Add a context object in nebula.Main to clean up on error (#550)

This commit is contained in:
brad-defined 2021-11-02 14:14:26 -04:00 committed by GitHub
parent 32cd9a93f1
commit 6ae8ba26f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 139 additions and 40 deletions

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -114,14 +115,21 @@ func (c *Config) HasChanged(k string) bool {
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *Config) CatchHUP() { func (c *Config) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP) signal.Notify(ch, syscall.SIGHUP)
go func() { go func() {
for range ch { for {
c.l.Info("Caught HUP, reloading config") select {
c.ReloadConfig() case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
} }
}() }()
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"sync" "sync"
"time" "time"
@ -32,7 +33,7 @@ type connectionManager struct {
// I wanted to call one matLock // I wanted to call one matLock
} }
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{ nc := &connectionManager{
hostMap: intf.hostMap, hostMap: intf.hostMap,
in: make(map[uint32]struct{}), in: make(map[uint32]struct{}),
@ -50,7 +51,7 @@ func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pend
pendingDeletionInterval: pendingDeletionInterval, pendingDeletionInterval: pendingDeletionInterval,
l: l, l: l,
} }
nc.Start() nc.Start(ctx)
return nc return nc
} }
@ -137,19 +138,26 @@ func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds)) n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
} }
func (n *connectionManager) Start() { func (n *connectionManager) Start(ctx context.Context) {
go n.Run() go n.Run(ctx)
} }
func (n *connectionManager) Run() { func (n *connectionManager) Run(ctx context.Context) {
clockSource := time.Tick(500 * time.Millisecond) clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for now := range clockSource { for {
n.HandleMonitorTick(now, p, nb, out) select {
n.HandleDeletionTick(now) case <-ctx.Done():
return
case now := <-clockSource.C:
n.HandleMonitorTick(now, p, nb, out)
n.HandleDeletionTick(now)
}
} }
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"net" "net"
@ -45,7 +46,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
now := time.Now() now := time.Now()
// Create manager // Create manager
nc := newConnectionManager(l, ifce, 5, 10) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@ -112,7 +115,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
now := time.Now() now := time.Now()
// Create manager // Create manager
nc := newConnectionManager(l, ifce, 5, 10) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@ -220,7 +225,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
} }
// Create manager // Create manager
nc := newConnectionManager(l, ifce, 5, 10) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"net" "net"
"os" "os"
"os/signal" "os/signal"
@ -17,6 +18,7 @@ import (
type Control struct { type Control struct {
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
cancel context.CancelFunc
sshStart func() sshStart func()
statsStart func() statsStart func()
dnsStart func() dnsStart func()
@ -57,6 +59,7 @@ func (c *Control) Start() {
func (c *Control) Stop() { func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though //TODO: stop tun and udp routines, the lock on hostMap effectively does that though
c.CloseAllTunnels(false) c.CloseAllTunnels(false)
c.cancel()
c.l.Info("Goodbye") c.l.Info("Goodbye")
} }

View File

@ -2,6 +2,7 @@ package nebula
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -66,14 +67,18 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
} }
} }
func (c *HandshakeManager) Run(f EncWriter) { func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
clockSource := time.Tick(c.config.tryInterval) clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
for { for {
select { select {
case <-ctx.Done():
return
case vpnIP := <-c.trigger: case vpnIP := <-c.trigger:
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true) c.handleOutbound(vpnIP, f, true)
case now := <-clockSource: case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
} }
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -369,7 +370,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
} }
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(conn *udpConn) { func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
var metricsTxPunchy metrics.Counter var metricsTxPunchy metrics.Counter
if hm.metricsEnabled { if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@ -379,6 +380,10 @@ func (hm *HostMap) Punchy(conn *udpConn) {
var remotes []*RemoteList var remotes []*RemoteList
b := []byte{1} b := []byte{1}
clockSource := time.NewTicker(time.Second * 10)
defer clockSource.Stop()
for { for {
remotes = hm.punchList(remotes[:0]) remotes = hm.punchList(remotes[:0])
for _, rl := range remotes { for _, rl := range remotes {
@ -388,7 +393,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
conn.WriteTo(b, addr) conn.WriteTo(b, addr)
} }
} }
time.Sleep(time.Second * 10)
select {
case <-ctx.Done():
return
case <-clockSource.C:
continue
}
} }
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"errors" "errors"
"io" "io"
"net" "net"
@ -86,7 +87,7 @@ type Interface struct {
l *logrus.Logger l *logrus.Logger
} }
func NewInterface(c *InterfaceConfig) (*Interface, error) { func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Outside == nil { if c.Outside == nil {
return nil, errors.New("no outside connection") return nil, errors.New("no outside connection")
} }
@ -135,7 +136,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval) ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil return ifce, nil
} }
@ -302,15 +303,20 @@ func (f *Interface) reloadFirewall(c *Config) {
Info("New firewall has been installed") Info("New firewall has been installed")
} }
func (f *Interface) emitStats(i time.Duration) { func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i) ticker := time.NewTicker(i)
defer ticker.Stop()
udpStats := NewUDPStatsEmitter(f.writers) udpStats := NewUDPStatsEmitter(f.writers)
for range ticker.C { for {
f.firewall.EmitStats() select {
f.handshakeManager.EmitStats() case <-ctx.Done():
return
udpStats() case <-ticker.C:
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
}
} }
} }

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -328,14 +329,23 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
} }
func (lh *LightHouse) LhUpdateWorker(f EncWriter) { func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
if lh.amLighthouse || lh.interval == 0 { if lh.amLighthouse || lh.interval == 0 {
return return
} }
clockSource := time.NewTicker(time.Second * time.Duration(lh.interval))
defer clockSource.Stop()
for { for {
lh.SendUpdate(f) lh.SendUpdate(f)
time.Sleep(time.Second * time.Duration(lh.interval))
select {
case <-ctx.Done():
return
case <-clockSource.C:
continue
}
} }
} }

