Add a context object in nebula.Main to clean up on error (#550)

This commit is contained in:
brad-defined 2021-11-02 14:14:26 -04:00 committed by GitHub
parent 32cd9a93f1
commit 6ae8ba26f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 139 additions and 40 deletions

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"errors"
"fmt"
"io/ioutil"
@ -114,14 +115,21 @@ func (c *Config) HasChanged(k string) bool {
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *Config) CatchHUP() {
func (c *Config) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for range ch {
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
for {
select {
case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}
}()
}

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"sync"
"time"
@ -32,7 +33,7 @@ type connectionManager struct {
// I wanted to call one matLock
}
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
@ -50,7 +51,7 @@ func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pend
pendingDeletionInterval: pendingDeletionInterval,
l: l,
}
nc.Start()
nc.Start(ctx)
return nc
}
@ -137,19 +138,26 @@ func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
}
func (n *connectionManager) Start() {
go n.Run()
func (n *connectionManager) Start(ctx context.Context) {
go n.Run(ctx)
}
func (n *connectionManager) Run() {
clockSource := time.Tick(500 * time.Millisecond)
func (n *connectionManager) Run(ctx context.Context) {
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for now := range clockSource {
n.HandleMonitorTick(now, p, nb, out)
n.HandleDeletionTick(now)
for {
select {
case <-ctx.Done():
return
case now := <-clockSource.C:
n.HandleMonitorTick(now, p, nb, out)
n.HandleDeletionTick(now)
}
}
}

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net"
@ -45,7 +46,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
now := time.Now()
// Create manager
nc := newConnectionManager(l, ifce, 5, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@ -112,7 +115,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
now := time.Now()
// Create manager
nc := newConnectionManager(l, ifce, 5, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@ -220,7 +225,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
}
// Create manager
nc := newConnectionManager(l, ifce, 5, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
ifce.connectionManager = nc
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"net"
"os"
"os/signal"
@ -17,6 +18,7 @@ import (
type Control struct {
f *Interface
l *logrus.Logger
cancel context.CancelFunc
sshStart func()
statsStart func()
dnsStart func()
@ -57,6 +59,7 @@ func (c *Control) Start() {
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
c.CloseAllTunnels(false)
c.cancel()
c.l.Info("Goodbye")
}

View File

@ -2,6 +2,7 @@ package nebula
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"errors"
@ -66,14 +67,18 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
}
}
func (c *HandshakeManager) Run(f EncWriter) {
clockSource := time.Tick(c.config.tryInterval)
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
for {
select {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now, f)
}
}

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"errors"
"fmt"
"net"
@ -369,7 +370,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
}
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(conn *udpConn) {
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@ -379,6 +380,10 @@ func (hm *HostMap) Punchy(conn *udpConn) {
var remotes []*RemoteList
b := []byte{1}
clockSource := time.NewTicker(time.Second * 10)
defer clockSource.Stop()
for {
remotes = hm.punchList(remotes[:0])
for _, rl := range remotes {
@ -388,7 +393,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
conn.WriteTo(b, addr)
}
}
time.Sleep(time.Second * 10)
select {
case <-ctx.Done():
return
case <-clockSource.C:
continue
}
}
}

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"errors"
"io"
"net"
@ -86,7 +87,7 @@ type Interface struct {
l *logrus.Logger
}
func NewInterface(c *InterfaceConfig) (*Interface, error) {
func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Outside == nil {
return nil, errors.New("no outside connection")
}
@ -135,7 +136,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil
}
@ -302,15 +303,20 @@ func (f *Interface) reloadFirewall(c *Config) {
Info("New firewall has been installed")
}
func (f *Interface) emitStats(i time.Duration) {
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()
udpStats := NewUDPStatsEmitter(f.writers)
for range ticker.C {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
}
}
}

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"encoding/binary"
"errors"
"fmt"
@ -328,14 +329,23 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
if lh.amLighthouse || lh.interval == 0 {
return
}
clockSource := time.NewTicker(time.Second * time.Duration(lh.interval))
defer clockSource.Stop()
for {
lh.SendUpdate(f)
time.Sleep(time.Second * time.Duration(lh.interval))
select {
case <-ctx.Done():
return
case <-clockSource.C:
continue
}
}
}

34
main.go
View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"encoding/binary"
"fmt"
"net"
@ -13,7 +14,16 @@ import (
type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
if reterr != nil {
cancel()
}
}()
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
@ -126,7 +136,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var tun Inside
if !configTest {
config.CatchHUP()
config.CatchHUP(ctx)
switch {
case config.GetBool("tun.disabled", false):
@ -159,6 +169,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
}
defer func() {
if reterr != nil {
tun.Close()
}
}()
// set up our UDP listener
udpConns := make([]*udpConn, routines)
port := config.GetInt("listen.port", 0)
@ -236,7 +252,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
punchy := NewPunchyFromConfig(config)
if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpConns[0])
go hostMap.Punchy(ctx, udpConns[0])
}
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
@ -388,7 +404,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var ifce *Interface
if !configTest {
ifce, err = NewInterface(ifConfig)
ifce, err = NewInterface(ctx, ifConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize interface: %s", err)
}
@ -399,10 +415,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce)
go handshakeManager.Run(ctx, ifce)
go lightHouse.LhUpdateWorker(ctx, ifce)
}
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, config, buildVersion, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
@ -413,7 +431,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
@ -424,5 +442,5 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
dnsStart = dnsMain(l, hostMap, config)
}
return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
}

View File

@ -93,7 +93,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
pr := prometheus.NewRegistry()
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics()
if !configTest {
go pClient.UpdatePrometheusMetrics()
}
// Export our version information as labels on a static gauge
g := prometheus.NewGauge(prometheus.GaugeOpts{

View File

@ -41,6 +41,13 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Close() error {
if c.Interface != nil {
return c.Interface.Close()
}
return nil
}
func (c *Tun) Activate() error {
var err error
c.Interface, err = water.New(water.Config{

View File

@ -28,6 +28,13 @@ type Tun struct {
io.ReadWriteCloser
}
func (c *Tun) Close() error {
if c.ReadWriteCloser != nil {
return c.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}

View File

@ -24,6 +24,13 @@ type Tun struct {
*water.Interface
}
func (c *Tun) Close() error {
if c.Interface != nil {
return c.Interface.Close()
}
return nil
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}