Ensure the Nebula device exists before attempting to bind to the Nebula IP (#375)

This commit is contained in:
brad-defined 2021-04-16 11:34:28 -04:00 committed by GitHub
parent ab08be1e3e
commit 17106f83a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 44 deletions

View File

@ -15,8 +15,11 @@ import (
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
f *Interface
l *logrus.Logger
sshStart func()
statsStart func()
dnsStart func()
}
type ControlHostInfo struct {
@ -32,6 +35,21 @@ type ControlHostInfo struct {
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Activate the interface
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
go c.sshStart()
}
if c.statsStart != nil {
go c.statsStart()
}
if c.dnsStart != nil {
go c.dnsStart()
}
// Start reading packets.
c.f.run()
}

View File

@ -109,7 +109,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
dnsR = newDnsRecords(hostMap)
// attach request handler func
@ -120,7 +120,10 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
reloadDns(l, c)
})
startDns(l, c)
return func() {
startDns(l, c)
}
}
func getDnsServerAddr(c *Config) string {

View File

@ -130,7 +130,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
return ifce, nil
}
func (f *Interface) run() {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@ -159,7 +162,9 @@ func (f *Interface) run() {
if err := f.inside.Activate(); err != nil {
f.l.Fatal(err)
}
}
func (f *Interface) run() {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)

10
main.go
View File

@ -75,8 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, config)
var sshStart func()
if config.GetBool("sshd.enabled", false) {
err = configSSH(l, ssh, config)
sshStart, err = configSSH(l, ssh, config)
if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
@ -393,7 +394,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go lightHouse.LhUpdateWorker(ifce)
}
err = startStats(l, config, buildVersion, configTest)
statsStart, err := startStats(l, config, buildVersion, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
@ -408,10 +409,11 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if amLighthouse && serveDns {
l.Debugln("Starting dns server")
go dnsMain(l, hostMap, config)
dnsStart = dnsMain(l, hostMap, config)
}
return &Control{ifce, l}, nil
return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
}

32
ssh.go
View File

@ -47,48 +47,55 @@ type sshCreateTunnelFlags struct {
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) {
err := configSSH(l, ssh, c)
sshRun, err := configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop()
}
if sshRun != nil {
go sshRun()
}
} else {
ssh.Stop()
}
})
}
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
// configSSH reads the ssh info out of the passed-in Config and
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
//TODO conntrack list
//TODO print firewall rules or hash?
listen := c.GetString("sshd.listen", "")
if listen == "" {
return fmt.Errorf("sshd.listen must be provided")
return nil, fmt.Errorf("sshd.listen must be provided")
}
_, port, err := net.SplitHostPort(listen)
if err != nil {
return fmt.Errorf("invalid sshd.listen address: %s", err)
return nil, fmt.Errorf("invalid sshd.listen address: %s", err)
}
if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22")
return nil, fmt.Errorf("sshd.listen can not use port 22")
}
//TODO: no good way to reload this right now
hostKeyFile := c.GetString("sshd.host_key", "")
if hostKeyFile == "" {
return fmt.Errorf("sshd.host_key must be provided")
return nil, fmt.Errorf("sshd.host_key must be provided")
}
hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
if err != nil {
return fmt.Errorf("error while loading sshd.host_key file: %s", err)
return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
}
err = ssh.SetHostKey(hostKeyBytes)
if err != nil {
return fmt.Errorf("error while adding sshd.host_key: %s", err)
return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
}
rawKeys := c.Get("sshd.authorized_users")
@ -139,14 +146,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
l.Info("no ssh users to authorize")
}
var runner func()
if c.GetBool("sshd.enabled", false) {
ssh.Stop()
go ssh.Run(listen)
runner = func() {
if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server")
}
}
} else {
ssh.Stop()
}
return nil
return runner, nil
}
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {

View File

@ -141,21 +141,22 @@ func (s *SSHServer) Run(addr string) error {
}
func (s *SSHServer) Stop() {
// Close the listener first, to prevent any new connections being accepted.
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
} else {
s.l.Info("SSH server stopped listening")
}
}
// Force close all existing connections.
// TODO I believe this has a slight race if the listener has just accepted
// a connection. Can fix by moving this to the goroutine that's accepting.
for _, c := range s.conns {
c.Close()
}
if s.listener == nil {
return
}
err := s.listener.Close()
if err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
return
}
s.l.Info("SSH server stopped listening")
return
}

View File

@ -17,24 +17,35 @@ import (
"github.com/sirupsen/logrus"
)
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) error {
// startStats initializes stats from config. On success, if any futher work
// is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil
return nil, nil
}
interval := c.GetDuration("stats.interval", 0)
if interval == 0 {
return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
}
var startFn func()
switch mType {
case "graphite":
startGraphiteStats(l, interval, c, configTest)
err := startGraphiteStats(l, interval, c, configTest)
if err != nil {
return nil, err
}
case "prometheus":
startPrometheusStats(l, interval, c, buildVersion, configTest)
var err error
startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest)
if err != nil {
return nil, err
}
default:
return fmt.Errorf("stats.type was not understood: %s", mType)
return nil, fmt.Errorf("stats.type was not understood: %s", mType)
}
metrics.RegisterDebugGCStats(metrics.DefaultRegistry)
@ -43,7 +54,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval)
go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval)
return nil
return startFn, nil
}
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
@ -59,25 +70,25 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
return fmt.Errorf("error while setting up graphite sink: %s", err)
}
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
if !configTest {
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
}
return nil
}
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) error {
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")
listen := c.GetString("stats.listen", "")
if listen == "" {
return fmt.Errorf("stats.listen should not be empty")
return nil, fmt.Errorf("stats.listen should not be empty")
}
path := c.GetString("stats.path", "")
if path == "" {
return fmt.Errorf("stats.path should not be empty")
return nil, fmt.Errorf("stats.path should not be empty")
}
pr := prometheus.NewRegistry()
@ -98,13 +109,14 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
pr.MustRegister(g)
g.Set(1)
var startFn func()
if !configTest {
go func() {
startFn = func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil))
}()
}
}
return nil
return startFn, nil
}