From 4453964e34631cb1fce827eb11360da36f7774bf Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 10 Nov 2021 21:47:38 -0600 Subject: [PATCH] Move util to test, contextual errors to util (#575) --- allow_list_test.go | 4 +-- bits_test.go | 10 +++---- cert/cert_test.go | 4 +-- cmd/nebula-service/main.go | 3 +- cmd/nebula/main.go | 3 +- config/config_test.go | 14 ++++----- connection_manager_test.go | 8 +++--- control_test.go | 6 ++-- firewall_test.go | 20 ++++++------- handshake_manager_test.go | 6 ++-- lighthouse_test.go | 10 +++---- logger.go | 33 --------------------- main.go | 43 ++++++++++++++-------------- punchy_test.go | 4 +-- {util => test}/assert.go | 2 +- util/main.go => test/logger.go | 4 +-- tun_test.go | 6 ++-- util/error.go | 39 +++++++++++++++++++++++++ logger_test.go => util/error_test.go | 4 ++- 19 files changed, 117 insertions(+), 106 deletions(-) rename {util => test}/assert.go (99%) rename util/main.go => test/logger.go (86%) create mode 100644 util/error.go rename logger_test.go => util/error_test.go (97%) diff --git a/allow_list_test.go b/allow_list_test.go index 038a6b2..991b8a3 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -7,12 +7,12 @@ import ( "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestNewAllowListFromConfig(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := config.NewC(l) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0": true, diff --git a/bits_test.go b/bits_test.go index 3135dfa..95abe01 100644 --- a/bits_test.go +++ b/bits_test.go @@ -3,12 +3,12 @@ package nebula import ( "testing" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestBits(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() b := NewBits(10) // make sure it is the right size @@ -76,7 +76,7 @@ func TestBits(t *testing.T) { } func TestBitsDupeCounter(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() @@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) { } func TestBitsOutOfWindowCounter(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() @@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) { } func TestBitsLostCounter(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() diff --git a/cert/cert_test.go b/cert/cert_test.go index 4fe13cc..fef3a3c 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/golang/protobuf/proto" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" @@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) { assert.Nil(t, err) cc := c.Copy() - util.AssertDeepCopyEqual(t, c, cc) + test.AssertDeepCopyEqual(t, c, cc) } func TestUnmarshalNebulaCertificate(t *testing.T) { diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 5040e28..f211c97 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -8,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) // A version string that can be set with @@ -60,7 +61,7 @@ func main() { ctrl, err := nebula.Main(c, *configTest, Build, l, nil) switch v := err.(type) { - case nebula.ContextualError: + case util.ContextualError: v.Log(l) os.Exit(1) case error: diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index a2923c7..efe406b 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -8,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) // A version string that can be set with @@ -54,7 +55,7 @@ func main() { ctrl, err := nebula.Main(c, *configTest, Build, l, nil) switch v := err.(type) { - case nebula.ContextualError: + case util.ContextualError: v.Log(l) os.Exit(1) case error: diff --git a/config/config_test.go b/config/config_test.go index a5254bd..8dfcbb8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -7,12 +7,12 @@ import ( "testing" "time" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestConfig_Load(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() dir, err := ioutil.TempDir("", "config-test") // invalid yaml c := NewC(l) @@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) { } func TestConfig_Get(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() // test simple type c := NewC(l) c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} @@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) { } func TestConfig_GetStringSlice(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := NewC(l) c.Settings["slice"] = []interface{}{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } func TestConfig_GetBool(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := NewC(l) c.Settings["bool"] = true assert.Equal(t, true, c.GetBool("bool", false)) @@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) { } func TestConfig_HasChanged(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() // No reload has occurred, return false c := NewC(l) c.Settings["test"] = "hi" @@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) { } func TestConfig_ReloadConfig(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() done := make(chan bool, 1) dir, err := ioutil.TempDir("", "config-test") assert.Nil(t, err) diff --git a/connection_manager_test.go b/connection_manager_test.go index 9f2fe6e..f05eaee 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -11,15 +11,15 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" - "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) var vpnIp iputil.VpnIp func Test_NewConnectionManagerTest(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { } func Test_NewConnectionManagerTest2(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Disconnect only if disconnectInvalid: true is set. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() - l := util.NewTestLogger() + l := test.NewLogger() ipNet := net.IPNet{ IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, diff --git a/control_test.go b/control_test.go index 08aa151..3b6cc2a 100644 --- a/control_test.go +++ b/control_test.go @@ -9,13 +9,13 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" - "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) @@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) - util.AssertDeepCopyEqual(t, &expectedInfo, thi) + test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { diff --git a/firewall_test.go b/firewall_test.go index b98a2cf..ce6ba18 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -14,12 +14,12 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestNewFirewall(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := &cert.NebulaCertificate{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack @@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) { } func TestNewFirewallFromConfig(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() // Test a bad rule definition c := &cert.NebulaCertificate{} conf := config.NewC(l) @@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) { } func TestAddFirewallRulesFromConfig(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() // Test adding tcp rule conf := config.NewC(l) mf := &mockFirewall{} @@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) diff --git a/handshake_manager_test.go b/handshake_manager_test.go index dfc8d2c..0ca651c 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -7,13 +7,13 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" - "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func Test_NewHandshakeManagerVpnIp(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { } func Test_NewHandshakeManagerTrigger(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") diff --git a/lighthouse_test.go b/lighthouse_test.go index 03c96b9..41fde97 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -8,8 +8,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" - "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) @@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) { } func Test_lhStaticMapping(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) @@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) { } func BenchmarkLighthouseHandleRequest(b *testing.B) { - l := util.NewTestLogger() + l := test.NewLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) @@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } func TestLighthouse_Memory(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} @@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, //TODO: this is a RemoteList test //func Test_lhRemoteAllowList(t *testing.T) { -// l := NewTestLogger() +// l := NewLogger() // c := NewConfig(l) // c.Settings["remoteallowlist"] = map[interface{}]interface{}{ // "10.20.0.0/12": false, diff --git a/logger.go b/logger.go index 8846264..aaf6f29 100644 --- a/logger.go +++ b/logger.go @@ -1,7 +1,6 @@ package nebula import ( - "errors" "fmt" "strings" "time" @@ -10,38 +9,6 @@ import ( "github.com/slackhq/nebula/config" ) -type ContextualError struct { - RealError error - Fields map[string]interface{} - Context string -} - -func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { - return ContextualError{Context: msg, Fields: fields, RealError: realError} -} - -func (ce ContextualError) Error() string { - if ce.RealError == nil { - return ce.Context - } - return ce.RealError.Error() -} - -func (ce ContextualError) Unwrap() error { - if ce.RealError == nil { - return errors.New(ce.Context) - } - return ce.RealError -} - -func (ce *ContextualError) Log(lr *logrus.Logger) { - if ce.RealError != nil { - lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) - } else { - lr.WithFields(ce.Fields).Error(ce.Context) - } -} - func configLogger(l *logrus.Logger, c *config.C) error { // set up our logging level logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) diff --git a/main.go b/main.go index 91418e1..8dc9536 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/util" "gopkg.in/yaml.v2" ) @@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg err := configLogger(l, c) if err != nil { - return nil, NewContextualError("Failed to configure the logger", nil, err) + return nil, util.NewContextualError("Failed to configure the logger", nil, err) } c.RegisterReloadCallback(func(c *config.C) { @@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg caPool, err := loadCAFromConfig(l, c) if err != nil { //The errors coming out of loadCA are already nicely formatted - return nil, NewContextualError("Failed to load ca from config", nil, err) + return nil, util.NewContextualError("Failed to load ca from config", nil, err) } l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") cs, err := NewCertStateFromConfig(c) if err != nil { //The errors coming out of NewCertStateFromConfig are already nicely formatted - return nil, NewContextualError("Failed to load certificate from config", nil, err) + return nil, util.NewContextualError("Failed to load certificate from config", nil, err) } l.WithField("cert", cs.certificate).Debug("Client nebula certificate") fw, err := NewFirewallFromConfig(l, cs.certificate, c) if err != nil { - return nil, NewContextualError("Error while loading firewall rules", nil, err) + return nil, util.NewContextualError("Error while loading firewall rules", nil, err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") @@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg tunCidr := cs.certificate.Details.Ips[0] routes, err := parseRoutes(c, tunCidr) if err != nil { - return nil, NewContextualError("Could not parse tun.routes", nil, err) + return nil, util.NewContextualError("Could not parse tun.routes", nil, err) } unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) if err != nil { - return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) + return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) @@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, NewContextualError("Error while configuring the sshd", nil, err) + return nil, util.NewContextualError("Error while configuring the sshd", nil, err) } } @@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } if err != nil { - return nil, NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) } } @@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for i := 0; i < routines; i++ { udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { - return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) + return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } udpServer.ReloadConfig(c) udpConns[i] = udpServer @@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if port == 0 { uPort, err := udpServer.LocalAddr() if err != nil { - return nil, NewContextualError("Failed to get listening port", nil, err) + return nil, util.NewContextualError("Failed to get listening port", nil, err) } port = int(uPort.Port) } @@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - return nil, NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - return nil, NewContextualError("Failed to parse local_range", nil, err) + return nil, util.NewContextualError("Failed to parse local_range", nil, err) } // Check if the entry for local_range was already specified in @@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // fatal if am_lighthouse is enabled but we are using an ephemeral port if amLighthouse && (c.GetInt("listen.port", 0) == 0) { - return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) + return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) } // warn if am_lighthouse is enabled but upstream lighthouses exists @@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for i, host := range rawLighthouseHosts { ip := net.ParseIP(host) if ip == nil { - return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) } if !tunCidr.Contains(ip) { - return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) } lighthouseHosts[i] = iputil.Ip2VpnIp(ip) } @@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { - return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) + return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") if err != nil { - return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) + return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } lightHouse.SetLocalAllowList(localAllowList) @@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ip := net.ParseIP(fmt.Sprintf("%v", k)) vpnIp := iputil.Ip2VpnIp(ip) if !tunCidr.Contains(ip) { - return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) + return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) } vals, ok := v.([]interface{}) if ok { for _, v := range vals { ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) if err != nil { - return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) + return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) } } else { ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) if err != nil { - return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) + return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) } @@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg statsStart, err := startStats(l, c, buildVersion, configTest) if err != nil { - return nil, NewContextualError("Failed to start stats emitter", nil, err) + return nil, util.NewContextualError("Failed to start stats emitter", nil, err) } if configTest { diff --git a/punchy_test.go b/punchy_test.go index 8b8cd1a..89b5136 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -5,12 +5,12 @@ import ( "time" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestNewPunchyFromConfig(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := config.NewC(l) // Test defaults diff --git a/util/assert.go b/test/assert.go similarity index 99% rename from util/assert.go rename to test/assert.go index 6f13d6b..6c6c795 100644 --- a/util/assert.go +++ b/test/assert.go @@ -1,4 +1,4 @@ -package util +package test import ( "fmt" diff --git a/util/main.go b/test/logger.go similarity index 86% rename from util/main.go rename to test/logger.go index 0d84c73..197ab44 100644 --- a/util/main.go +++ b/test/logger.go @@ -1,4 +1,4 @@ -package util +package test import ( "io/ioutil" @@ -7,7 +7,7 @@ import ( "github.com/sirupsen/logrus" ) -func NewTestLogger() *logrus.Logger { +func NewLogger() *logrus.Logger { l := logrus.New() v := os.Getenv("TEST_LOGS") diff --git a/tun_test.go b/tun_test.go index 01c904a..9cbb548 100644 --- a/tun_test.go +++ b/tun_test.go @@ -6,12 +6,12 @@ import ( "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func Test_parseRoutes(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") @@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) { } func Test_parseUnsafeRoutes(t *testing.T) { - l := util.NewTestLogger() + l := test.NewLogger() c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") diff --git a/util/error.go b/util/error.go new file mode 100644 index 0000000..7f9bc47 --- /dev/null +++ b/util/error.go @@ -0,0 +1,39 @@ +package util + +import ( + "errors" + + "github.com/sirupsen/logrus" +) + +type ContextualError struct { + RealError error + Fields map[string]interface{} + Context string +} + +func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { + return ContextualError{Context: msg, Fields: fields, RealError: realError} +} + +func (ce ContextualError) Error() string { + if ce.RealError == nil { + return ce.Context + } + return ce.RealError.Error() +} + +func (ce ContextualError) Unwrap() error { + if ce.RealError == nil { + return errors.New(ce.Context) + } + return ce.RealError +} + +func (ce *ContextualError) Log(lr *logrus.Logger) { + if ce.RealError != nil { + lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) + } else { + lr.WithFields(ce.Fields).Error(ce.Context) + } +} diff --git a/logger_test.go b/util/error_test.go similarity index 97% rename from logger_test.go rename to util/error_test.go index 1594fb9..747d04e 100644 --- a/logger_test.go +++ b/util/error_test.go @@ -1,4 +1,4 @@ -package nebula +package util import ( "errors" @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +type m map[string]interface{} + type TestLogWriter struct { Logs []string }