From 6ae8ba26f774cb74348349bb069c99221d7a0cea Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Tue, 2 Nov 2021 14:14:26 -0400 Subject: [PATCH] Add a context object in nebula.Main to clean up on error (#550) --- config.go | 16 ++++++++++++---- connection_manager.go | 26 +++++++++++++++++--------- connection_manager_test.go | 13 ++++++++++--- control.go | 3 +++ handshake_manager.go | 11 ++++++++--- hostmap.go | 15 +++++++++++++-- interface.go | 22 ++++++++++++++-------- lighthouse.go | 14 ++++++++++++-- main.go | 34 ++++++++++++++++++++++++++-------- stats.go | 4 +++- tun_darwin.go | 7 +++++++ tun_freebsd.go | 7 +++++++ tun_windows.go | 7 +++++++ 13 files changed, 139 insertions(+), 40 deletions(-) diff --git a/config.go b/config.go index 152fd64..c4dce64 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "errors" "fmt" "io/ioutil" @@ -114,14 +115,21 @@ func (c *Config) HasChanged(k string) bool { // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the // original path provided to Load. The old settings are shallow copied for change detection after the reload. -func (c *Config) CatchHUP() { +func (c *Config) CatchHUP(ctx context.Context) { ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGHUP) go func() { - for range ch { - c.l.Info("Caught HUP, reloading config") - c.ReloadConfig() + for { + select { + case <-ctx.Done(): + signal.Stop(ch) + close(ch) + return + case <-ch: + c.l.Info("Caught HUP, reloading config") + c.ReloadConfig() + } } }() } diff --git a/connection_manager.go b/connection_manager.go index 78b1a8a..de9b165 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "sync" "time" @@ -32,7 +33,7 @@ type connectionManager struct { // I wanted to call one matLock } -func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { +func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { nc := &connectionManager{ hostMap: intf.hostMap, in: make(map[uint32]struct{}), @@ -50,7 +51,7 @@ func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pend pendingDeletionInterval: pendingDeletionInterval, l: l, } - nc.Start() + nc.Start(ctx) return nc } @@ -137,19 +138,26 @@ func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) { n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds)) } -func (n *connectionManager) Start() { - go n.Run() +func (n *connectionManager) Start(ctx context.Context) { + go n.Run(ctx) } -func (n *connectionManager) Run() { - clockSource := time.Tick(500 * time.Millisecond) +func (n *connectionManager) Run(ctx context.Context) { + clockSource := time.NewTicker(500 * time.Millisecond) + defer clockSource.Stop() + p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) - for now := range clockSource { - n.HandleMonitorTick(now, p, nb, out) - n.HandleDeletionTick(now) + for { + select { + case <-ctx.Done(): + return + case now := <-clockSource.C: + n.HandleMonitorTick(now, p, nb, out) + n.HandleDeletionTick(now) + } } } diff --git a/connection_manager_test.go b/connection_manager_test.go index d3b2b49..fa88640 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "crypto/ed25519" "crypto/rand" "net" @@ -45,7 +46,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { now := time.Now() // Create manager - nc := newConnectionManager(l, ifce, 5, 10) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + nc := newConnectionManager(ctx, l, ifce, 5, 10) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -112,7 +115,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { now := time.Now() // Create manager - nc := newConnectionManager(l, ifce, 5, 10) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + nc := newConnectionManager(ctx, l, ifce, 5, 10) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -220,7 +225,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { } // Create manager - nc := newConnectionManager(l, ifce, 5, 10) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + nc := newConnectionManager(ctx, l, ifce, 5, 10) ifce.connectionManager = nc hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo.ConnectionState = &ConnectionState{ diff --git a/control.go b/control.go index 39e1979..4bbe65f 100644 --- a/control.go +++ b/control.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "net" "os" "os/signal" @@ -17,6 +18,7 @@ import ( type Control struct { f *Interface l *logrus.Logger + cancel context.CancelFunc sshStart func() statsStart func() dnsStart func() @@ -57,6 +59,7 @@ func (c *Control) Start() { func (c *Control) Stop() { //TODO: stop tun and udp routines, the lock on hostMap effectively does that though c.CloseAllTunnels(false) + c.cancel() c.l.Info("Goodbye") } diff --git a/handshake_manager.go b/handshake_manager.go index e97d85b..d03cad4 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -2,6 +2,7 @@ package nebula import ( "bytes" + "context" "crypto/rand" "encoding/binary" "errors" @@ -66,14 +67,18 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(f EncWriter) { - clockSource := time.Tick(c.config.tryInterval) +func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { + clockSource := time.NewTicker(c.config.tryInterval) + defer clockSource.Stop() + for { select { + case <-ctx.Done(): + return case vpnIP := <-c.trigger: c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") c.handleOutbound(vpnIP, f, true) - case now := <-clockSource: + case now := <-clockSource.C: c.NextOutboundHandshakeTimerTick(now, f) } } diff --git a/hostmap.go b/hostmap.go index c2b520e..2f46d83 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "errors" "fmt" "net" @@ -369,7 +370,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList { } // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them -func (hm *HostMap) Punchy(conn *udpConn) { +func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) { var metricsTxPunchy metrics.Counter if hm.metricsEnabled { metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) @@ -379,6 +380,10 @@ func (hm *HostMap) Punchy(conn *udpConn) { var remotes []*RemoteList b := []byte{1} + + clockSource := time.NewTicker(time.Second * 10) + defer clockSource.Stop() + for { remotes = hm.punchList(remotes[:0]) for _, rl := range remotes { @@ -388,7 +393,13 @@ func (hm *HostMap) Punchy(conn *udpConn) { conn.WriteTo(b, addr) } } - time.Sleep(time.Second * 10) + + select { + case <-ctx.Done(): + return + case <-clockSource.C: + continue + } } } diff --git a/interface.go b/interface.go index 9ea6c3b..fc5642a 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "errors" "io" "net" @@ -86,7 +87,7 @@ type Interface struct { l *logrus.Logger } -func NewInterface(c *InterfaceConfig) (*Interface, error) { +func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Outside == nil { return nil, errors.New("no outside connection") } @@ -135,7 +136,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { l: c.l, } - ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval) + ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval) return ifce, nil } @@ -302,15 +303,20 @@ func (f *Interface) reloadFirewall(c *Config) { Info("New firewall has been installed") } -func (f *Interface) emitStats(i time.Duration) { +func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) + defer ticker.Stop() udpStats := NewUDPStatsEmitter(f.writers) - for range ticker.C { - f.firewall.EmitStats() - f.handshakeManager.EmitStats() - - udpStats() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + f.firewall.EmitStats() + f.handshakeManager.EmitStats() + udpStats() + } } } diff --git a/lighthouse.go b/lighthouse.go index 56e2851..0c12144 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "encoding/binary" "errors" "fmt" @@ -328,14 +329,23 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr { return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(f EncWriter) { +func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { if lh.amLighthouse || lh.interval == 0 { return } + clockSource := time.NewTicker(time.Second * time.Duration(lh.interval)) + defer clockSource.Stop() + for { lh.SendUpdate(f) - time.Sleep(time.Second * time.Duration(lh.interval)) + + select { + case <-ctx.Done(): + return + case <-clockSource.C: + continue + } } } diff --git a/main.go b/main.go index f18a971..048a4f3 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "encoding/binary" "fmt" "net" @@ -13,7 +14,16 @@ import ( type m map[string]interface{} -func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) { +func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { + + ctx, cancel := context.WithCancel(context.Background()) + // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. + defer func() { + if reterr != nil { + cancel() + } + }() + l := logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, @@ -126,7 +136,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L var tun Inside if !configTest { - config.CatchHUP() + config.CatchHUP(ctx) switch { case config.GetBool("tun.disabled", false): @@ -159,6 +169,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } } + defer func() { + if reterr != nil { + tun.Close() + } + }() + // set up our UDP listener udpConns := make([]*udpConn, routines) port := config.GetInt("listen.port", 0) @@ -236,7 +252,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L punchy := NewPunchyFromConfig(config) if punchy.Punch && !configTest { l.Info("UDP hole punching enabled") - go hostMap.Punchy(udpConns[0]) + go hostMap.Punchy(ctx, udpConns[0]) } amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) @@ -388,7 +404,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L var ifce *Interface if !configTest { - ifce, err = NewInterface(ifConfig) + ifce, err = NewInterface(ctx, ifConfig) if err != nil { return nil, fmt.Errorf("failed to initialize interface: %s", err) } @@ -399,10 +415,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L ifce.RegisterConfigChangeCallbacks(config) - go handshakeManager.Run(ifce) - go lightHouse.LhUpdateWorker(ifce) + go handshakeManager.Run(ctx, ifce) + go lightHouse.LhUpdateWorker(ctx, ifce) } + // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept + // a context so that they can exit when the context is Done. statsStart, err := startStats(l, config, buildVersion, configTest) if err != nil { return nil, NewContextualError("Failed to start stats emitter", nil, err) @@ -413,7 +431,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } //TODO: check if we _should_ be emitting stats - go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) + go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10)) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) @@ -424,5 +442,5 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L dnsStart = dnsMain(l, hostMap, config) } - return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil + return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil } diff --git a/stats.go b/stats.go index 205a89a..94e75ef 100644 --- a/stats.go +++ b/stats.go @@ -93,7 +93,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer pr := prometheus.NewRegistry() pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) - go pClient.UpdatePrometheusMetrics() + if !configTest { + go pClient.UpdatePrometheusMetrics() + } // Export our version information as labels on a static gauge g := prometheus.NewGauge(prometheus.GaugeOpts{ diff --git a/tun_darwin.go b/tun_darwin.go index e2801ba..079c80e 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -41,6 +41,13 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } +func (c *Tun) Close() error { + if c.Interface != nil { + return c.Interface.Close() + } + return nil +} + func (c *Tun) Activate() error { var err error c.Interface, err = water.New(water.Config{ diff --git a/tun_freebsd.go b/tun_freebsd.go index accbd40..482edad 100644 --- a/tun_freebsd.go +++ b/tun_freebsd.go @@ -28,6 +28,13 @@ type Tun struct { io.ReadWriteCloser } +func (c *Tun) Close() error { + if c.ReadWriteCloser != nil { + return c.ReadWriteCloser.Close() + } + return nil +} + func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } diff --git a/tun_windows.go b/tun_windows.go index 8dcc002..17bbe4e 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -24,6 +24,13 @@ type Tun struct { *water.Interface } +func (c *Tun) Close() error { + if c.Interface != nil { + return c.Interface.Close() + } + return nil +} + func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") }