Move util to test, contextual errors to util (#575)

This commit is contained in:
Nate Brown 2021-11-10 21:47:38 -06:00 committed by GitHub
parent 19a9a4221e
commit 4453964e34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 117 additions and 106 deletions

View File

@ -7,12 +7,12 @@ import (
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewAllowListFromConfig(t *testing.T) { func TestNewAllowListFromConfig(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true, "192.168.0.0": true,

View File

@ -3,12 +3,12 @@ package nebula
import ( import (
"testing" "testing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBits(t *testing.T) { func TestBits(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(10)
// make sure it is the right size // make sure it is the right size
@ -76,7 +76,7 @@ func TestBits(t *testing.T) {
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
} }
func TestBitsOutOfWindowCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
} }
func TestBitsLostCounter(t *testing.T) { func TestBitsLostCounter(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()

View File

@ -9,7 +9,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
cc := c.Copy() cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc) test.AssertDeepCopyEqual(t, c, cc)
} }
func TestUnmarshalNebulaCertificate(t *testing.T) { func TestUnmarshalNebulaCertificate(t *testing.T) {

View File

@ -8,6 +8,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
) )
// A version string that can be set with // A version string that can be set with
@ -60,7 +61,7 @@ func main() {
ctrl, err := nebula.Main(c, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {
case nebula.ContextualError: case util.ContextualError:
v.Log(l) v.Log(l)
os.Exit(1) os.Exit(1)
case error: case error:

View File

@ -8,6 +8,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
) )
// A version string that can be set with // A version string that can be set with
@ -54,7 +55,7 @@ func main() {
ctrl, err := nebula.Main(c, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {
case nebula.ContextualError: case util.ContextualError:
v.Log(l) v.Log(l)
os.Exit(1) os.Exit(1)
case error: case error:

View File

@ -7,12 +7,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
// invalid yaml // invalid yaml
c := NewC(l) c := NewC(l)
@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) {
} }
func TestConfig_Get(t *testing.T) { func TestConfig_Get(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
// test simple type // test simple type
c := NewC(l) c := NewC(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) {
} }
func TestConfig_GetStringSlice(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := NewC(l) c := NewC(l)
c.Settings["slice"] = []interface{}{"one", "two"} c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
} }
func TestConfig_GetBool(t *testing.T) { func TestConfig_GetBool(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := NewC(l) c := NewC(l)
c.Settings["bool"] = true c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false)) assert.Equal(t, true, c.GetBool("bool", false))
@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) {
} }
func TestConfig_HasChanged(t *testing.T) { func TestConfig_HasChanged(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
// No reload has occurred, return false // No reload has occurred, return false
c := NewC(l) c := NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) {
} }
func TestConfig_ReloadConfig(t *testing.T) { func TestConfig_ReloadConfig(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
done := make(chan bool, 1) done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -11,15 +11,15 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var vpnIp iputil.VpnIp var vpnIp iputil.VpnIp
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
} }
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Disconnect only if disconnectInvalid: true is set. // Disconnect only if disconnectInvalid: true is set.
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
now := time.Now() now := time.Now()
l := util.NewTestLogger() l := test.NewLogger()
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: net.IPv4(172, 1, 1, 2), IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},

View File

@ -9,13 +9,13 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestControl_GetHostInfoByVpnIp(t *testing.T) { func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() { assert.NotPanics(t, func() {

View File

@ -14,12 +14,12 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack conntrack := fw.Conntrack
@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) {
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) {
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) {
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) {
} }
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
// Test a bad rule definition // Test a bad rule definition
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
conf := config.NewC(l) conf := config.NewC(l)
@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
} }
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
// Test adding tcp rule // Test adding tcp rule
conf := config.NewC(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)

View File

@ -7,13 +7,13 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_NewHandshakeManagerVpnIp(t *testing.T) { func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) { func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")

View File

@ -8,8 +8,8 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) {
} }
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) {
} }
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := util.NewTestLogger() l := test.NewLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
} }
func TestLighthouse_Memory(t *testing.T) { func TestLighthouse_Memory(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
//TODO: this is a RemoteList test //TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) { //func Test_lhRemoteAllowList(t *testing.T) {
// l := NewTestLogger() // l := NewLogger()
// c := NewConfig(l) // c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{ // c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false, // "10.20.0.0/12": false,

View File

@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -10,38 +9,6 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}
func configLogger(l *logrus.Logger, c *config.C) error { func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level // set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))

43
main.go
View File

@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
err := configLogger(l, c) err := configLogger(l, c)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to configure the logger", nil, err) return nil, util.NewContextualError("Failed to configure the logger", nil, err)
} }
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
caPool, err := loadCAFromConfig(l, c) caPool, err := loadCAFromConfig(l, c)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err) return nil, util.NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(c) cs, err := NewCertStateFromConfig(c)
if err != nil { if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted //The errors coming out of NewCertStateFromConfig are already nicely formatted
return nil, NewContextualError("Failed to load certificate from config", nil, err) return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(l, cs.certificate, c) fw, err := NewFirewallFromConfig(l, cs.certificate, c)
if err != nil { if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err) return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
} }
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
tunCidr := cs.certificate.Details.Ips[0] tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(c, tunCidr) routes, err := parseRoutes(c, tunCidr)
if err != nil { if err != nil {
return nil, NewContextualError("Could not parse tun.routes", nil, err) return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil { if err != nil {
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err) return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
} }
} }
@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
if err != nil { if err != nil {
return nil, NewContextualError("Failed to get a tun/tap device", nil, err) return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
} }
} }
@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil { if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
udpServer.ReloadConfig(c) udpServer.ReloadConfig(c)
udpConns[i] = udpServer udpConns[i] = udpServer
@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if port == 0 { if port == 0 {
uPort, err := udpServer.LocalAddr() uPort, err := udpServer.LocalAddr()
if err != nil { if err != nil {
return nil, NewContextualError("Failed to get listening port", nil, err) return nil, util.NewContextualError("Failed to get listening port", nil, err)
} }
port = int(uPort.Port) port = int(uPort.Port)
} }
@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange) _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to parse preferred ranges", nil, err) return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
} }
preferredRanges = append(preferredRanges, preferredRange) preferredRanges = append(preferredRanges, preferredRange)
} }
@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if rawLocalRange != "" { if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange) _, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to parse local_range", nil, err) return nil, util.NewContextualError("Failed to parse local_range", nil, err)
} }
// Check if the entry for local_range was already specified in // Check if the entry for local_range was already specified in
@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// fatal if am_lighthouse is enabled but we are using an ephemeral port // fatal if am_lighthouse is enabled but we are using an ephemeral port
if amLighthouse && (c.GetInt("listen.port", 0) == 0) { if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
} }
// warn if am_lighthouse is enabled but upstream lighthouses exists // warn if am_lighthouse is enabled but upstream lighthouses exists
@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i, host := range rawLighthouseHosts { for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
} }
if !tunCidr.Contains(ip) { if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
} }
lighthouseHosts[i] = iputil.Ip2VpnIp(ip) lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
} }
@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
if err != nil { if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
} }
lightHouse.SetRemoteAllowList(remoteAllowList) lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
if err != nil { if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
} }
lightHouse.SetLocalAllowList(localAllowList) lightHouse.SetLocalAllowList(localAllowList)
@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
ip := net.ParseIP(fmt.Sprintf("%v", k)) ip := net.ParseIP(fmt.Sprintf("%v", k))
vpnIp := iputil.Ip2VpnIp(ip) vpnIp := iputil.Ip2VpnIp(ip)
if !tunCidr.Contains(ip) { if !tunCidr.Contains(ip) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
} }
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if ok {
for _, v := range vals { for _, v := range vals {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
} }
} else { } else {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
} }
@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart, err := startStats(l, c, buildVersion, configTest) statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err) return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
} }
if configTest { if configTest {

View File

@ -5,12 +5,12 @@ import (
"time" "time"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
// Test defaults // Test defaults

View File

@ -1,4 +1,4 @@
package util package test
import ( import (
"fmt" "fmt"

View File

@ -1,4 +1,4 @@
package util package test
import ( import (
"io/ioutil" "io/ioutil"
@ -7,7 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func NewTestLogger() *logrus.Logger { func NewLogger() *logrus.Logger {
l := logrus.New() l := logrus.New()
v := os.Getenv("TEST_LOGS") v := os.Getenv("TEST_LOGS")

View File

@ -6,12 +6,12 @@ import (
"testing" "testing"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")
@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) {
} }
func Test_parseUnsafeRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) {
l := util.NewTestLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")

39
util/error.go Normal file
View File

@ -0,0 +1,39 @@
package util
import (
"errors"
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

View File

@ -1,4 +1,4 @@
package nebula package util
import ( import (
"errors" "errors"
@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type m map[string]interface{}
type TestLogWriter struct { type TestLogWriter struct {
Logs []string Logs []string
} }