34
main.go
View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@ -13,7 +14,16 @@ import (
type m map[string]interface{} type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) { func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
if reterr != nil {
cancel()
}
}()
l := logger l := logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
@ -126,7 +136,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var tun Inside var tun Inside
if !configTest { if !configTest {
config.CatchHUP() config.CatchHUP(ctx)
switch { switch {
case config.GetBool("tun.disabled", false): case config.GetBool("tun.disabled", false):
@ -159,6 +169,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
} }
defer func() {
if reterr != nil {
tun.Close()
}
}()
// set up our UDP listener // set up our UDP listener
udpConns := make([]*udpConn, routines) udpConns := make([]*udpConn, routines)
port := config.GetInt("listen.port", 0) port := config.GetInt("listen.port", 0)
@ -236,7 +252,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
punchy := NewPunchyFromConfig(config) punchy := NewPunchyFromConfig(config)
if punchy.Punch && !configTest { if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled") l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpConns[0]) go hostMap.Punchy(ctx, udpConns[0])
} }
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
@ -388,7 +404,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var ifce *Interface var ifce *Interface
if !configTest { if !configTest {
ifce, err = NewInterface(ifConfig) ifce, err = NewInterface(ctx, ifConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize interface: %s", err) return nil, fmt.Errorf("failed to initialize interface: %s", err)
} }
@ -399,10 +415,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
ifce.RegisterConfigChangeCallbacks(config) ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce) go handshakeManager.Run(ctx, ifce)
go lightHouse.LhUpdateWorker(ifce) go lightHouse.LhUpdateWorker(ctx, ifce)
} }
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, config, buildVersion, configTest) statsStart, err := startStats(l, config, buildVersion, configTest)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err) return nil, NewContextualError("Failed to start stats emitter", nil, err)
@ -413,7 +431,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
@ -424,5 +442,5 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
dnsStart = dnsMain(l, hostMap, config) dnsStart = dnsMain(l, hostMap, config)
} }
return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
} }

View File

@ -93,7 +93,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
pr := prometheus.NewRegistry() pr := prometheus.NewRegistry()
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics() if !configTest {
go pClient.UpdatePrometheusMetrics()
}
// Export our version information as labels on a static gauge // Export our version information as labels on a static gauge
g := prometheus.NewGauge(prometheus.GaugeOpts{ g := prometheus.NewGauge(prometheus.GaugeOpts{

View File

@ -41,6 +41,13 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
func (c *Tun) Close() error {
if c.Interface != nil {
return c.Interface.Close()
}
return nil
}
func (c *Tun) Activate() error { func (c *Tun) Activate() error {
var err error var err error
c.Interface, err = water.New(water.Config{ c.Interface, err = water.New(water.Config{

View File

@ -28,6 +28,13 @@ type Tun struct {
io.ReadWriteCloser io.ReadWriteCloser
} }
func (c *Tun) Close() error {
if c.ReadWriteCloser != nil {
return c.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }

View File

@ -24,6 +24,13 @@ type Tun struct {
*water.Interface *water.Interface
} }
func (c *Tun) Close() error {
if c.Interface != nil {
return c.Interface.Close()
}
return nil
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }