From 3ea7e1b75f5b839615e73de27061a88f896bdfd0 Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Fri, 26 Mar 2021 09:46:30 -0500 Subject: [PATCH] Don't use a global logger (#423) --- bits.go | 4 +- bits_test.go | 154 +++++++++++++++++----------------- cert.go | 3 +- cmd/nebula-service/main.go | 7 +- cmd/nebula-service/service.go | 7 +- cmd/nebula/main.go | 7 +- config.go | 18 ++-- config_test.go | 27 +++--- connection_manager.go | 20 +++-- connection_manager_test.go | 20 +++-- connection_state.go | 5 +- control_test.go | 3 +- dns_server.go | 25 +++--- firewall.go | 31 +++---- firewall_test.go | 149 ++++++++++++++++---------------- handshake.go | 2 +- handshake_ix.go | 66 +++++++-------- handshake_manager.go | 20 +++-- handshake_manager_test.go | 26 +++--- hostmap.go | 54 ++++++------ hostmap_test.go | 9 +- inside.go | 40 ++++----- interface.go | 40 +++++---- lighthouse.go | 36 ++++---- lighthouse_test.go | 19 +++-- main.go | 29 +++---- main_test.go | 29 +++++++ outside.go | 50 +++++------ punchy_test.go | 3 +- ssh.go | 20 +++-- stats.go | 11 +-- tun_android.go | 5 +- tun_darwin.go | 8 +- tun_disabled.go | 29 ++++--- tun_freebsd.go | 16 ++-- tun_linux.go | 12 ++- tun_test.go | 6 +- tun_windows.go | 7 +- udp_android.go | 2 + udp_darwin.go | 2 + udp_freebsd.go | 2 + udp_generic.go | 12 ++- udp_linux.go | 23 ++--- udp_linux_32.go | 1 + udp_linux_64.go | 1 + 45 files changed, 590 insertions(+), 470 deletions(-) diff --git a/bits.go b/bits.go index 49cadc1..b4f96c6 100644 --- a/bits.go +++ b/bits.go @@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits { } } -func (b *Bits) Check(i uint64) bool { +func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { // If i is the next number, return true. if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { return true @@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool { return false } -func (b *Bits) Update(i uint64) bool { +func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Report missed packets, we can only understand what was missed after the first window has been gone through diff --git a/bits_test.go b/bits_test.go index 880fcd9..50d00bc 100644 --- a/bits_test.go +++ b/bits_test.go @@ -7,6 +7,7 @@ import ( ) func TestBits(t *testing.T) { + l := NewTestLogger() b := NewBits(10) // make sure it is the right size @@ -14,46 +15,46 @@ func TestBits(t *testing.T) { // This is initialized to zero - receive one. This should work. - assert.True(t, b.Check(1)) - u := b.Update(1) + assert.True(t, b.Check(l, 1)) + u := b.Update(l, 1) assert.True(t, u) assert.EqualValues(t, 1, b.current) g := []bool{false, true, false, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two - assert.True(t, b.Check(2)) - u = b.Update(2) + assert.True(t, b.Check(l, 2)) + u = b.Update(l, 2) assert.True(t, u) assert.EqualValues(t, 2, b.current) g = []bool{false, true, true, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two again - it will fail - assert.False(t, b.Check(2)) - u = b.Update(2) + assert.False(t, b.Check(l, 2)) + u = b.Update(l, 2) assert.False(t, u) assert.EqualValues(t, 2, b.current) // Jump ahead to 15, which should clear everything and set the 6th element - assert.True(t, b.Check(15)) - u = b.Update(15) + assert.True(t, b.Check(l, 15)) + u = b.Update(l, 15) assert.True(t, u) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, false, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 14, which is allowed because it is in the window - assert.True(t, b.Check(14)) - u = b.Update(14) + assert.True(t, b.Check(l, 14)) + u = b.Update(l, 14) assert.True(t, u) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 5, which is not allowed because it is not in the window - assert.False(t, b.Check(5)) - u = b.Update(5) + assert.False(t, b.Check(l, 5)) + u = b.Update(l, 5) assert.False(t, u) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} @@ -61,63 +62,65 @@ func TestBits(t *testing.T) { // make sure we handle wrapping around once to the current position b = NewBits(10) - assert.True(t, b.Update(1)) - assert.True(t, b.Update(11)) + assert.True(t, b.Update(l, 1)) + assert.True(t, b.Update(l, 11)) assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) // Walk through a few windows in order b = NewBits(10) for i := uint64(0); i <= 100; i++ { - assert.True(t, b.Check(i), "Error while checking %v", i) - assert.True(t, b.Update(i), "Error while updating %v", i) + assert.True(t, b.Check(l, i), "Error while checking %v", i) + assert.True(t, b.Update(l, i), "Error while updating %v", i) } } func TestBitsDupeCounter(t *testing.T) { + l := NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(1)) + assert.True(t, b.Update(l, 1)) assert.Equal(t, int64(0), b.dupeCounter.Count()) - assert.False(t, b.Update(1)) + assert.False(t, b.Update(l, 1)) assert.Equal(t, int64(1), b.dupeCounter.Count()) - assert.True(t, b.Update(2)) + assert.True(t, b.Update(l, 2)) assert.Equal(t, int64(1), b.dupeCounter.Count()) - assert.True(t, b.Update(3)) + assert.True(t, b.Update(l, 3)) assert.Equal(t, int64(1), b.dupeCounter.Count()) - assert.False(t, b.Update(1)) + assert.False(t, b.Update(l, 1)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(2), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsOutOfWindowCounter(t *testing.T) { + l := NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(20)) + assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - assert.True(t, b.Update(21)) - assert.True(t, b.Update(22)) - assert.True(t, b.Update(23)) - assert.True(t, b.Update(24)) - assert.True(t, b.Update(25)) - assert.True(t, b.Update(26)) - assert.True(t, b.Update(27)) - assert.True(t, b.Update(28)) - assert.True(t, b.Update(29)) + assert.True(t, b.Update(l, 21)) + assert.True(t, b.Update(l, 22)) + assert.True(t, b.Update(l, 23)) + assert.True(t, b.Update(l, 24)) + assert.True(t, b.Update(l, 25)) + assert.True(t, b.Update(l, 26)) + assert.True(t, b.Update(l, 27)) + assert.True(t, b.Update(l, 28)) + assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - assert.False(t, b.Update(0)) + assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) //tODO: make sure lostcounter doesn't increase in orderly increment @@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) { } func TestBitsLostCounter(t *testing.T) { + l := NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() //assert.True(t, b.Update(0)) - assert.True(t, b.Update(0)) - assert.True(t, b.Update(20)) - assert.True(t, b.Update(21)) - assert.True(t, b.Update(22)) - assert.True(t, b.Update(23)) - assert.True(t, b.Update(24)) - assert.True(t, b.Update(25)) - assert.True(t, b.Update(26)) - assert.True(t, b.Update(27)) - assert.True(t, b.Update(28)) - assert.True(t, b.Update(29)) + assert.True(t, b.Update(l, 0)) + assert.True(t, b.Update(l, 20)) + assert.True(t, b.Update(l, 21)) + assert.True(t, b.Update(l, 22)) + assert.True(t, b.Update(l, 23)) + assert.True(t, b.Update(l, 24)) + assert.True(t, b.Update(l, 25)) + assert.True(t, b.Update(l, 26)) + assert.True(t, b.Update(l, 27)) + assert.True(t, b.Update(l, 28)) + assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(20), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) @@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) { b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(0)) + assert.True(t, b.Update(l, 0)) assert.Equal(t, int64(0), b.lostCounter.Count()) - assert.True(t, b.Update(9)) + assert.True(t, b.Update(l, 9)) assert.Equal(t, int64(0), b.lostCounter.Count()) // 10 will set 0 index, 0 was already set, no lost packets - assert.True(t, b.Update(10)) + assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) // 11 will set 1 index, 1 was missed, we should see 1 packet lost - assert.True(t, b.Update(11)) + assert.True(t, b.Update(l, 11)) assert.Equal(t, int64(1), b.lostCounter.Count()) // Now let's fill in the window, should end up with 8 lost packets - assert.True(t, b.Update(12)) - assert.True(t, b.Update(13)) - assert.True(t, b.Update(14)) - assert.True(t, b.Update(15)) - assert.True(t, b.Update(16)) - assert.True(t, b.Update(17)) - assert.True(t, b.Update(18)) - assert.True(t, b.Update(19)) + assert.True(t, b.Update(l, 12)) + assert.True(t, b.Update(l, 13)) + assert.True(t, b.Update(l, 14)) + assert.True(t, b.Update(l, 15)) + assert.True(t, b.Update(l, 16)) + assert.True(t, b.Update(l, 17)) + assert.True(t, b.Update(l, 18)) + assert.True(t, b.Update(l, 19)) assert.Equal(t, int64(8), b.lostCounter.Count()) // Jump ahead by a window size - assert.True(t, b.Update(29)) + assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(8), b.lostCounter.Count()) // Now lets walk ahead normally through the window, the missed packets should fill in - assert.True(t, b.Update(30)) - assert.True(t, b.Update(31)) - assert.True(t, b.Update(32)) - assert.True(t, b.Update(33)) - assert.True(t, b.Update(34)) - assert.True(t, b.Update(35)) - assert.True(t, b.Update(36)) - assert.True(t, b.Update(37)) - assert.True(t, b.Update(38)) + assert.True(t, b.Update(l, 30)) + assert.True(t, b.Update(l, 31)) + assert.True(t, b.Update(l, 32)) + assert.True(t, b.Update(l, 33)) + assert.True(t, b.Update(l, 34)) + assert.True(t, b.Update(l, 35)) + assert.True(t, b.Update(l, 36)) + assert.True(t, b.Update(l, 37)) + assert.True(t, b.Update(l, 38)) // 39 packets tracked, 22 seen, 17 lost assert.Equal(t, int64(17), b.lostCounter.Count()) // Jump ahead by 2 windows, should have recording 1 full window missing - assert.True(t, b.Update(58)) + assert.True(t, b.Update(l, 58)) assert.Equal(t, int64(27), b.lostCounter.Count()) // Now lets walk ahead normally through the window, the missed packets should fill in from this window - assert.True(t, b.Update(59)) - assert.True(t, b.Update(60)) - assert.True(t, b.Update(61)) - assert.True(t, b.Update(62)) - assert.True(t, b.Update(63)) - assert.True(t, b.Update(64)) - assert.True(t, b.Update(65)) - assert.True(t, b.Update(66)) - assert.True(t, b.Update(67)) + assert.True(t, b.Update(l, 59)) + assert.True(t, b.Update(l, 60)) + assert.True(t, b.Update(l, 61)) + assert.True(t, b.Update(l, 62)) + assert.True(t, b.Update(l, 63)) + assert.True(t, b.Update(l, 64)) + assert.True(t, b.Update(l, 65)) + assert.True(t, b.Update(l, 66)) + assert.True(t, b.Update(l, 67)) // 68 packets tracked, 32 seen, 36 missed assert.Equal(t, int64(36), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) diff --git a/cert.go b/cert.go index bc51175..0e2ce3c 100644 --- a/cert.go +++ b/cert.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" ) @@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) { return NewCertState(nebulaCert, rawKey) } -func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) { +func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) { var rawCA []byte var err error diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 912f470..ea189d2 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -46,15 +46,16 @@ func main() { os.Exit(1) } - config := nebula.NewConfig() + l := logrus.New() + l.Out = os.Stdout + + config := nebula.NewConfig(l) err := config.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout c, err := nebula.Main(config, *configTest, Build, l, nil) switch v := err.(type) { diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 6e1dcd9..03336b6 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") - config := nebula.NewConfig() + l := logrus.New() + l.Out = os.Stdout + + config := nebula.NewConfig(l) err := config.Load(*p.configPath) if err != nil { return fmt.Errorf("failed to load config: %s", err) } - l := logrus.New() - l.Out = os.Stdout p.control, err = nebula.Main(config, *p.configTest, Build, l, nil) if err != nil { return err diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index b28fa13..cffd75a 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -40,15 +40,16 @@ func main() { os.Exit(1) } - config := nebula.NewConfig() + l := logrus.New() + l.Out = os.Stdout + + config := nebula.NewConfig(l) err := config.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout c, err := nebula.Main(config, *configTest, Build, l, nil) switch v := err.(type) { diff --git a/config.go b/config.go index a5493df..a11b89a 100644 --- a/config.go +++ b/config.go @@ -26,11 +26,13 @@ type Config struct { Settings map[interface{}]interface{} oldSettings map[interface{}]interface{} callbacks []func(*Config) + l *logrus.Logger } -func NewConfig() *Config { +func NewConfig(l *logrus.Logger) *Config { return &Config{ Settings: make(map[interface{}]interface{}), + l: l, } } @@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool { newVals, err := yaml.Marshal(nv) if err != nil { - l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") } oldVals, err := yaml.Marshal(ov) if err != nil { - l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") } return string(newVals) != string(oldVals) @@ -118,7 +120,7 @@ func (c *Config) CatchHUP() { go func() { for range ch { - l.Info("Caught HUP, reloading config") + c.l.Info("Caught HUP, reloading config") c.ReloadConfig() } }() @@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() { err := c.Load(c.path) if err != nil { - l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") return } @@ -500,7 +502,7 @@ func configLogger(c *Config) error { if err != nil { return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) } - l.SetLevel(logLevel) + c.l.SetLevel(logLevel) disableTimestamp := c.GetBool("logging.disable_timestamp", false) timestampFormat := c.GetString("logging.timestamp_format", "") @@ -512,13 +514,13 @@ func configLogger(c *Config) error { logFormat := strings.ToLower(c.GetString("logging.format", "text")) switch logFormat { case "text": - l.Formatter = &logrus.TextFormatter{ + c.l.Formatter = &logrus.TextFormatter{ TimestampFormat: timestampFormat, FullTimestamp: fullTimestamp, DisableTimestamp: disableTimestamp, } case "json": - l.Formatter = &logrus.JSONFormatter{ + c.l.Formatter = &logrus.JSONFormatter{ TimestampFormat: timestampFormat, DisableTimestamp: disableTimestamp, } diff --git a/config_test.go b/config_test.go index 359a2af..5a1aea4 100644 --- a/config_test.go +++ b/config_test.go @@ -11,14 +11,15 @@ import ( ) func TestConfig_Load(t *testing.T) { + l := NewTestLogger() dir, err := ioutil.TempDir("", "config-test") // invalid yaml - c := NewConfig() + c := NewConfig(l) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge - c = NewConfig() + c = NewConfig(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) @@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) { } func TestConfig_Get(t *testing.T) { + l := NewTestLogger() // test simple type - c := NewConfig() + c := NewConfig(l) c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) @@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) { } func TestConfig_GetStringSlice(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) c.Settings["slice"] = []interface{}{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } func TestConfig_GetBool(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) c.Settings["bool"] = true assert.Equal(t, true, c.GetBool("bool", false)) @@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) { } func TestConfig_GetAllowList(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0": true, } @@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) { } func TestConfig_HasChanged(t *testing.T) { + l := NewTestLogger() // No reload has occurred, return false - c := NewConfig() + c := NewConfig(l) c.Settings["test"] = "hi" assert.False(t, c.HasChanged("")) // Test key change - c = NewConfig() + c = NewConfig(l) c.Settings["test"] = "hi" c.oldSettings = map[interface{}]interface{}{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change - c = NewConfig() + c = NewConfig(l) c.Settings["test"] = "hi" c.oldSettings = map[interface{}]interface{}{"test": "hi"} assert.False(t, c.HasChanged("test")) @@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) { } func TestConfig_ReloadConfig(t *testing.T) { + l := NewTestLogger() done := make(chan bool, 1) dir, err := ioutil.TempDir("", "config-test") assert.Nil(t, err) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) - c := NewConfig() + c := NewConfig(l) assert.Nil(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) diff --git a/connection_manager.go b/connection_manager.go index bc2ce05..db58274 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -28,10 +28,11 @@ type connectionManager struct { checkInterval int pendingDeletionInterval int + l *logrus.Logger // I wanted to call one matLock } -func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { +func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { nc := &connectionManager{ hostMap: intf.hostMap, in: make(map[uint32]struct{}), @@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), checkInterval: checkInterval, pendingDeletionInterval: pendingDeletionInterval, + l: l, } nc.Start() return nc @@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) // If we saw incoming packets from this ip, just return if traf { - if l.Level >= logrus.DebugLevel { - l.WithField("vpnIp", IntIp(vpnIP)). + if n.l.Level >= logrus.DebugLevel { + n.l.WithField("vpnIp", IntIp(vpnIP)). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } @@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) // If we didn't we may need to probe or destroy the conn hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) if err != nil { - l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) n.ClearIP(vpnIP) n.ClearPendingDeletion(vpnIP) continue } - hostinfo.logger(). + hostinfo.logger(n.l). WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") @@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out) } else { - hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) + hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) } n.AddPendingDeletion(vpnIP) } @@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { // If we saw incoming packets from this ip, just return traf := n.CheckIn(vpnIP) if traf { - l.WithField("vpnIp", IntIp(vpnIP)). + n.l.WithField("vpnIp", IntIp(vpnIP)). WithField("tunnelCheck", m{"state": "alive", "method": "active"}). Debug("Tunnel status") n.ClearIP(vpnIP) @@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { if err != nil { n.ClearIP(vpnIP) n.ClearPendingDeletion(vpnIP) - l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) continue } @@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { cn = hostinfo.ConnectionState.peerCert.Details.Name } - hostinfo.logger(). + hostinfo.logger(n.l). WithField("tunnelCheck", m{"state": "dead", "method": "active"}). WithField("certName", cn). Info("Tunnel status") diff --git a/connection_manager_test.go b/connection_manager_test.go index 15baae2..31489ed 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -13,6 +13,7 @@ import ( var vpnIP uint32 func Test_NewConnectionManagerTest(t *testing.T) { + l := NewTestLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap("test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) cs := &CertState{ rawCertificate: []byte{}, privateKey: []byte{}, @@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, @@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) { certState: cs, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + l: l, } now := time.Now() // Create manager - nc := newConnectionManager(ifce, 5, 10) + nc := newConnectionManager(l, ifce, 5, 10) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) { } func Test_NewConnectionManagerTest2(t *testing.T) { + l := NewTestLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap("test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) cs := &CertState{ rawCertificate: []byte{}, privateKey: []byte{}, @@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, @@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) { certState: cs, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + l: l, } now := time.Now() // Create manager - nc := newConnectionManager(ifce, 5, 10) + nc := newConnectionManager(l, ifce, 5, 10) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) diff --git a/connection_state.go b/connection_state.go index 25cdc58..c28cc42 100644 --- a/connection_state.go +++ b/connection_state.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "github.com/flynn/noise" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" ) @@ -26,7 +27,7 @@ type ConnectionState struct { ready bool } -func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) if f.cipher == "chachapoly" { cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) @@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss - b.Update(0) + b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ CipherSuite: cs, diff --git a/control_test.go b/control_test.go index 4f6beac..a411fc1 100644 --- a/control_test.go +++ b/control_test.go @@ -13,9 +13,10 @@ import ( ) func TestControl_GetHostInfoByVpnIP(t *testing.T) { + l := NewTestLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0)) + hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) remote1 := NewUDPAddr(int2ip(100), 4444) remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ diff --git a/dns_server.go b/dns_server.go index 3fec6b6..a4e1f13 100644 --- a/dns_server.go +++ b/dns_server.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) // This whole thing should be rewritten to use context @@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) { d.Unlock() } -func parseQuery(m *dns.Msg, w dns.ResponseWriter) { +func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { case dns.TypeA: @@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { +func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(m, w) + parseQuery(l, m, w) } w.WriteMsg(m) } -func dnsMain(hostMap *HostMap, c *Config) { +func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) { dnsR = newDnsRecords(hostMap) // attach request handler func - dns.HandleFunc(".", handleDnsRequest) + dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { + handleDnsRequest(l, w, r) + }) - c.RegisterReloadCallback(reloadDns) - startDns(c) + c.RegisterReloadCallback(func(c *Config) { + reloadDns(l, c) + }) + startDns(l, c) } func getDnsServerAddr(c *Config) string { return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) } -func startDns(c *Config) { +func startDns(l *logrus.Logger, c *Config) { dnsAddr = getDnsServerAddr(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.Debugf("Starting DNS responder at %s\n", dnsAddr) @@ -133,7 +138,7 @@ func startDns(c *Config) { } } -func reloadDns(c *Config) { +func reloadDns(l *logrus.Logger, c *Config) { if dnsAddr == getDnsServerAddr(c) { l.Debug("No DNS server config change detected") return @@ -141,5 +146,5 @@ func reloadDns(c *Config) { l.Debug("Restarting DNS server") dnsServer.Shutdown() - go startDns(c) + go startDns(l, c) } diff --git a/firewall.go b/firewall.go index f09a701..81f3377 100644 --- a/firewall.go +++ b/firewall.go @@ -70,6 +70,7 @@ type Firewall struct { trackTCPRTT bool metricTCPRTT metrics.Histogram + l *logrus.Logger } type FirewallConntrack struct { @@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. -func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { //TODO: error on 0 duration var min, max time.Duration @@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N DefaultTimeout: defaultTimeout, localIps: localIps, metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), + l: l, } } -func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { fw := NewFirewall( + l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), @@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er //TODO: max_connections ) - err := AddFirewallRulesFromConfig(false, c, fw) + err := AddFirewallRulesFromConfig(l, false, c, fw) if err != nil { return nil, err } - err = AddFirewallRulesFromConfig(true, c, fw) + err = AddFirewallRulesFromConfig(l, true, c, fw) if err != nil { return nil, err } @@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). Info("Firewall rule added") var ( @@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string { return hex.EncodeToString(sum[:]) } -func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa for i, t := range rs { var groups []string - r, err := convertRule(t, table, i) + r, err := convertRule(l, t, table, i) if err != nil { return fmt.Errorf("%s rule #%v; %s", table, i, err) } @@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { - if l.Level >= logrus.DebugLevel { - h.logger(). + if f.l.Level >= logrus.DebugLevel { + h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). WithField("rulesVersion", f.rulesVersion). @@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H return false } - if l.Level >= logrus.DebugLevel { - h.logger(). + if f.l.Level >= logrus.DebugLevel { + h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). WithField("rulesVersion", f.rulesVersion). @@ -795,7 +798,7 @@ type rule struct { CASha string } -func convertRule(p interface{}, table string, i int) (rule, error) { +func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[interface{}]interface{}) @@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) { // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get() ConntrackCache { +func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { if c == nil { return nil } if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if l.GetLevel() == logrus.DebugLevel { + if l.Level == logrus.DebugLevel { l.WithField("len", ll).Debug("resetting conntrack cache") } c.cache = make(ConntrackCache, ll) diff --git a/firewall_test.go b/firewall_test.go index 3995e8d..43902cd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -15,8 +15,9 @@ import ( ) func TestNewFirewall(t *testing.T) { + l := NewTestLogger() c := &cert.NebulaCertificate{} - fw := NewFirewall(time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) @@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(time.Second, time.Hour, time.Minute, c) + fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(time.Hour, time.Second, time.Minute, c) + fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(time.Hour, time.Minute, time.Second, c) + fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(time.Minute, time.Hour, time.Second, c) + fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(time.Minute, time.Second, time.Hour, c) + fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) } func TestFirewall_AddRule(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) c := &cert.NebulaCertificate{} - fw := NewFirewall(time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) assert.False(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") @@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) @@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") // Set any and clear fields - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") @@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) // Test error conditions - fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", "")) } func TestFirewall_Drop(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), @@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) { } h.CreateRemoteCIDR(&c) - fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) cp := cert.NewCAPool() @@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) { p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) @@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), @@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) { } h1.CreateRemoteCIDR(&c1) - fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) cp := cert.NewCAPool() @@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), @@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) { } h3.CreateRemoteCIDR(&c3) - fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) cp := cert.NewCAPool() @@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), @@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } h.CreateRemoteCIDR(&c) - fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) cp := cert.NewCAPool() @@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) { } func TestNewFirewallFromConfig(t *testing.T) { + l := NewTestLogger() // Test a bad rule definition c := &cert.NebulaCertificate{} - conf := NewConfig() + conf := NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(c, conf) + _, err := NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") // Test both group and groups - conf = NewConfig() + conf = NewConfig(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(c, conf) + _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { + l := NewTestLogger() // Test adding tcp rule - conf := NewConfig() + conf := NewConfig(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding udp rule - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding icmp rule - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding any rule - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding rule with ca_sha - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) // Test single group - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) // Test single groups - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) // Test multiple AND groups - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) // Test Add error - conf = NewConfig() + conf = NewConfig(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`") + assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestTCPRTTTracking(t *testing.T) { @@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { + l := NewTestLogger() ob := &bytes.Buffer{} - out := l.Out l.SetOutput(ob) - defer l.SetOutput(out) // Ensure group array of 1 is converted and a warning is printed c := map[interface{}]interface{}{ "group": []interface{}{"group1"}, } - r, err := convertRule(c, "test", 1) + r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Nil(t, err) assert.Equal(t, "group1", r.Group) @@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) { "group": []interface{}{"group1", "group2"}, } - r, err = convertRule(c, "test", 1) + r, err = convertRule(l, c, "test", 1) assert.Equal(t, "", ob.String()) assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") @@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) { "group": "group1", } - r, err = convertRule(c, "test", 1) + r, err = convertRule(l, c, "test", 1) assert.Nil(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/handshake.go b/handshake.go index aa4cd8f..a703ff8 100644 --- a/handshake.go +++ b/handshake.go @@ -7,7 +7,7 @@ const ( func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) { - l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } diff --git a/handshake_ix.go b/handshake_ix.go index 5d2dd84..63070de 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { err := f.handshakeManager.AddIndexHostInfo(hostinfo) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return } @@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { hsBytes, err = proto.Marshal(hs) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return } @@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { msg, _, _, err := ci.H.WriteMessage(header, hsBytes) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } // We are sending handshake packet 1, so we don't expect to receive // handshake packet 1 from the responder - ci.window.Update(1) + ci.window.Update(f.l, 1) hostinfo.HandshakePacket[0] = msg hostinfo.HandshakeReady = true @@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { - ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) + ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(1) + ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) if err != nil { - l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") return } @@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) */ if err != nil || hs.Details == nil { - l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") return } remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) if err != nil { - l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). Info("Invalid certificate from host") return @@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { fingerprint, _ := remoteCert.Sha256Sum() if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } - myIndex, err := generateIndex() + myIndex, err := generateIndex(f.l) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") @@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { HandshakePacket: make(map[uint8][]byte, 0), } - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hsBytes, err := proto.Marshal(hs) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") @@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") @@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { // We are sending handshake packet 2, so we don't expect to receive // handshake packet 2 from the initiator. - ci.window.Update(2) + ci.window.Update(f.l, 2) ci.peerCert = remoteCert ci.dKey = NewNebulaCipherState(dKey) @@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) err := f.outside.WriteTo(msg, addr) if err != nil { - l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { case ErrExistingHostInfo: // This means there was an existing tunnel and we didn't win // handshake avoidance - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) err = f.outside.WriteTo(msg, addr) if err != nil { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { Info("Handshake message sent") } - hostinfo.handshakeComplete() + hostinfo.handshakeComplete(f.l) return } @@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ defer hostinfo.Unlock() if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Info("Already seen this handshake packet") return false @@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ ci := hostinfo.ConnectionState // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(2) + ci.window.Update(f.l, 2) hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:])) copy(hostinfo.HandshakePacket[2], packet[HeaderLen:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // near future return false } else if dKey == nil || eKey == nil { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") return true @@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hs := &NebulaHandshake{} err = proto.Unmarshal(msg, hs) if err != nil || hs.Details == nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") return true } remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Invalid certificate from host") return true @@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ fingerprint, _ := remoteCert.Sha256Sum() duration := time.Since(hostinfo.handshakeStart).Nanoseconds() - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hostinfo.CreateRemoteCIDR(remoteCert) f.handshakeManager.Complete(hostinfo, f) - hostinfo.handshakeComplete() + hostinfo.handshakeComplete(f.l) f.metricHandshakes.Update(duration) return false diff --git a/handshake_manager.go b/handshake_manager.go index f82e603..099d002 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -53,11 +53,12 @@ type HandshakeManager struct { InboundHandshakeTimer *SystemTimerWheel messageMetrics *MessageMetrics + l *logrus.Logger } -func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges), + pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), messageMetrics: config.messageMetrics, + l: l, } } @@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) { for { select { case vpnIP := <-c.trigger: - l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") + c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") c.handleOutbound(vpnIP, f, true) case now := <-clockSource: c.NextOutboundHandshakeTimerTick(now, f) @@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) if err != nil { - hostinfo.logger().WithField("udpAddr", hostinfo.remote). + hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote). WithField("initiatorIndex", hostinfo.localIndexId). WithField("remoteIndex", hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). @@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT } else { //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should // keep the real packet struct around for logging purposes - hostinfo.logger().WithField("udpAddr", hostinfo.remote). + hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote). WithField("initiatorIndex", hostinfo.localIndexId). WithField("remoteIndex", hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). @@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(). + hostinfo.logger(c.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). Info("New host shadows existing host remoteIndex") } @@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(). + hostinfo.logger(c.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). Info("New host shadows existing host remoteIndex") } @@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { defer c.mainHostMap.RUnlock() for i := 0; i < 32; i++ { - index, err := generateIndex() + index, err := generateIndex(c.l) if err != nil { return err } @@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() { // Utility functions below -func generateIndex() (uint32, error) { +func generateIndex(l *logrus.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 712028a..ee1c640 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -12,15 +12,15 @@ import ( var ips []uint32 func Test_NewHandshakeManagerIndex(t *testing.T) { - + l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap("test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) now := time.Now() blah.NextInboundHandshakeTimerTick(now) @@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) { } func Test_NewHandshakeManagerVpnIP(t *testing.T) { + l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} - mainHM := NewHostMap("test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) @@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) { } func Test_NewHandshakeManagerTrigger(t *testing.T) { + l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ip := ip2int(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} - mainHM := NewHostMap("test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) lh := &LightHouse{} - blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) @@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) { } func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { + l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") vpnIP = ip2int(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} - mainHM := NewHostMap("test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) @@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { } func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { + l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap("test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) now := time.Now() blah.NextInboundHandshakeTimerTick(now) diff --git a/hostmap.go b/hostmap.go index 58d3c5d..b718ee5 100644 --- a/hostmap.go +++ b/hostmap.go @@ -33,6 +33,7 @@ type HostMap struct { defaultRoute uint32 unsafeRoutes *CIDRTree metricsEnabled bool + l *logrus.Logger } type HostInfo struct { @@ -83,7 +84,7 @@ type Probe struct { Counter int } -func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { +func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { h := map[uint32]*HostInfo{} i := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{} @@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) * vpnCIDR: vpnCIDR, defaultRoute: 0, unsafeRoutes: NewCIDRTree(), + l: l, } return &m } @@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { } hm.Unlock() - if l.Level >= logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). Debug("Hostmap vpnIp deleted") } } @@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { hm.RemoteIndexes[index] = h hm.Unlock() - if l.Level > logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), + if hm.l.Level > logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). Debug("Hostmap remoteIndex added") } @@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) { hm.RemoteIndexes[h.remoteIndexId] = h hm.Unlock() - if l.Level > logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), + if hm.l.Level > logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). Debug("Hostmap vpnIp added") } @@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) { } hm.Unlock() - if l.Level >= logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). Debug("Hostmap index deleted") } } @@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) { } hm.Unlock() - if l.Level >= logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). Debug("Hostmap remote index deleted") } } @@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { } hm.Unlock() - if l.Level >= logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), "vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } @@ -313,8 +315,10 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo { } i.remote = i.Remotes[0].addr hm.Hosts[vpnIp] = i - l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). - Debug("Hostmap remote ip added") + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). + Debug("Hostmap remote ip added") + } } i.ForcePromoteBest(hm.preferredRanges) hm.Unlock() @@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo - if l.Level >= logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). Debug("Hostmap vpnIp added") } @@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) { func (hm *HostMap) addUnsafeRoutes(routes *[]route) { for _, r := range *routes { - l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route") + hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route") hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via)) } } @@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() { i.remote = i.Remotes[0].addr } -func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { +func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { //TODO: return the error so we can log with more context if len(i.packetStore) < 100 { tempPacket := make([]byte, len(packet)) @@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) if l.Level >= logrus.DebugLevel { - i.logger(). + i.logger(l). WithField("length", len(i.packetStore)). WithField("stored", true). Debugf("Packet store") } } else if l.Level >= logrus.DebugLevel { - i.logger(). + i.logger(l). WithField("length", len(i.packetStore)). WithField("stored", false). Debugf("Packet store") @@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac } // handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets -func (i *HostInfo) handshakeComplete() { +func (i *HostInfo) handshakeComplete(l *logrus.Logger) { //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical @@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() { atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2) if l.Level >= logrus.DebugLevel { - i.logger().Debugf("Sending %d stored packets", len(i.packetStore)) + i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) } if len(i.packetStore) > 0 { @@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { i.remoteCidr = remoteCidr } -func (i *HostInfo) logger() *logrus.Entry { +func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { if i == nil { return logrus.NewEntry(l) } @@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) { // Utility functions -func localIps(allowList *AllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { //FIXME: This function is pretty garbage var ips []net.IP ifaces, _ := net.Interfaces() diff --git a/hostmap_test.go b/hostmap_test.go index bbc47b5..f158b9e 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) { */ func TestHostmap(t *testing.T) { + l := NewTestLogger() _, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") myNets := []*net.IPNet{myNet} preferredRanges := []*net.IPNet{localToMe} - m := NewHostMap("test", myNet, preferredRanges) + m := NewHostMap(l, "test", myNet, preferredRanges) a := NewUDPAddrFromString("10.127.0.3:11111") b := NewUDPAddrFromString("1.0.0.1:22222") @@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) { } func TestHostmapdebug(t *testing.T) { + l := NewTestLogger() _, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") preferredRanges := []*net.IPNet{localToMe} - m := NewHostMap("test", myNet, preferredRanges) + m := NewHostMap(l, "test", myNet, preferredRanges) a := NewUDPAddrFromString("10.127.0.3:11111") b := NewUDPAddrFromString("1.0.0.1:22222") @@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) { } func BenchmarkHostmappromote2(b *testing.B) { + l := NewTestLogger() for n := 0; n < b.N; n++ { _, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") preferredRanges := []*net.IPNet{localToMe} - m := NewHostMap("test", myNet, preferredRanges) + m := NewHostMap(l, "test", myNet, preferredRanges) y := NewUDPAddrFromString("10.128.0.3:11111") a := NewUDPAddrFromString("10.127.0.3:11111") g := NewUDPAddrFromString("1.0.0.1:22222") diff --git a/inside.go b/inside.go index 5e898f2..c682f19 100644 --- a/inside.go +++ b/inside.go @@ -10,7 +10,7 @@ import ( func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { - l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) return } @@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, hostinfo := f.getOrHandshake(fwPacket.RemoteIP) if hostinfo == nil { - if l.Level >= logrus.DebugLevel { - l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)). WithField("fwPacket", fwPacket). Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") } @@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, // the packet queue. ci.queueLock.Lock() if !ci.ready { - hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) + hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow) ci.queueLock.Unlock() return } @@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, f.lightHouse.Query(fwPacket.RemoteIP, f) } - } else if l.Level >= logrus.DebugLevel { - hostinfo.logger(). + } else if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). WithField("fwPacket", fwPacket). WithField("reason", dropReason). Debugln("dropping outbound packet") @@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { if ci == nil { // if we don't have a connection state, then send a handshake initiation - ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0) + ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. //ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0) hostinfo.ConnectionState = ci @@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, fp := &FirewallPacket{} err := newPacket(p, false, fp) if err != nil { - l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) if dropReason != nil { - if l.Level >= logrus.DebugLevel { - l.WithField("fwPacket", fp). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("fwPacket", fp). WithField("reason", dropReason). Debugln("dropping cached packet") } @@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { hostInfo := f.getOrHandshake(vpnIp) if hostInfo == nil { - if l.Level >= logrus.DebugLevel { - l.WithField("vpnIp", IntIp(vpnIp)). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnIp", IntIp(vpnIp)). Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") } return @@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT // the packet queue. hostInfo.ConnectionState.queueLock.Lock() if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp) + hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp) hostInfo.ConnectionState.queueLock.Unlock() return } @@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { hostInfo := f.getOrHandshake(vpnIp) if hostInfo == nil { - if l.Level >= logrus.DebugLevel { - l.WithField("vpnIp", IntIp(vpnIp)). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnIp", IntIp(vpnIp)). Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes") } return @@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp // the packet queue. hostInfo.ConnectionState.queueLock.Lock() if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(t, st, p, f.sendMessageToAll) + hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll) hostInfo.ConnectionState.queueLock.Unlock() return } @@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.Query(hostinfo.hostId, f) hostinfo.lastRebindCount = f.rebindCount - if l.Level >= logrus.DebugLevel { - l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter") + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, //TODO: see above note on lock //ci.writeLock.Unlock() if err != nil { - hostinfo.logger().WithError(err). + hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).WithField("counter", c). WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") @@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, err = f.writers[q].WriteTo(out, remote) if err != nil { - hostinfo.logger().WithError(err). + hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } return c diff --git a/interface.go b/interface.go index e90fef0..8bd95ae 100644 --- a/interface.go +++ b/interface.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" ) const mtu = 9001 @@ -42,6 +43,7 @@ type InterfaceConfig struct { version string ConntrackCacheTimeout time.Duration + l *logrus.Logger } type Interface struct { @@ -73,6 +75,7 @@ type Interface struct { metricHandshakes metrics.Histogram messageMetrics *MessageMetrics + l *logrus.Logger } func NewInterface(c *InterfaceConfig) (*Interface, error) { @@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, + l: c.l, } - ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval) + ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval) return ifce, nil } @@ -125,10 +129,10 @@ func (f *Interface) run() { addr, err := f.outside.LocalAddr() if err != nil { - l.WithError(err).Error("Failed to get udp listen address") + f.l.WithError(err).Error("Failed to get udp listen address") } - l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()). + f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()). WithField("build", f.version).WithField("udpAddr", addr). Info("Nebula interface is active") @@ -140,14 +144,14 @@ func (f *Interface) run() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - l.Fatal(err) + f.l.Fatal(err) } } f.readers[i] = reader } if err := f.inside.Activate(); err != nil { - l.Fatal(err) + f.l.Fatal(err) } // Launch n queues to read packets from udp @@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - l.WithError(err).Error("Error while reading outbound packet") + f.l.WithError(err).Error("Error while reading outbound packet") // This only seems to happen when something fatal happens to the fd, so exit. os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } } @@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) { func (f *Interface) reloadCA(c *Config) { // reload and check regardless // todo: need mutex? - newCAs, err := loadCAFromConfig(c) + newCAs, err := loadCAFromConfig(f.l, c) if err != nil { - l.WithError(err).Error("Could not refresh trusted CA certificates") + f.l.WithError(err).Error("Could not refresh trusted CA certificates") return } trustedCAs = newCAs - l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") + f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") } func (f *Interface) reloadCertKey(c *Config) { // reload and check in all cases cs, err := NewCertStateFromConfig(c) if err != nil { - l.WithError(err).Error("Could not refresh client cert") + f.l.WithError(err).Error("Could not refresh client cert") return } @@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) { oldIPs := f.certState.certificate.Details.Ips newIPs := cs.certificate.Details.Ips if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { - l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") + f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") return } f.certState = cs - l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") + f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") } func (f *Interface) reloadFirewall(c *Config) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { - l.Debug("No firewall config change detected") + f.l.Debug("No firewall config change detected") return } - fw, err := NewFirewallFromConfig(f.certState.certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c) if err != nil { - l.WithError(err).Error("Error while creating firewall during reload") + f.l.WithError(err).Error("Error while creating firewall during reload") return } @@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - l.WithField("firewallHash", fw.GetRuleHash()). + f.l.WithField("firewallHash", fw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("rulesVersion", fw.rulesVersion). Warn("firewall rulesVersion has overflowed, resetting conntrack") @@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) { f.firewall = fw oldFw.Destroy() - l.WithField("firewallHash", fw.GetRuleHash()). + f.l.WithField("firewallHash", fw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") diff --git a/lighthouse.go b/lighthouse.go index 861fe77..84f1dd2 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -48,6 +48,7 @@ type LightHouse struct { metrics *MessageMetrics metricHolepunchTx metrics.Counter + l *logrus.Logger } type EncWriter interface { @@ -55,7 +56,7 @@ type EncWriter interface { SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) } -func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { +func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { h := LightHouse{ amLighthouse: amLighthouse, myIp: myIp, @@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n punchConn: pc, punchBack: punchBack, punchDelay: punchDelay, + l: l, } if metricsEnabled { @@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { // Send a query to the lighthouses and hope for the best next time query, err := proto.Marshal(NewLhQueryByInt(ip)) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") + lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") return } @@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { lh.Lock() //l.Debugln(lh.addrMap) delete(lh.addrMap, vpnIP) - l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) + lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) lh.Unlock() } @@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) { } allow := lh.remoteAllowList.Allow(toIp.IP) - l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow") + lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow") if !allow { return } @@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) { var v4 []*IpAndPort var v6 []*Ip6AndPort - for _, e := range *localIps(lh.localAllowList) { + for _, e := range *localIps(lh.l, lh.localAllowList) { // Only add IPs that aren't my VPN/tun IP if ip2int(e) != lh.myIp { ipp := NewIpAndPort(e, lh.nebulaPort) @@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) { for vpnIp := range lh.lighthouses { mm, err := proto.Marshal(m) if err != nil { - l.Debugf("Invalid marshal to update") + lh.l.Debugf("Invalid marshal to update") } //l.Error("LIGHTHOUSE PACKET SEND", mm) f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out) @@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by n := lhh.resetMeta() err := proto.UnmarshalMerge(p, n) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") //TODO: send recv_error? return } if n.Details == nil { - l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") //TODO: send recv_error? return @@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by case NebulaMeta_HostQuery: // Exit if we don't answer queries if !lh.amLighthouse { - l.Debugln("I don't answer queries, but received from: ", rAddr) + lh.l.Debugln("I don't answer queries, but received from: ", rAddr) return } @@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by reply, err := proto.Marshal(n) if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") + lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") return } lh.metricTx(NebulaMeta_HostQueryReply, 1) @@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by // This signals the other side to punch some zero byte udp packets ips, err = lh.Query(vpnIp, f) if err != nil { - l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch") + lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch") return } else { //l.Debugln("Notify host to punch", iap) @@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by case NebulaMeta_HostUpdateNotification: //Simple check that the host sent this not someone else if n.Details.VpnIp != vpnIp { - l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") return } @@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by }() - if l.Level >= logrus.DebugLevel { + if lh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) + lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) } } @@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by }() - if l.Level >= logrus.DebugLevel { + if lh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) + lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) } } @@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by if lh.punchBack { go func() { time.Sleep(time.Second * 5) - l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) + lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) // TODO we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. diff --git a/lighthouse_test.go b/lighthouse_test.go index edd4d7a..0e96c2a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) { } func Test_lhStaticMapping(t *testing.T) { + l := NewTestLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) - udpServer, _ := NewListener("0.0.0.0", 0, true) + udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) err := meh.ValidateLHStaticEntries() assert.Nil(t, err) @@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) { lh2 := "10.128.0.3" lh2IP := net.ParseIP(lh2) - meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) + meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) err = meh.ValidateLHStaticEntries() assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") } func BenchmarkLighthouseHandleRequest(b *testing.B) { + l := NewTestLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) - udpServer, _ := NewListener("0.0.0.0", 0, true) + udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) hAddr := NewUDPAddrFromString("4.5.6.7:12345") hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") @@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } func Test_lhRemoteAllowList(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) c.Settings["remoteallowlist"] = map[interface{}]interface{}{ "10.20.0.0/12": false, } @@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) { lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) - udpServer, _ := NewListener("0.0.0.0", 0, true) + udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) lh.SetRemoteAllowList(allowList) remote1 := "10.20.0.3" diff --git a/main.go b/main.go index e4c937a..5396a39 100644 --- a/main.go +++ b/main.go @@ -11,13 +11,10 @@ import ( "gopkg.in/yaml.v2" ) -// The caller should provide a real logger, we have one just in case -var l = logrus.New() - type m map[string]interface{} func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) { - l = logger + l := logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, } @@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L }) // trustedCAs is currently a global, so loadCA operates on that global directly - trustedCAs, err = loadCAFromConfig(config) + trustedCAs, err = loadCAFromConfig(l, config) if err != nil { //The errors coming out of loadCA are already nicely formatted return nil, NewContextualError("Failed to load ca from config", nil, err) @@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - fw, err := NewFirewallFromConfig(cs.certificate, config) + fw, err := NewFirewallFromConfig(l, cs.certificate, config) if err != nil { return nil, NewContextualError("Error while loading firewall rules", nil, err) } @@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) - wireSSHReload(ssh, config) + wireSSHReload(l, ssh, config) if config.GetBool("sshd.enabled", false) { - err = configSSH(ssh, config) + err = configSSH(l, ssh, config) if err != nil { return nil, NewContextualError("Error while configuring the sshd", nil, err) } @@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l) case tunFd != nil: tun, err = newTunFromFd( + l, *tunFd, tunCidr, config.GetInt("tun.mtu", DEFAULT_MTU), @@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L ) default: tun, err = newTun( + l, config.GetString("tun.dev", ""), tunCidr, config.GetInt("tun.mtu", DEFAULT_MTU), @@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L if !configTest { for i := 0; i < routines; i++ { - udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1) + udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1) if err != nil { return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } } - hostMap := NewHostMap("main", tunCidr, preferredRanges) + hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) @@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } lightHouse := NewLightHouse( + l, amLighthouse, ip2int(tunCidr.IP), lighthouseHosts, @@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L messageMetrics: messageMetrics, } - handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) + handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger //TODO: These will be reused for psk @@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L version: buildVersion, ConntrackCacheTimeout: conntrackCacheTimeout, + l: l, } switch ifConfig.Cipher { @@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L go lightHouse.LhUpdateWorker(ifce) } - err = startStats(config, configTest) + err = startStats(l, config, configTest) if err != nil { return nil, NewContextualError("Failed to start stats emitter", nil, err) } @@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L //TODO: check if we _should_ be emitting stats go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) - attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) + attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host if amLighthouse && serveDns { l.Debugln("Starting dns server") - go dnsMain(hostMap, config) + go dnsMain(l, hostMap, config) } return &Control{ifce, l}, nil diff --git a/main_test.go b/main_test.go index 2808317..f638011 100644 --- a/main_test.go +++ b/main_test.go @@ -1 +1,30 @@ package nebula + +import ( + "io/ioutil" + "os" + + "github.com/sirupsen/logrus" +) + +func NewTestLogger() *logrus.Logger { + l := logrus.New() + + v := os.Getenv("TEST_LOGS") + if v == "" { + l.SetOutput(ioutil.Discard) + return l + } + + switch v { + case "1": + // This is the default level but we are being explicit + l.SetLevel(logrus.InfoLevel) + case "2": + l.SetLevel(logrus.DebugLevel) + case "3": + l.SetLevel(logrus.TraceLevel) + } + + return l +} diff --git a/outside.go b/outside.go index c199157..9acd2e1 100644 --- a/outside.go +++ b/outside.go @@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) } return } @@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) if err != nil { - hostinfo.logger().WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") @@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) if err != nil { - hostinfo.logger().WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). WithField("packet", packet). Error("Failed to decrypt test packet") @@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - hostinfo.logger().WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", addr). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) @@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, default: f.messageMetrics.Rx(header.Type, header.Subtype, 1) - hostinfo.logger().Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) return } @@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { if hostDidRoam(hostinfo.remote, addr) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) { - hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") + hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") return } if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { - if l.Level >= logrus.DebugLevel { - hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() remoteCopy := *hostinfo.remote @@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. - if ci == nil || !ci.window.Check(header.MessageCounter) { + if ci == nil || !ci.window.Check(f.l, header.MessageCounter) { f.sendRecvError(addr, header.RemoteIndex) return false } @@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return nil, err } - if !hostinfo.ConnectionState.window.Update(mc) { - hostinfo.logger().WithField("header", header). + if !hostinfo.ConnectionState.window.Update(f.l, mc) { + hostinfo.logger(f.l).WithField("header", header). Debugln("dropping out of window packet") return nil, errors.New("out of window packet") } @@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) if err != nil { - hostinfo.logger().WithError(err).Error("Failed to decrypt packet") + hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") //TODO: maybe after build 64 is out? 06/14/2018 - NB //f.sendRecvError(hostinfo.remote, header.RemoteIndex) return @@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out err = newPacket(out, true, fwPacket) if err != nil { - hostinfo.logger().WithError(err).WithField("packet", out). + hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") return } - if !hostinfo.ConnectionState.window.Update(messageCounter) { - hostinfo.logger().WithField("fwPacket", fwPacket). + if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). Debugln("dropping out of window packet") return } dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) if dropReason != nil { - if l.Level >= logrus.DebugLevel { - hostinfo.logger().WithField("fwPacket", fwPacket). + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). Debugln("dropping inbound packet") } @@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out f.connectionManager.In(hostinfo.hostId) _, err = f.readers[q].Write(out) if err != nil { - l.WithError(err).Error("Failed to write to tun") + f.l.WithError(err).Error("Failed to write to tun") } } @@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { //TODO: this should be a signed message so we can trust that we should drop the index b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) f.outside.WriteTo(b, endpoint) - if l.Level >= logrus.DebugLevel { - l.WithField("index", index). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("index", index). WithField("udpAddr", endpoint). Debug("Recv error sent") } } func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { - if l.Level >= logrus.DebugLevel { - l.WithField("index", h.RemoteIndex). + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). Debug("Recv error received") } @@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) if err != nil { - l.Debugln(err, ": ", h.RemoteIndex) + f.l.Debugln(err, ": ", h.RemoteIndex) return } @@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { return } if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() { - l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } diff --git a/punchy_test.go b/punchy_test.go index 145dbe0..2ab570f 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -8,7 +8,8 @@ import ( ) func TestNewPunchyFromConfig(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) // Test defaults p := NewPunchyFromConfig(c) diff --git a/ssh.go b/ssh.go index 25badf3..a9b0729 100644 --- a/ssh.go +++ b/ssh.go @@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct { Address string } -func wireSSHReload(ssh *sshd.SSHServer, c *Config) { +func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { c.RegisterReloadCallback(func(c *Config) { if c.GetBool("sshd.enabled", false) { - err := configSSH(ssh, c) + err := configSSH(l, ssh, c) if err != nil { l.WithError(err).Error("Failed to reconfigure the sshd") ssh.Stop() @@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) { }) } -func configSSH(ssh *sshd.SSHServer, c *Config) error { +func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error { //TODO conntrack list //TODO print firewall rules or hash? @@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error { return nil } -func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { +func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", - Callback: sshLogLevel, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshLogLevel(l, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ Name: "log-format", ShortDescription: "Gets or sets the current log format", - Callback: sshLogFormat, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshLogFormat(l, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } @@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } -func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } diff --git a/stats.go b/stats.go index fec1189..e09966e 100644 --- a/stats.go +++ b/stats.go @@ -13,9 +13,10 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" ) -func startStats(c *Config, configTest bool) error { +func startStats(l *logrus.Logger, c *Config, configTest bool) error { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { return nil @@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error { switch mType { case "graphite": - startGraphiteStats(interval, c, configTest) + startGraphiteStats(l, interval, c, configTest) case "prometheus": - startPrometheusStats(interval, c, configTest) + startPrometheusStats(l, interval, c, configTest) default: return fmt.Errorf("stats.type was not understood: %s", mType) } @@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error { return nil } -func startGraphiteStats(i time.Duration, c *Config, configTest bool) error { +func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { proto := c.GetString("stats.protocol", "tcp") host := c.GetString("stats.host", "") if host == "" { @@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error { return nil } -func startPrometheusStats(i time.Duration, c *Config, configTest bool) error { +func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") diff --git a/tun_android.go b/tun_android.go index 8f6463f..b71a627 100644 --- a/tun_android.go +++ b/tun_android.go @@ -6,6 +6,7 @@ import ( "net" "os" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) @@ -19,9 +20,10 @@ type Tun struct { TXQueueLen int Routes []route UnsafeRoutes []route + l *logrus.Logger } -func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") ifce = &Tun{ @@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, TXQueueLen: txQueueLen, Routes: routes, UnsafeRoutes: unsafeRoutes, + l: l, } return } diff --git a/tun_darwin.go b/tun_darwin.go index 0dfbe3c..0e39481 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -9,6 +9,7 @@ import ( "os/exec" "strconv" + "github.com/sirupsen/logrus" "github.com/songgao/water" ) @@ -17,11 +18,11 @@ type Tun struct { Cidr *net.IPNet MTU int UnsafeRoutes []route - + l *logrus.Logger *water.Interface } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("route MTU not supported in Darwin") } @@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, Cidr: cidr, MTU: defaultMTU, UnsafeRoutes: unsafeRoutes, + l: l, }, nil } -func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/tun_disabled.go b/tun_disabled.go index 78eeed8..5db8961 100644 --- a/tun_disabled.go +++ b/tun_disabled.go @@ -9,24 +9,23 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - log "github.com/sirupsen/logrus" ) type disabledTun struct { - read chan []byte - cidr *net.IPNet - logger *log.Logger + read chan []byte + cidr *net.IPNet // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter + l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun { +func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - logger: l, + cidr: cidr, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) { } t.tx.Inc(1) - if l.Level >= logrus.DebugLevel { - t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload") + if t.l.Level >= logrus.DebugLevel { + t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") } return copy(b, r), nil @@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { select { case t.read <- buf: default: - t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response") + t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") } return true @@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) { // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { - if l.Level >= logrus.DebugLevel { - t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") + if t.l.Level >= logrus.DebugLevel { + t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") } - } else if l.Level >= logrus.DebugLevel { - t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") + } else if t.l.Level >= logrus.DebugLevel { + t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") } return len(b), nil } diff --git a/tun_freebsd.go b/tun_freebsd.go index 7e6f98c..4415401 100644 --- a/tun_freebsd.go +++ b/tun_freebsd.go @@ -9,6 +9,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/sirupsen/logrus" ) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -18,15 +20,16 @@ type Tun struct { Cidr *net.IPNet MTU int UnsafeRoutes []route + l *logrus.Logger io.ReadWriteCloser } -func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("Route MTU not supported in FreeBSD") } @@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, Cidr: cidr, MTU: defaultMTU, UnsafeRoutes: unsafeRoutes, + l: l, }, nil } @@ -52,21 +56,21 @@ func (c *Tun) Activate() error { } // TODO use syscalls instead of exec.Command - l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()) + c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()) if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device) + c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device) if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) } - l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)) + c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)) if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } // Unsafe path routes for _, r := range c.UnsafeRoutes { - l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device) + c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device) if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil { return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err) } diff --git a/tun_linux.go b/tun_linux.go index 4d0707b..5dd3e5d 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -10,6 +10,7 @@ import ( "strings" "unsafe" + "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) @@ -24,6 +25,7 @@ type Tun struct { TXQueueLen int Routes []route UnsafeRoutes []route + l *logrus.Logger } type ifReq struct { @@ -78,7 +80,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, TXQueueLen: txQueueLen, Routes: routes, UnsafeRoutes: unsafeRoutes, + l: l, } return } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, TXQueueLen: txQueueLen, Routes: routes, UnsafeRoutes: unsafeRoutes, + l: l, } return } @@ -233,14 +237,14 @@ func (c Tun) Activate() error { ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - l.WithError(err).Error("Failed to set tun mtu") + c.l.WithError(err).Error("Failed to set tun mtu") } // Set the transmit queue length ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)} if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") + c.l.WithError(err).Error("Failed to set tun tx queue length") } // Bring up the interface diff --git a/tun_test.go b/tun_test.go index ef7f34c..08ff10f 100644 --- a/tun_test.go +++ b/tun_test.go @@ -9,7 +9,8 @@ import ( ) func Test_parseRoutes(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config @@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) { } func Test_parseUnsafeRoutes(t *testing.T) { - c := NewConfig() + l := NewTestLogger() + c := NewConfig(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config diff --git a/tun_windows.go b/tun_windows.go index 040f1ce..594675e 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -7,6 +7,7 @@ import ( "os/exec" "strconv" + "github.com/sirupsen/logrus" "github.com/songgao/water" ) @@ -15,15 +16,16 @@ type Tun struct { Cidr *net.IPNet MTU int UnsafeRoutes []route + l *logrus.Logger *water.Interface } -func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("route MTU not supported in Windows") } @@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, Cidr: cidr, MTU: defaultMTU, UnsafeRoutes: unsafeRoutes, + l: l, }, nil } diff --git a/udp_android.go b/udp_android.go index ac5606c..9e688f3 100644 --- a/udp_android.go +++ b/udp_android.go @@ -1,3 +1,5 @@ +// +build !e2e_testing + package nebula import ( diff --git a/udp_darwin.go b/udp_darwin.go index 52506b5..861334f 100644 --- a/udp_darwin.go +++ b/udp_darwin.go @@ -1,3 +1,5 @@ +// +build !e2e_testing + package nebula // Darwin support is primarily implemented in udp_generic, besides NewListenConfig diff --git a/udp_freebsd.go b/udp_freebsd.go index 5910a9d..184e092 100644 --- a/udp_freebsd.go +++ b/udp_freebsd.go @@ -1,3 +1,5 @@ +// +build !e2e_testing + package nebula // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig diff --git a/udp_generic.go b/udp_generic.go index 492a695..64978a2 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -1,4 +1,5 @@ // +build !linux android +// +build !e2e_testing // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. @@ -9,20 +10,23 @@ import ( "context" "fmt" "net" + + "github.com/sirupsen/logrus" ) type udpConn struct { *net.UDPConn + l *logrus.Logger } -func NewListener(ip string, port int, multi bool) (*udpConn, error) { +func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { - return &udpConn{UDPConn: uc}, nil + return &udpConn{UDPConn: uc, l: l}, nil } return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } @@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) { // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) if err != nil { - l.WithError(err).Error("Failed to read packets") + f.l.WithError(err).Error("Failed to read packets") continue } udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get()) + f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l)) } } diff --git a/udp_linux.go b/udp_linux.go index 764b21a..c49aea5 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -1,4 +1,5 @@ // +build !android +// +build !e2e_testing package nebula @@ -10,6 +11,7 @@ import ( "unsafe" "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) @@ -17,6 +19,7 @@ import ( type udpConn struct { sysFd int + l *logrus.Logger } var x int @@ -38,7 +41,7 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(ip string, port int, multi bool) (*udpConn, error) { +func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { syscall.ForkLock.RLock() fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { @@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) { //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &udpConn{sysFd: fd}, err + return &udpConn{sysFd: fd, l: l}, err } func (u *udpConn) Rebind() error { @@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { for { n, err := read(msgs) if err != nil { - l.WithError(err).Error("Failed to read packets") + u.l.WithError(err).Error("Failed to read packets") continue } @@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { for i := 0; i < n; i++ { udpAddr.IP = names[i][8:24] udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get()) + f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l)) } } } @@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) { if err == nil { s, err := u.GetRecvBuffer() if err == nil { - l.WithField("size", s).Info("listen.read_buffer was set") + u.l.WithField("size", s).Info("listen.read_buffer was set") } else { - l.WithError(err).Warn("Failed to get listen.read_buffer") + u.l.WithError(err).Warn("Failed to get listen.read_buffer") } } else { - l.WithError(err).Error("Failed to set listen.read_buffer") + u.l.WithError(err).Error("Failed to set listen.read_buffer") } } @@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) { if err == nil { s, err := u.GetSendBuffer() if err == nil { - l.WithField("size", s).Info("listen.write_buffer was set") + u.l.WithField("size", s).Info("listen.write_buffer was set") } else { - l.WithError(err).Warn("Failed to get listen.write_buffer") + u.l.WithError(err).Warn("Failed to get listen.write_buffer") } } else { - l.WithError(err).Error("Failed to set listen.write_buffer") + u.l.WithError(err).Error("Failed to set listen.write_buffer") } } } diff --git a/udp_linux_32.go b/udp_linux_32.go index 0b2f6b1..e8c56b2 100644 --- a/udp_linux_32.go +++ b/udp_linux_32.go @@ -1,6 +1,7 @@ // +build linux // +build 386 amd64p32 arm mips mipsle // +build !android +// +build !e2e_testing package nebula diff --git a/udp_linux_64.go b/udp_linux_64.go index faeda69..3ed1c2e 100644 --- a/udp_linux_64.go +++ b/udp_linux_64.go @@ -1,6 +1,7 @@ // +build linux // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x // +build !android +// +build !e2e_testing package nebula