diff --git a/allow_list.go b/allow_list.go index 97c13a0..0e44a12 100644 --- a/allow_list.go +++ b/allow_list.go @@ -4,11 +4,15 @@ import ( "fmt" "net" "regexp" + + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *CIDR6Tree + cidrTree *cidr.Tree6 } type RemoteAllowList struct { @@ -16,7 +20,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *CIDR6Tree + insideAllowLists *cidr.Tree6 } type LocalAllowList struct { @@ -31,6 +35,223 @@ type AllowListNameRule struct { Allow bool } +func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) { + var nameRules []AllowListNameRule + handleKey := func(key string, value interface{}) (bool, error) { + if key == "interfaces" { + var err error + nameRules, err = getAllowListInterfaces(k, value) + if err != nil { + return false, err + } + + return true, nil + } + return false, nil + } + + al, err := newAllowListFromConfig(c, k, handleKey) + if err != nil { + return nil, err + } + return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil +} + +func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) { + al, err := newAllowListFromConfig(c, k, nil) + if err != nil { + return nil, err + } + remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey) + if err != nil { + return nil, err + } + return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { + r := c.Get(k) + if r == nil { + return nil, nil + } + + return newAllowList(k, r, handleKey) +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { + rawMap, ok := raw.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) + } + + tree := cidr.NewTree6() + + // Keep track of the rules we have added for both ipv4 and ipv6 + type allowListRules struct { + firstValue bool + allValuesMatch bool + defaultSet bool + allValues bool + } + + rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} + rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} + + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + if handleKey != nil { + handled, err := handleKey(rawCIDR, rawValue) + if err != nil { + return nil, err + } + if handled { + continue + } + } + + value, ok := rawValue.(bool) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) + } + + _, ipNet, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + // TODO: should we error on duplicate CIDRs in the config? + tree.AddCIDR(ipNet, value) + + maskBits, maskSize := ipNet.Mask.Size() + + var rules *allowListRules + if maskSize == 32 { + rules = &rules4 + } else { + rules = &rules6 + } + + if rules.firstValue { + rules.allValues = value + rules.firstValue = false + } else { + if value != rules.allValues { + rules.allValuesMatch = false + } + } + + // Check if this is 0.0.0.0/0 or ::/0 + if maskBits == 0 { + rules.defaultSet = true + } + } + + if !rules4.defaultSet { + if rules4.allValuesMatch { + _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") + tree.AddCIDR(zeroCIDR, !rules4.allValues) + } else { + return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) + } + } + + if !rules6.defaultSet { + if rules6.allValuesMatch { + _, zeroCIDR, _ := net.ParseCIDR("::/0") + tree.AddCIDR(zeroCIDR, !rules6.allValues) + } else { + return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) + } + } + + return &AllowList{cidrTree: tree}, nil +} + +func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { + var nameRules []AllowListNameRule + + rawRules, ok := v.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) + } + + firstEntry := true + var allValues bool + for rawName, rawAllow := range rawRules { + name, ok := rawName.(string) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName) + } + allow, ok := rawAllow.(bool) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) + } + + nameRE, err := regexp.Compile("^" + name + "$") + if err != nil { + return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err) + } + + nameRules = append(nameRules, AllowListNameRule{ + Name: nameRE, + Allow: allow, + }) + + if firstEntry { + allValues = allow + firstEntry = false + } else { + if allow != allValues { + return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k) + } + } + } + + return nameRules, nil +} + +func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + remoteAllowRanges := cidr.NewTree6() + + rawMap, ok := value.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) + if err != nil { + return nil, err + } + + _, ipNet, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + remoteAllowRanges.AddCIDR(ipNet, allowList) + } + + return remoteAllowRanges, nil +} + func (al *AllowList) Allow(ip net.IP) bool { if al == nil { return true @@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool { } } -func (al *AllowList) AllowIpV4(ip uint32) bool { +func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { if al == nil { return true } @@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool { +func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { if !al.getInsideAllowList(vpnIp).Allow(ip) { return false } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool { +func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { if al == nil { return true } @@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool { return al.AllowList.AllowIpV4(ip) } -func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool { +func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { if al == nil { return true } @@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool { return al.AllowList.AllowIpV6(hi, lo) } -func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList { +func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { if al.insideAllowLists != nil { inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) if inside != nil { diff --git a/allow_list_test.go b/allow_list_test.go index 2dcc3a1..038a6b2 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -5,21 +5,110 @@ import ( "regexp" "testing" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) +func TestNewAllowListFromConfig(t *testing.T) { + l := util.NewTestLogger() + c := config.NewC(l) + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0": true, + } + r, err := newAllowListFromConfig(c, "allowlist", nil) + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.Nil(t, r) + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0/16": "abc", + } + r, err = newAllowListFromConfig(c, "allowlist", nil) + assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0/16": true, + "10.0.0.0/8": false, + } + r, err = newAllowListFromConfig(c, "allowlist", nil) + assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "0.0.0.0/0": true, + "10.0.0.0/8": false, + "10.42.42.0/24": true, + "fd00::/8": true, + "fd00:fd00::/16": false, + } + r, err = newAllowListFromConfig(c, "allowlist", nil) + assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "0.0.0.0/0": true, + "10.0.0.0/8": false, + "10.42.42.0/24": true, + } + r, err = newAllowListFromConfig(c, "allowlist", nil) + if assert.NoError(t, err) { + assert.NotNil(t, r) + } + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "0.0.0.0/0": true, + "10.0.0.0/8": false, + "10.42.42.0/24": true, + "::/0": false, + "fd00::/8": true, + "fd00:fd00::/16": false, + } + r, err = newAllowListFromConfig(c, "allowlist", nil) + if assert.NoError(t, err) { + assert.NotNil(t, r) + } + + // Test interface names + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: "foo", + }, + } + lr, err := NewLocalAllowListFromConfig(c, "allowlist") + assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: false, + `eth.*`: true, + }, + } + lr, err = NewLocalAllowListFromConfig(c, "allowlist") + assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: false, + }, + } + lr, err = NewLocalAllowListFromConfig(c, "allowlist") + if assert.NoError(t, err) { + assert.NotNil(t, lr) + } +} + func TestAllowList_Allow(t *testing.T) { assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - tree := NewCIDR6Tree() - tree.AddCIDR(getCIDR("0.0.0.0/0"), true) - tree.AddCIDR(getCIDR("10.0.0.0/8"), false) - tree.AddCIDR(getCIDR("10.42.42.42/32"), true) - tree.AddCIDR(getCIDR("10.42.0.0/16"), true) - tree.AddCIDR(getCIDR("10.42.42.0/24"), true) - tree.AddCIDR(getCIDR("10.42.42.0/24"), false) - tree.AddCIDR(getCIDR("::1/128"), true) - tree.AddCIDR(getCIDR("::2/128"), false) + tree := cidr.NewTree6() + tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) + tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) + tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) + tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) + tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) + tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) + tree.AddCIDR(cidr.Parse("::1/128"), true) + tree.AddCIDR(cidr.Parse("::2/128"), false) al := &AllowList{cidrTree: tree} assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) diff --git a/bits_test.go b/bits_test.go index 50d00bc..3135dfa 100644 --- a/bits_test.go +++ b/bits_test.go @@ -3,11 +3,12 @@ package nebula import ( "testing" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func TestBits(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() b := NewBits(10) // make sure it is the right size @@ -75,7 +76,7 @@ func TestBits(t *testing.T) { } func TestBitsDupeCounter(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() @@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) { } func TestBitsOutOfWindowCounter(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() @@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) { } func TestBitsLostCounter(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() diff --git a/cert.go b/cert.go index dfee77a..f02ffe7 100644 --- a/cert.go +++ b/cert.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" ) type CertState struct { @@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert return cs, nil } -func NewCertStateFromConfig(c *Config) (*CertState, error) { +func NewCertStateFromConfig(c *config.C) (*CertState, error) { var pemPrivateKey []byte var err error @@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) { return NewCertState(nebulaCert, rawKey) } -func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) { +func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { var rawCA []byte var err error diff --git a/cidr/parse.go b/cidr/parse.go new file mode 100644 index 0000000..74367f6 --- /dev/null +++ b/cidr/parse.go @@ -0,0 +1,10 @@ +package cidr + +import "net" + +// Parse is a convenience function that returns only the IPNet +// This function ignores errors since it is primarily a test helper, the result could be nil +func Parse(s string) *net.IPNet { + _, c, _ := net.ParseCIDR(s) + return c +} diff --git a/cidr_radix.go b/cidr/tree4.go similarity index 57% rename from cidr_radix.go rename to cidr/tree4.go index aa36c60..28d0e78 100644 --- a/cidr_radix.go +++ b/cidr/tree4.go @@ -1,39 +1,39 @@ -package nebula +package cidr import ( - "encoding/binary" - "fmt" "net" + + "github.com/slackhq/nebula/iputil" ) -type CIDRNode struct { - left *CIDRNode - right *CIDRNode - parent *CIDRNode +type Node struct { + left *Node + right *Node + parent *Node value interface{} } -type CIDRTree struct { - root *CIDRNode +type Tree4 struct { + root *Node } const ( - startbit = uint32(0x80000000) + startbit = iputil.VpnIp(0x80000000) ) -func NewCIDRTree() *CIDRTree { - tree := new(CIDRTree) - tree.root = &CIDRNode{} +func NewTree4() *Tree4 { + tree := new(Tree4) + tree.root = &Node{} return tree } -func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) { +func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { bit := startbit node := tree.root next := tree.root - ip := ip2int(cidr.IP) - mask := ip2int(cidr.Mask) + ip := iputil.Ip2VpnIp(cidr.IP) + mask := iputil.Ip2VpnIp(cidr.Mask) // Find our last ancestor in the tree for bit&mask != 0 { @@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) { // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &CIDRNode{} + next = &Node{} next.parent = node if ip&bit != 0 { @@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) { } // Finds the first match, which may be the least specific -func (tree *CIDRTree) Contains(ip uint32) (value interface{}) { +func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) { } // Finds the most specific match -func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) { +func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) { } // Finds the most specific match -func (tree *CIDRTree) Match(ip uint32) (value interface{}) { +func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root lastNode := node @@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) { } return value } - -// A helper type to avoid converting to IP when logging -type IntIp uint32 - -func (ip IntIp) String() string { - return fmt.Sprintf("%v", int2ip(uint32(ip))) -} - -func (ip IntIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil -} - -func ip2int(ip []byte) uint32 { - if len(ip) == 16 { - return binary.BigEndian.Uint32(ip[12:16]) - } - return binary.BigEndian.Uint32(ip) -} - -func int2ip(nn uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, nn) - return ip -} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go new file mode 100644 index 0000000..07f2b0a --- /dev/null +++ b/cidr/tree4_test.go @@ -0,0 +1,153 @@ +package cidr + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/iputil" + "github.com/stretchr/testify/assert" +) + +func TestCIDRTree_Contains(t *testing.T) { + tree := NewTree4() + tree.AddCIDR(Parse("1.0.0.0/8"), "1") + tree.AddCIDR(Parse("2.1.0.0/16"), "2") + tree.AddCIDR(Parse("3.1.1.0/24"), "3") + tree.AddCIDR(Parse("4.1.1.0/24"), "4a") + tree.AddCIDR(Parse("4.1.1.1/32"), "4b") + tree.AddCIDR(Parse("4.1.2.1/32"), "4c") + tree.AddCIDR(Parse("254.0.0.0/4"), "5") + + tests := []struct { + Result interface{} + IP string + }{ + {"1", "1.0.0.0"}, + {"1", "1.255.255.255"}, + {"2", "2.1.0.0"}, + {"2", "2.1.255.255"}, + {"3", "3.1.1.0"}, + {"3", "3.1.1.255"}, + {"4a", "4.1.1.255"}, + {"4a", "4.1.1.1"}, + {"5", "240.0.0.0"}, + {"5", "255.255.255.255"}, + {nil, "239.0.0.0"}, + {nil, "4.1.2.2"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + } + + tree = NewTree4() + tree.AddCIDR(Parse("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) +} + +func TestCIDRTree_MostSpecificContains(t *testing.T) { + tree := NewTree4() + tree.AddCIDR(Parse("1.0.0.0/8"), "1") + tree.AddCIDR(Parse("2.1.0.0/16"), "2") + tree.AddCIDR(Parse("3.1.1.0/24"), "3") + tree.AddCIDR(Parse("4.1.1.0/24"), "4a") + tree.AddCIDR(Parse("4.1.1.0/30"), "4b") + tree.AddCIDR(Parse("4.1.1.1/32"), "4c") + tree.AddCIDR(Parse("254.0.0.0/4"), "5") + + tests := []struct { + Result interface{} + IP string + }{ + {"1", "1.0.0.0"}, + {"1", "1.255.255.255"}, + {"2", "2.1.0.0"}, + {"2", "2.1.255.255"}, + {"3", "3.1.1.0"}, + {"3", "3.1.1.255"}, + {"4a", "4.1.1.255"}, + {"4b", "4.1.1.2"}, + {"4c", "4.1.1.1"}, + {"5", "240.0.0.0"}, + {"5", "255.255.255.255"}, + {nil, "239.0.0.0"}, + {nil, "4.1.2.2"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + } + + tree = NewTree4() + tree.AddCIDR(Parse("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) +} + +func TestCIDRTree_Match(t *testing.T) { + tree := NewTree4() + tree.AddCIDR(Parse("4.1.1.0/32"), "1a") + tree.AddCIDR(Parse("4.1.1.1/32"), "1b") + + tests := []struct { + Result interface{} + IP string + }{ + {"1a", "4.1.1.0"}, + {"1b", "4.1.1.1"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + } + + tree = NewTree4() + tree.AddCIDR(Parse("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) +} + +func BenchmarkCIDRTree_Contains(b *testing.B) { + tree := NewTree4() + tree.AddCIDR(Parse("1.1.0.0/16"), "1") + tree.AddCIDR(Parse("1.2.1.1/32"), "1") + tree.AddCIDR(Parse("192.2.1.1/32"), "1") + tree.AddCIDR(Parse("172.2.1.1/32"), "1") + + ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) + b.Run("found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Contains(ip) + } + }) + + ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) + b.Run("not found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Contains(ip) + } + }) +} + +func BenchmarkCIDRTree_Match(b *testing.B) { + tree := NewTree4() + tree.AddCIDR(Parse("1.1.0.0/16"), "1") + tree.AddCIDR(Parse("1.2.1.1/32"), "1") + tree.AddCIDR(Parse("192.2.1.1/32"), "1") + tree.AddCIDR(Parse("172.2.1.1/32"), "1") + + ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) + b.Run("found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Match(ip) + } + }) + + ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) + b.Run("not found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Match(ip) + } + }) +} diff --git a/cidr6_radix.go b/cidr/tree6.go similarity index 74% rename from cidr6_radix.go rename to cidr/tree6.go index 4ae2570..d13c93d 100644 --- a/cidr6_radix.go +++ b/cidr/tree6.go @@ -1,26 +1,27 @@ -package nebula +package cidr import ( - "encoding/binary" "net" + + "github.com/slackhq/nebula/iputil" ) const startbit6 = uint64(1 << 63) -type CIDR6Tree struct { - root4 *CIDRNode - root6 *CIDRNode +type Tree6 struct { + root4 *Node + root6 *Node } -func NewCIDR6Tree() *CIDR6Tree { - tree := new(CIDR6Tree) - tree.root4 = &CIDRNode{} - tree.root6 = &CIDRNode{} +func NewTree6() *Tree6 { + tree := new(Tree6) + tree.root4 = &Node{} + tree.root6 = &Node{} return tree } -func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) { - var node, next *CIDRNode +func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { + var node, next *Node cidrIP, ipv4 := isIPV4(cidr.IP) if ipv4 { @@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) { } for i := 0; i < len(cidrIP); i += 4 { - ip := binary.BigEndian.Uint32(cidrIP[i : i+4]) - mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4]) + ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) + mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) bit := startbit // Find our last ancestor in the tree @@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) { // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &CIDRNode{} + next = &Node{} next.parent = node if ip&bit != 0 { @@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) { } // Finds the most specific match -func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) { - var node *CIDRNode +func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { + var node *Node wholeIP, ipv4 := isIPV4(ip) if ipv4 { @@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) { } for i := 0; i < len(wholeIP); i += 4 { - ip := ip2int(wholeIP[i : i+4]) + ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) bit := startbit for node != nil { @@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) { return value } -func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) { +func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root4 @@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) { return value } -func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { +func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { ip := hi node := tree.root6 diff --git a/cidr6_radix_test.go b/cidr/tree6_test.go similarity index 57% rename from cidr6_radix_test.go rename to cidr/tree6_test.go index 1e69d5c..b6dc4c2 100644 --- a/cidr6_radix_test.go +++ b/cidr/tree6_test.go @@ -1,6 +1,7 @@ -package nebula +package cidr import ( + "encoding/binary" "net" "testing" @@ -8,17 +9,17 @@ import ( ) func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewCIDR6Tree() - tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") - tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") - tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") - tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a") - tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b") - tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c") - tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c") + tree := NewTree6() + tree.AddCIDR(Parse("1.0.0.0/8"), "1") + tree.AddCIDR(Parse("2.1.0.0/16"), "2") + tree.AddCIDR(Parse("3.1.1.0/24"), "3") + tree.AddCIDR(Parse("4.1.1.1/24"), "4a") + tree.AddCIDR(Parse("4.1.1.1/30"), "4b") + tree.AddCIDR(Parse("4.1.1.1/32"), "4c") + tree.AddCIDR(Parse("254.0.0.0/4"), "5") + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { Result interface{} @@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) { assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP))) } - tree = NewCIDR6Tree() - tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") - tree.AddCIDR(getCIDR("::/0"), "cool6") + tree = NewTree6() + tree.AddCIDR(Parse("1.1.1.1/0"), "cool") + tree.AddCIDR(Parse("::/0"), "cool6") assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0"))) assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255"))) assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::"))) @@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) { } func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewCIDR6Tree() - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c") + tree := NewTree6() + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") + tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { Result interface{} @@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { } for _, tt := range tests { - ip := NewIp6AndPort(net.ParseIP(tt.IP), 0) - assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo)) + ip := net.ParseIP(tt.IP) + hi := binary.BigEndian.Uint64(ip[:8]) + lo := binary.BigEndian.Uint64(ip[8:]) + + assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo)) } } diff --git a/cidr_radix_test.go b/cidr_radix_test.go deleted file mode 100644 index 1e3fad1..0000000 --- a/cidr_radix_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package nebula - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewCIDRTree() - tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") - tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") - tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") - tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a") - tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b") - tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c") - tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") - - tests := []struct { - Result interface{} - IP string - }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4a", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP)))) - } - - tree = NewCIDRTree() - tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255")))) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewCIDRTree() - tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") - tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") - tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") - tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a") - tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b") - tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c") - tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") - - tests := []struct { - Result interface{} - IP string - }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP)))) - } - - tree = NewCIDRTree() - tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255")))) -} - -func TestCIDRTree_Match(t *testing.T) { - tree := NewCIDRTree() - tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a") - tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b") - - tests := []struct { - Result interface{} - IP string - }{ - {"1a", "4.1.1.0"}, - {"1b", "4.1.1.1"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP)))) - } - - tree = NewCIDRTree() - tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255")))) -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewCIDRTree() - tree.AddCIDR(getCIDR("1.1.0.0/16"), "1") - tree.AddCIDR(getCIDR("1.2.1.1/32"), "1") - tree.AddCIDR(getCIDR("192.2.1.1/32"), "1") - tree.AddCIDR(getCIDR("172.2.1.1/32"), "1") - - ip := ip2int(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = ip2int(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -} - -func BenchmarkCIDRTree_Match(b *testing.B) { - tree := NewCIDRTree() - tree.AddCIDR(getCIDR("1.1.0.0/16"), "1") - tree.AddCIDR(getCIDR("1.2.1.1/32"), "1") - tree.AddCIDR(getCIDR("192.2.1.1/32"), "1") - tree.AddCIDR(getCIDR("172.2.1.1/32"), "1") - - ip := ip2int(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Match(ip) - } - }) - - ip = ip2int(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Match(ip) - } - }) -} - -func getCIDR(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -} diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index ea189d2..5040e28 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -7,6 +7,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" ) // A version string that can be set with @@ -49,14 +50,14 @@ func main() { l := logrus.New() l.Out = os.Stdout - config := nebula.NewConfig(l) - err := config.Load(*configPath) + c := config.NewC(l) + err := c.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } - c, err := nebula.Main(config, *configTest, Build, l, nil) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) switch v := err.(type) { case nebula.ContextualError: @@ -68,8 +69,8 @@ func main() { } if !*configTest { - c.Start() - c.ShutdownBlock() + ctrl.Start() + ctrl.ShutdownBlock() } os.Exit(0) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index be2dee0..591e8e7 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -9,6 +9,7 @@ import ( "github.com/kardianos/service" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" ) var logger service.Logger @@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error { l := logrus.New() HookLogger(l) - config := nebula.NewConfig(l) - err := config.Load(*p.configPath) + c := config.NewC(l) + err := c.Load(*p.configPath) if err != nil { return fmt.Errorf("failed to load config: %s", err) } - p.control, err = nebula.Main(config, *p.configTest, Build, l, nil) + p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index cffd75a..a2923c7 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -7,6 +7,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" ) // A version string that can be set with @@ -43,14 +44,14 @@ func main() { l := logrus.New() l.Out = os.Stdout - config := nebula.NewConfig(l) - err := config.Load(*configPath) + c := config.NewC(l) + err := c.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } - c, err := nebula.Main(config, *configTest, Build, l, nil) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) switch v := err.(type) { case nebula.ContextualError: @@ -62,8 +63,8 @@ func main() { } if !*configTest { - c.Start() - c.ShutdownBlock() + ctrl.Start() + ctrl.ShutdownBlock() } os.Exit(0) diff --git a/config.go b/config.go deleted file mode 100644 index c4dce64..0000000 --- a/config.go +++ /dev/null @@ -1,611 +0,0 @@ -package nebula - -import ( - "context" - "errors" - "fmt" - "io/ioutil" - "net" - "os" - "os/signal" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "syscall" - "time" - - "github.com/imdario/mergo" - "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" -) - -type Config struct { - path string - files []string - Settings map[interface{}]interface{} - oldSettings map[interface{}]interface{} - callbacks []func(*Config) - l *logrus.Logger -} - -func NewConfig(l *logrus.Logger) *Config { - return &Config{ - Settings: make(map[interface{}]interface{}), - l: l, - } -} - -// Load will find all yaml files within path and load them in lexical order -func (c *Config) Load(path string) error { - c.path = path - c.files = make([]string, 0) - - err := c.resolve(path, true) - if err != nil { - return err - } - - if len(c.files) == 0 { - return fmt.Errorf("no config files found at %s", path) - } - - sort.Strings(c.files) - - err = c.parse() - if err != nil { - return err - } - - return nil -} - -func (c *Config) LoadString(raw string) error { - if raw == "" { - return errors.New("Empty configuration") - } - return c.parseRaw([]byte(raw)) -} - -// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered -// here should decide if they need to make a change to the current process before making the change. HasChanged can be -// used to help decide if a change is necessary. -// These functions should return quickly or spawn their own go routine if they will take a while -func (c *Config) RegisterReloadCallback(f func(*Config)) { - c.callbacks = append(c.callbacks, f) -} - -// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of -// k in both the old and new settings will be serialized, the result of the string comparison is returned. -// If k is an empty string the entire config is tested. -// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating -// there is change when there actually wasn't any. -func (c *Config) HasChanged(k string) bool { - if c.oldSettings == nil { - return false - } - - var ( - nv interface{} - ov interface{} - ) - - if k == "" { - nv = c.Settings - ov = c.oldSettings - k = "all settings" - } else { - nv = c.get(k, c.Settings) - ov = c.get(k, c.oldSettings) - } - - newVals, err := yaml.Marshal(nv) - if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") - } - - oldVals, err := yaml.Marshal(ov) - if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") - } - - return string(newVals) != string(oldVals) -} - -// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the -// original path provided to Load. The old settings are shallow copied for change detection after the reload. -func (c *Config) CatchHUP(ctx context.Context) { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGHUP) - - go func() { - for { - select { - case <-ctx.Done(): - signal.Stop(ch) - close(ch) - return - case <-ch: - c.l.Info("Caught HUP, reloading config") - c.ReloadConfig() - } - } - }() -} - -func (c *Config) ReloadConfig() { - c.oldSettings = make(map[interface{}]interface{}) - for k, v := range c.Settings { - c.oldSettings[k] = v - } - - err := c.Load(c.path) - if err != nil { - c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") - return - } - - for _, v := range c.callbacks { - v(c) - } -} - -// GetString will get the string for k or return the default d if not found or invalid -func (c *Config) GetString(k, d string) string { - r := c.Get(k) - if r == nil { - return d - } - - return fmt.Sprintf("%v", r) -} - -// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid -func (c *Config) GetStringSlice(k string, d []string) []string { - r := c.Get(k) - if r == nil { - return d - } - - rv, ok := r.([]interface{}) - if !ok { - return d - } - - v := make([]string, len(rv)) - for i := 0; i < len(v); i++ { - v[i] = fmt.Sprintf("%v", rv[i]) - } - - return v -} - -// GetMap will get the map for k or return the default d if not found or invalid -func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { - r := c.Get(k) - if r == nil { - return d - } - - v, ok := r.(map[interface{}]interface{}) - if !ok { - return d - } - - return v -} - -// GetInt will get the int for k or return the default d if not found or invalid -func (c *Config) GetInt(k string, d int) int { - r := c.GetString(k, strconv.Itoa(d)) - v, err := strconv.Atoi(r) - if err != nil { - return d - } - - return v -} - -// GetBool will get the bool for k or return the default d if not found or invalid -func (c *Config) GetBool(k string, d bool) bool { - r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) - v, err := strconv.ParseBool(r) - if err != nil { - switch r { - case "y", "yes": - return true - case "n", "no": - return false - } - return d - } - - return v -} - -// GetDuration will get the duration for k or return the default d if not found or invalid -func (c *Config) GetDuration(k string, d time.Duration) time.Duration { - r := c.GetString(k, "") - v, err := time.ParseDuration(r) - if err != nil { - return d - } - return v -} - -func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) { - var nameRules []AllowListNameRule - handleKey := func(key string, value interface{}) (bool, error) { - if key == "interfaces" { - var err error - nameRules, err = c.getAllowListInterfaces(k, value) - if err != nil { - return false, err - } - - return true, nil - } - return false, nil - } - - al, err := c.GetAllowList(k, handleKey) - if err != nil { - return nil, err - } - return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil -} - -func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) { - al, err := c.GetAllowList(k, nil) - if err != nil { - return nil, err - } - remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey) - if err != nil { - return nil, err - } - return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil -} - -func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) { - value := c.Get(k) - if value == nil { - return nil, nil - } - - remoteAllowRanges := NewCIDR6Tree() - - rawMap, ok := value.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) - } - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - - allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) - if err != nil { - return nil, err - } - - _, cidr, err := net.ParseCIDR(rawCIDR) - if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) - } - - remoteAllowRanges.AddCIDR(cidr, allowList) - } - - return remoteAllowRanges, nil -} - -// If the handleKey func returns true, the rest of the parsing is skipped -// for this key. This allows parsing of special values like `interfaces`. -func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { - r := c.Get(k) - if r == nil { - return nil, nil - } - - return c.getAllowList(k, r, handleKey) -} - -// If the handleKey func returns true, the rest of the parsing is skipped -// for this key. This allows parsing of special values like `interfaces`. -func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { - rawMap, ok := raw.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) - } - - tree := NewCIDR6Tree() - - // Keep track of the rules we have added for both ipv4 and ipv6 - type allowListRules struct { - firstValue bool - allValuesMatch bool - defaultSet bool - allValues bool - } - rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} - rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} - - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - - if handleKey != nil { - handled, err := handleKey(rawCIDR, rawValue) - if err != nil { - return nil, err - } - if handled { - continue - } - } - - value, ok := rawValue.(bool) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) - } - - _, cidr, err := net.ParseCIDR(rawCIDR) - if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) - } - - // TODO: should we error on duplicate CIDRs in the config? - tree.AddCIDR(cidr, value) - - maskBits, maskSize := cidr.Mask.Size() - - var rules *allowListRules - if maskSize == 32 { - rules = &rules4 - } else { - rules = &rules6 - } - - if rules.firstValue { - rules.allValues = value - rules.firstValue = false - } else { - if value != rules.allValues { - rules.allValuesMatch = false - } - } - - // Check if this is 0.0.0.0/0 or ::/0 - if maskBits == 0 { - rules.defaultSet = true - } - } - - if !rules4.defaultSet { - if rules4.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") - tree.AddCIDR(zeroCIDR, !rules4.allValues) - } else { - return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) - } - } - - if !rules6.defaultSet { - if rules6.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("::/0") - tree.AddCIDR(zeroCIDR, !rules6.allValues) - } else { - return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) - } - } - - return &AllowList{cidrTree: tree}, nil -} - -func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { - var nameRules []AllowListNameRule - - rawRules, ok := v.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) - } - - firstEntry := true - var allValues bool - for rawName, rawAllow := range rawRules { - name, ok := rawName.(string) - if !ok { - return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName) - } - allow, ok := rawAllow.(bool) - if !ok { - return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) - } - - nameRE, err := regexp.Compile("^" + name + "$") - if err != nil { - return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err) - } - - nameRules = append(nameRules, AllowListNameRule{ - Name: nameRE, - Allow: allow, - }) - - if firstEntry { - allValues = allow - firstEntry = false - } else { - if allow != allValues { - return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k) - } - } - } - - return nameRules, nil -} - -func (c *Config) Get(k string) interface{} { - return c.get(k, c.Settings) -} - -func (c *Config) IsSet(k string) bool { - return c.get(k, c.Settings) != nil -} - -func (c *Config) get(k string, v interface{}) interface{} { - parts := strings.Split(k, ".") - for _, p := range parts { - m, ok := v.(map[interface{}]interface{}) - if !ok { - return nil - } - - v, ok = m[p] - if !ok { - return nil - } - } - - return v -} - -// direct signifies if this is the config path directly specified by the user, -// versus a file/dir found by recursing into that path -func (c *Config) resolve(path string, direct bool) error { - i, err := os.Stat(path) - if err != nil { - return nil - } - - if !i.IsDir() { - c.addFile(path, direct) - return nil - } - - paths, err := readDirNames(path) - if err != nil { - return fmt.Errorf("problem while reading directory %s: %s", path, err) - } - - for _, p := range paths { - err := c.resolve(filepath.Join(path, p), false) - if err != nil { - return err - } - } - - return nil -} - -func (c *Config) addFile(path string, direct bool) error { - ext := filepath.Ext(path) - - if !direct && ext != ".yaml" && ext != ".yml" { - return nil - } - - ap, err := filepath.Abs(path) - if err != nil { - return err - } - - c.files = append(c.files, ap) - return nil -} - -func (c *Config) parseRaw(b []byte) error { - var m map[interface{}]interface{} - - err := yaml.Unmarshal(b, &m) - if err != nil { - return err - } - - c.Settings = m - return nil -} - -func (c *Config) parse() error { - var m map[interface{}]interface{} - - for _, path := range c.files { - b, err := ioutil.ReadFile(path) - if err != nil { - return err - } - - var nm map[interface{}]interface{} - err = yaml.Unmarshal(b, &nm) - if err != nil { - return err - } - - // We need to use WithAppendSlice so that firewall rules in separate - // files are appended together - err = mergo.Merge(&nm, m, mergo.WithAppendSlice) - m = nm - if err != nil { - return err - } - } - - c.Settings = m - return nil -} - -func readDirNames(path string) ([]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - - paths, err := f.Readdirnames(-1) - f.Close() - if err != nil { - return nil, err - } - - sort.Strings(paths) - return paths, nil -} - -func configLogger(c *Config) error { - // set up our logging level - logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) - if err != nil { - return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) - } - c.l.SetLevel(logLevel) - - disableTimestamp := c.GetBool("logging.disable_timestamp", false) - timestampFormat := c.GetString("logging.timestamp_format", "") - fullTimestamp := (timestampFormat != "") - if timestampFormat == "" { - timestampFormat = time.RFC3339 - } - - logFormat := strings.ToLower(c.GetString("logging.format", "text")) - switch logFormat { - case "text": - c.l.Formatter = &logrus.TextFormatter{ - TimestampFormat: timestampFormat, - FullTimestamp: fullTimestamp, - DisableTimestamp: disableTimestamp, - } - case "json": - c.l.Formatter = &logrus.JSONFormatter{ - TimestampFormat: timestampFormat, - DisableTimestamp: disableTimestamp, - } - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) - } - - return nil -} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..2328007 --- /dev/null +++ b/config/config.go @@ -0,0 +1,358 @@ +package config + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "os" + "os/signal" + "path/filepath" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/imdario/mergo" + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" +) + +type C struct { + path string + files []string + Settings map[interface{}]interface{} + oldSettings map[interface{}]interface{} + callbacks []func(*C) + l *logrus.Logger +} + +func NewC(l *logrus.Logger) *C { + return &C{ + Settings: make(map[interface{}]interface{}), + l: l, + } +} + +// Load will find all yaml files within path and load them in lexical order +func (c *C) Load(path string) error { + c.path = path + c.files = make([]string, 0) + + err := c.resolve(path, true) + if err != nil { + return err + } + + if len(c.files) == 0 { + return fmt.Errorf("no config files found at %s", path) + } + + sort.Strings(c.files) + + err = c.parse() + if err != nil { + return err + } + + return nil +} + +func (c *C) LoadString(raw string) error { + if raw == "" { + return errors.New("Empty configuration") + } + return c.parseRaw([]byte(raw)) +} + +// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered +// here should decide if they need to make a change to the current process before making the change. HasChanged can be +// used to help decide if a change is necessary. +// These functions should return quickly or spawn their own go routine if they will take a while +func (c *C) RegisterReloadCallback(f func(*C)) { + c.callbacks = append(c.callbacks, f) +} + +// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of +// k in both the old and new settings will be serialized, the result of the string comparison is returned. +// If k is an empty string the entire config is tested. +// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating +// there is change when there actually wasn't any. +func (c *C) HasChanged(k string) bool { + if c.oldSettings == nil { + return false + } + + var ( + nv interface{} + ov interface{} + ) + + if k == "" { + nv = c.Settings + ov = c.oldSettings + k = "all settings" + } else { + nv = c.get(k, c.Settings) + ov = c.get(k, c.oldSettings) + } + + newVals, err := yaml.Marshal(nv) + if err != nil { + c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + } + + oldVals, err := yaml.Marshal(ov) + if err != nil { + c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + } + + return string(newVals) != string(oldVals) +} + +// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the +// original path provided to Load. The old settings are shallow copied for change detection after the reload. +func (c *C) CatchHUP(ctx context.Context) { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGHUP) + + go func() { + for { + select { + case <-ctx.Done(): + signal.Stop(ch) + close(ch) + return + case <-ch: + c.l.Info("Caught HUP, reloading config") + c.ReloadConfig() + } + } + }() +} + +func (c *C) ReloadConfig() { + c.oldSettings = make(map[interface{}]interface{}) + for k, v := range c.Settings { + c.oldSettings[k] = v + } + + err := c.Load(c.path) + if err != nil { + c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + return + } + + for _, v := range c.callbacks { + v(c) + } +} + +// GetString will get the string for k or return the default d if not found or invalid +func (c *C) GetString(k, d string) string { + r := c.Get(k) + if r == nil { + return d + } + + return fmt.Sprintf("%v", r) +} + +// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid +func (c *C) GetStringSlice(k string, d []string) []string { + r := c.Get(k) + if r == nil { + return d + } + + rv, ok := r.([]interface{}) + if !ok { + return d + } + + v := make([]string, len(rv)) + for i := 0; i < len(v); i++ { + v[i] = fmt.Sprintf("%v", rv[i]) + } + + return v +} + +// GetMap will get the map for k or return the default d if not found or invalid +func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { + r := c.Get(k) + if r == nil { + return d + } + + v, ok := r.(map[interface{}]interface{}) + if !ok { + return d + } + + return v +} + +// GetInt will get the int for k or return the default d if not found or invalid +func (c *C) GetInt(k string, d int) int { + r := c.GetString(k, strconv.Itoa(d)) + v, err := strconv.Atoi(r) + if err != nil { + return d + } + + return v +} + +// GetBool will get the bool for k or return the default d if not found or invalid +func (c *C) GetBool(k string, d bool) bool { + r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) + v, err := strconv.ParseBool(r) + if err != nil { + switch r { + case "y", "yes": + return true + case "n", "no": + return false + } + return d + } + + return v +} + +// GetDuration will get the duration for k or return the default d if not found or invalid +func (c *C) GetDuration(k string, d time.Duration) time.Duration { + r := c.GetString(k, "") + v, err := time.ParseDuration(r) + if err != nil { + return d + } + return v +} + +func (c *C) Get(k string) interface{} { + return c.get(k, c.Settings) +} + +func (c *C) IsSet(k string) bool { + return c.get(k, c.Settings) != nil +} + +func (c *C) get(k string, v interface{}) interface{} { + parts := strings.Split(k, ".") + for _, p := range parts { + m, ok := v.(map[interface{}]interface{}) + if !ok { + return nil + } + + v, ok = m[p] + if !ok { + return nil + } + } + + return v +} + +// direct signifies if this is the config path directly specified by the user, +// versus a file/dir found by recursing into that path +func (c *C) resolve(path string, direct bool) error { + i, err := os.Stat(path) + if err != nil { + return nil + } + + if !i.IsDir() { + c.addFile(path, direct) + return nil + } + + paths, err := readDirNames(path) + if err != nil { + return fmt.Errorf("problem while reading directory %s: %s", path, err) + } + + for _, p := range paths { + err := c.resolve(filepath.Join(path, p), false) + if err != nil { + return err + } + } + + return nil +} + +func (c *C) addFile(path string, direct bool) error { + ext := filepath.Ext(path) + + if !direct && ext != ".yaml" && ext != ".yml" { + return nil + } + + ap, err := filepath.Abs(path) + if err != nil { + return err + } + + c.files = append(c.files, ap) + return nil +} + +func (c *C) parseRaw(b []byte) error { + var m map[interface{}]interface{} + + err := yaml.Unmarshal(b, &m) + if err != nil { + return err + } + + c.Settings = m + return nil +} + +func (c *C) parse() error { + var m map[interface{}]interface{} + + for _, path := range c.files { + b, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + var nm map[interface{}]interface{} + err = yaml.Unmarshal(b, &nm) + if err != nil { + return err + } + + // We need to use WithAppendSlice so that firewall rules in separate + // files are appended together + err = mergo.Merge(&nm, m, mergo.WithAppendSlice) + m = nm + if err != nil { + return err + } + } + + c.Settings = m + return nil +} + +func readDirNames(path string) ([]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + paths, err := f.Readdirnames(-1) + f.Close() + if err != nil { + return nil, err + } + + sort.Strings(paths) + return paths, nil +} diff --git a/config_test.go b/config/config_test.go similarity index 54% rename from config_test.go rename to config/config_test.go index 84848b8..a5254bd 100644 --- a/config_test.go +++ b/config/config_test.go @@ -1,4 +1,4 @@ -package nebula +package config import ( "io/ioutil" @@ -7,19 +7,20 @@ import ( "testing" "time" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func TestConfig_Load(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() dir, err := ioutil.TempDir("", "config-test") // invalid yaml - c := NewConfig(l) + c := NewC(l) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge - c = NewConfig(l) + c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) @@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) { } func TestConfig_Get(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() // test simple type - c := NewConfig(l) + c := NewC(l) c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) @@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) { } func TestConfig_GetStringSlice(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) + l := util.NewTestLogger() + c := NewC(l) c.Settings["slice"] = []interface{}{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } func TestConfig_GetBool(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) + l := util.NewTestLogger() + c := NewC(l) c.Settings["bool"] = true assert.Equal(t, true, c.GetBool("bool", false)) @@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) { assert.Equal(t, false, c.GetBool("bool", true)) } -func TestConfig_GetAllowList(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) - c.Settings["allowlist"] = map[interface{}]interface{}{ - "192.168.0.0": true, - } - r, err := c.GetAllowList("allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") - assert.Nil(t, r) - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "192.168.0.0/16": "abc", - } - r, err = c.GetAllowList("allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "192.168.0.0/16": true, - "10.0.0.0/8": false, - } - r, err = c.GetAllowList("allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "0.0.0.0/0": true, - "10.0.0.0/8": false, - "10.42.42.0/24": true, - "fd00::/8": true, - "fd00:fd00::/16": false, - } - r, err = c.GetAllowList("allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "0.0.0.0/0": true, - "10.0.0.0/8": false, - "10.42.42.0/24": true, - } - r, err = c.GetAllowList("allowlist", nil) - if assert.NoError(t, err) { - assert.NotNil(t, r) - } - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "0.0.0.0/0": true, - "10.0.0.0/8": false, - "10.42.42.0/24": true, - "::/0": false, - "fd00::/8": true, - "fd00:fd00::/16": false, - } - r, err = c.GetAllowList("allowlist", nil) - if assert.NoError(t, err) { - assert.NotNil(t, r) - } - - // Test interface names - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ - `docker.*`: "foo", - }, - } - lr, err := c.GetLocalAllowList("allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ - `docker.*`: false, - `eth.*`: true, - }, - } - lr, err = c.GetLocalAllowList("allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") - - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ - `docker.*`: false, - }, - } - lr, err = c.GetLocalAllowList("allowlist") - if assert.NoError(t, err) { - assert.NotNil(t, lr) - } -} - func TestConfig_HasChanged(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() // No reload has occurred, return false - c := NewConfig(l) + c := NewC(l) c.Settings["test"] = "hi" assert.False(t, c.HasChanged("")) // Test key change - c = NewConfig(l) + c = NewC(l) c.Settings["test"] = "hi" c.oldSettings = map[interface{}]interface{}{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change - c = NewConfig(l) + c = NewC(l) c.Settings["test"] = "hi" c.oldSettings = map[interface{}]interface{}{"test": "hi"} assert.False(t, c.HasChanged("test")) @@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) { } func TestConfig_ReloadConfig(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() done := make(chan bool, 1) dir, err := ioutil.TempDir("", "config-test") assert.Nil(t, err) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) - c := NewConfig(l) + c := NewC(l) assert.Nil(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) @@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) { ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) - c.RegisterReloadCallback(func(c *Config) { + c.RegisterReloadCallback(func(c *C) { done <- true }) diff --git a/connection_manager.go b/connection_manager.go index de9b165..c480bbb 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -6,6 +6,8 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" ) // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet @@ -13,16 +15,16 @@ import ( type connectionManager struct { hostMap *HostMap - in map[uint32]struct{} + in map[iputil.VpnIp]struct{} inLock *sync.RWMutex inCount int - out map[uint32]struct{} + out map[iputil.VpnIp]struct{} outLock *sync.RWMutex outCount int TrafficTimer *SystemTimerWheel intf *Interface - pendingDeletion map[uint32]int + pendingDeletion map[iputil.VpnIp]int pendingDeletionLock *sync.RWMutex pendingDeletionTimer *SystemTimerWheel @@ -36,15 +38,15 @@ type connectionManager struct { func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { nc := &connectionManager{ hostMap: intf.hostMap, - in: make(map[uint32]struct{}), + in: make(map[iputil.VpnIp]struct{}), inLock: &sync.RWMutex{}, inCount: 0, - out: make(map[uint32]struct{}), + out: make(map[iputil.VpnIp]struct{}), outLock: &sync.RWMutex{}, outCount: 0, TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), intf: intf, - pendingDeletion: make(map[uint32]int), + pendingDeletion: make(map[iputil.VpnIp]int), pendingDeletionLock: &sync.RWMutex{}, pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), checkInterval: checkInterval, @@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface return nc } -func (n *connectionManager) In(ip uint32) { +func (n *connectionManager) In(ip iputil.VpnIp) { n.inLock.RLock() // If this already exists, return if _, ok := n.in[ip]; ok { @@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) { n.inLock.Unlock() } -func (n *connectionManager) Out(ip uint32) { +func (n *connectionManager) Out(ip iputil.VpnIp) { n.outLock.RLock() // If this already exists, return if _, ok := n.out[ip]; ok { @@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) { n.outLock.Unlock() } -func (n *connectionManager) CheckIn(vpnIP uint32) bool { +func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool { n.inLock.RLock() - if _, ok := n.in[vpnIP]; ok { + if _, ok := n.in[vpnIp]; ok { n.inLock.RUnlock() return true } @@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool { return false } -func (n *connectionManager) ClearIP(ip uint32) { +func (n *connectionManager) ClearIP(ip iputil.VpnIp) { n.inLock.Lock() n.outLock.Lock() delete(n.in, ip) @@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) { n.outLock.Unlock() } -func (n *connectionManager) ClearPendingDeletion(ip uint32) { +func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) { n.pendingDeletionLock.Lock() delete(n.pendingDeletion, ip) n.pendingDeletionLock.Unlock() } -func (n *connectionManager) AddPendingDeletion(ip uint32) { +func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) { n.pendingDeletionLock.Lock() if _, ok := n.pendingDeletion[ip]; ok { n.pendingDeletion[ip] += 1 @@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) { n.pendingDeletionLock.Unlock() } -func (n *connectionManager) checkPendingDeletion(ip uint32) bool { +func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool { n.pendingDeletionLock.RLock() if _, ok := n.pendingDeletion[ip]; ok { @@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool { return false } -func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) { - n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds)) +func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) { + n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds)) } func (n *connectionManager) Start(ctx context.Context) { @@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) break } - vpnIP := ep.(uint32) + vpnIp := ep.(iputil.VpnIp) // Check for traffic coming back in from this host. - traf := n.CheckIn(vpnIP) + traf := n.CheckIn(vpnIp) - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + hostinfo, err := n.hostMap.QueryVpnIp(vpnIp) if err != nil { - n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + n.l.Debugf("Not found in hostmap: %s", vpnIp) if !n.intf.disconnectInvalid { - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) continue } } - if n.handleInvalidCertificate(now, vpnIP, hostinfo) { + if n.handleInvalidCertificate(now, vpnIp, hostinfo) { continue } @@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) // expired, just ignore. if traf { if n.l.Level >= logrus.DebugLevel { - n.l.WithField("vpnIp", IntIp(vpnIP)). + n.l.WithField("vpnIp", vpnIp). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) continue } @@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) if hostinfo != nil && hostinfo.ConnectionState != nil { // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out) + n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out) } else { - hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) + hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp) } - n.AddPendingDeletion(vpnIP) + n.AddPendingDeletion(vpnIp) } } @@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { break } - vpnIP := ep.(uint32) + vpnIp := ep.(iputil.VpnIp) - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + hostinfo, err := n.hostMap.QueryVpnIp(vpnIp) if err != nil { - n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + n.l.Debugf("Not found in hostmap: %s", vpnIp) if !n.intf.disconnectInvalid { - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) continue } } - if n.handleInvalidCertificate(now, vpnIP, hostinfo) { + if n.handleInvalidCertificate(now, vpnIp, hostinfo) { continue } // If we saw an incoming packets from this ip and peer's certificate is not // expired, just ignore. - traf := n.CheckIn(vpnIP) + traf := n.CheckIn(vpnIp) if traf { - n.l.WithField("vpnIp", IntIp(vpnIP)). + n.l.WithField("vpnIp", vpnIp). WithField("tunnelCheck", m{"state": "alive", "method": "active"}). Debug("Tunnel status") - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) continue } // If it comes around on deletion wheel and hasn't resolved itself, delete - if n.checkPendingDeletion(vpnIP) { + if n.checkPendingDeletion(vpnIp) { cn := "" if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { cn = hostinfo.ConnectionState.peerCert.Details.Name @@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { WithField("certName", cn). Info("Tunnel status") - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) // TODO: This is only here to let tests work. Should do proper mocking if n.intf.lightHouse != nil { - n.intf.lightHouse.DeleteVpnIP(vpnIP) + n.intf.lightHouse.DeleteVpnIp(vpnIp) } n.hostMap.DeleteHostInfo(hostinfo) } else { - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) } } } // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid -func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool { +func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool { if !n.intf.disconnectInvalid { return false } @@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32 } fingerprint, _ := remoteCert.Sha256Sum() - n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err). + n.l.WithField("vpnIp", vpnIp).WithError(err). WithField("certName", remoteCert.Details.Name). WithField("fingerprint", fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") @@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32 n.intf.sendCloseTunnel(hostinfo) n.intf.closeTunnel(hostinfo, false) - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) + n.ClearIP(vpnIp) + n.ClearPendingDeletion(vpnIp) return true } diff --git a/connection_manager_test.go b/connection_manager_test.go index fa88640..9da6ddc 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,17 +10,20 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) -var vpnIP uint32 +var vpnIp iputil.VpnIp func Test_NewConnectionManagerTest(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIP = ip2int(net.ParseIP("172.1.1.2")) + vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects @@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, - outside: &udpConn{}, + outside: &udp.Conn{}, certState: cs, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), l: l, } now := time.Now() @@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) { out := make([]byte, mtu) nc.HandleMonitorTick(now, p, nb, out) // Add an ip we have established a connection w/ to hostmap - hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo := nc.hostMap.AddVpnIp(vpnIp) hostinfo.ConnectionState = &ConnectionState{ certState: cs, H: &noise.HandshakeState{}, } - // We saw traffic out to vpnIP - nc.Out(vpnIP) - assert.NotContains(t, nc.pendingDeletion, vpnIP) - assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // We saw traffic out to vpnIp + nc.Out(vpnIp) + assert.NotContains(t, nc.pendingDeletion, vpnIp) + assert.Contains(t, nc.hostMap.Hosts, vpnIp) // Move ahead 5s. Nothing should happen next_tick := now.Add(5 * time.Second) nc.HandleMonitorTick(next_tick, p, nb, out) @@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleDeletionTick(next_tick) // This host should now be up for deletion - assert.Contains(t, nc.pendingDeletion, vpnIP) - assert.Contains(t, nc.hostMap.Hosts, vpnIP) + assert.Contains(t, nc.pendingDeletion, vpnIp) + assert.Contains(t, nc.hostMap.Hosts, vpnIp) // Move ahead some more next_tick = now.Add(45 * time.Second) nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleDeletionTick(next_tick) // The host should be evicted - assert.NotContains(t, nc.pendingDeletion, vpnIP) - assert.NotContains(t, nc.hostMap.Hosts, vpnIP) + assert.NotContains(t, nc.pendingDeletion, vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, vpnIp) } func Test_NewConnectionManagerTest2(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, - outside: &udpConn{}, + outside: &udp.Conn{}, certState: cs, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), l: l, } now := time.Now() @@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) { out := make([]byte, mtu) nc.HandleMonitorTick(now, p, nb, out) // Add an ip we have established a connection w/ to hostmap - hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo := nc.hostMap.AddVpnIp(vpnIp) hostinfo.ConnectionState = &ConnectionState{ certState: cs, H: &noise.HandshakeState{}, } - // We saw traffic out to vpnIP - nc.Out(vpnIP) - assert.NotContains(t, nc.pendingDeletion, vpnIP) - assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // We saw traffic out to vpnIp + nc.Out(vpnIp) + assert.NotContains(t, nc.pendingDeletion, vpnIp) + assert.Contains(t, nc.hostMap.Hosts, vpnIp) // Move ahead 5s. Nothing should happen next_tick := now.Add(5 * time.Second) nc.HandleMonitorTick(next_tick, p, nb, out) @@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) { nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleDeletionTick(next_tick) // This host should now be up for deletion - assert.Contains(t, nc.pendingDeletion, vpnIP) - assert.Contains(t, nc.hostMap.Hosts, vpnIP) + assert.Contains(t, nc.pendingDeletion, vpnIp) + assert.Contains(t, nc.hostMap.Hosts, vpnIp) // We heard back this time - nc.In(vpnIP) + nc.In(vpnIp) // Move ahead some more next_tick = now.Add(45 * time.Second) nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleDeletionTick(next_tick) // The host should be evicted - assert.NotContains(t, nc.pendingDeletion, vpnIP) - assert.Contains(t, nc.hostMap.Hosts, vpnIP) + assert.NotContains(t, nc.pendingDeletion, vpnIp) + assert.Contains(t, nc.hostMap.Hosts, vpnIp) } @@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Disconnect only if disconnectInvalid: true is set. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() - l := NewTestLogger() + l := util.NewTestLogger() ipNet := net.IPNet{ IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, @@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, - outside: &udpConn{}, + outside: &udp.Conn{}, certState: cs, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), l: l, disconnectInvalid: true, caPool: ncp, @@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { defer cancel() nc := newConnectionManager(ctx, l, ifce, 5, 10) ifce.connectionManager = nc - hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo := nc.hostMap.AddVpnIp(vpnIp) hostinfo.ConnectionState = &ConnectionState{ certState: cs, peerCert: &peerCert, @@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Check if to disconnect with invalid certificate. // Should be alive. nextTick := now.Add(45 * time.Second) - destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) + destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo) assert.False(t, destroyed) // Move ahead 61s. // Check if to disconnect with invalid certificate. // Should be disconnected. nextTick = now.Add(61 * time.Second) - destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) + destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo) assert.True(t, destroyed) } diff --git a/control.go b/control.go index 4bbe65f..8c93be2 100644 --- a/control.go +++ b/control.go @@ -10,6 +10,9 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -25,14 +28,14 @@ type Control struct { } type ControlHostInfo struct { - VpnIP net.IP `json:"vpnIp"` + VpnIp net.IP `json:"vpnIp"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udpAddr `json:"remoteAddrs"` + RemoteAddrs []*udp.Addr `json:"remoteAddrs"` CachedPackets int `json:"cachedPackets"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udpAddr `json:"currentRemote"` + CurrentRemote *udp.Addr `json:"currentRemote"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { } } -// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo { +// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found +func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { var hm *HostMap if pending { hm = c.f.handshakeManager.pendingHostMap @@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf hm = c.f.hostMap } - h, err := hm.QueryVpnIP(vpnIP) + h, err := hm.QueryVpnIp(vpnIp) if err != nil { return nil } @@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo { - hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) +func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { + hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) if err != nil { return nil } @@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool { - hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) +func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { + hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) if err != nil { return false } if !localOnly { c.f.send( - closeTunnel, + header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, @@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { c.f.hostMap.Lock() for _, h := range c.f.hostMap.Hosts { if excludeLighthouses { - if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok { + if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok { continue } } if h.ConnectionState.ready { - c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h, true) - c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote). + c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } @@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi := ControlHostInfo{ - VpnIP: int2ip(h.hostId), + VpnIp: h.vpnIp.ToIP(), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), diff --git a/control_test.go b/control_test.go index 5679ce6..08aa151 100644 --- a/control_test.go +++ b/control_test.go @@ -8,17 +8,19 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) -func TestControl_GetHostInfoByVpnIP(t *testing.T) { - l := NewTestLogger() +func TestControl_GetHostInfoByVpnIp(t *testing.T) { + l := util.NewTestLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) - remote1 := NewUDPAddr(int2ip(100), 4444) - remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) + remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) + remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ IP: net.IPv4(1, 2, 3, 4), Mask: net.IPMask{255, 255, 255, 0}, @@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { remotes := NewRemoteList() remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) - hm.Add(ip2int(ipNet.IP), &HostInfo{ + hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), }) - hm.Add(ip2int(ipNet2.IP), &HostInfo{ + hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - hostId: ip2int(ipNet2.IP), + vpnIp: iputil.Ip2VpnIp(ipNet2.IP), }) c := Control{ @@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false) + thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) expectedInfo := ControlHostInfo{ - VpnIP: net.IPv4(1, 2, 3, 4).To4(), + VpnIp: net.IPv4(1, 2, 3, 4).To4(), LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udpAddr{remote2, remote1}, + RemoteAddrs: []*udp.Addr{remote2, remote1}, CachedPackets: 0, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: NewUDPAddr(int2ip(100), 4444), + CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) + assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) util.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false) + thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) }) } diff --git a/control_tester.go b/control_tester.go index df1d2e7..2c7752e 100644 --- a/control_tester.go +++ b/control_tester.go @@ -8,12 +8,15 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) // WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device // returning after a message matching the criteria has been piped -func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) { - h := &Header{} +func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { + h := &header.H{} for { p := c.f.outside.Get(true) if err := h.Parse(p.Data); err != nil { @@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu // WaitForTypeByIndex is similar to WaitForType except it adds an index check // Useful if you have many nodes communicating and want to wait to find a specific nodes packet -func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) { - h := &Header{} +func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { + h := &header.H{} for { p := c.f.outside.Get(true) if err := h.Parse(p.Data); err != nil { @@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := ip2int(vpnIp) + iVpnIp := iputil.Ip2VpnIp(vpnIp) if v4 := toAddr.IP.To4(); v4 != nil { remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) } else { @@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte { } // GetFromUDP will pull a udp packet off the udp side of nebula -func (c *Control) GetFromUDP(block bool) *UdpPacket { +func (c *Control) GetFromUDP(block bool) *udp.Packet { return c.f.outside.Get(block) } -func (c *Control) GetUDPTxChan() <-chan *UdpPacket { - return c.f.outside.txPackets +func (c *Control) GetUDPTxChan() <-chan *udp.Packet { + return c.f.outside.TxPackets } func (c *Control) GetTunTxChan() <-chan []byte { @@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte { } // InjectUDPPacket will inject a packet into the udp side of nebula -func (c *Control) InjectUDPPacket(p *UdpPacket) { +func (c *Control) InjectUDPPacket(p *udp.Packet) { c.f.outside.Send(p) } @@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 } func (c *Control) GetUDPAddr() string { - return c.f.outside.addr.String() + return c.f.outside.Addr.String() } func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)] + hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)] if !ok { return false } diff --git a/dns_server.go b/dns_server.go index eef5486..dd7e30e 100644 --- a/dns_server.go +++ b/dns_server.go @@ -8,6 +8,8 @@ import ( "github.com/miekg/dns" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string { if ip == nil { return "" } - iip := ip2int(ip) - hostinfo, err := d.hostMap.QueryVpnIP(iip) + iip := iputil.Ip2VpnIp(ip) + hostinfo, err := d.hostMap.QueryVpnIp(iip) if err != nil { return "" } @@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() { +func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { dnsR = newDnsRecords(hostMap) // attach request handler func @@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() { handleDnsRequest(l, w, r) }) - c.RegisterReloadCallback(func(c *Config) { + c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) }) @@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() { } } -func getDnsServerAddr(c *Config) string { +func getDnsServerAddr(c *config.C) string { return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) } -func startDns(l *logrus.Logger, c *Config) { +func startDns(l *logrus.Logger, c *config.C) { dnsAddr = getDnsServerAddr(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder") @@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) { } } -func reloadDns(l *logrus.Logger, c *Config) { +func reloadDns(l *logrus.Logger, c *config.C) { if dnsAddr == getDnsServerAddr(c) { l.Debug("No DNS server config change detected") return diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index a4a0ea0..9f239c2 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -10,6 +10,9 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) @@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) { t.Log("I consume a garbage packet with a proper nebula header for our tunnel") // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel badPacket := stage1Packet.Copy() - badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen] + badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len] myControl.InjectUDPPacket(badPacket) t.Log("Have me consume their real stage 1 packet. I have a tunnel now") @@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) - r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType { - h := &nebula.Header{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + h := &header.H{} err := h.Parse(p.Data) if err != nil { panic(err) @@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) { r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 35938cf..843e08c 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -5,7 +5,6 @@ package e2e import ( "crypto/rand" - "encoding/binary" "fmt" "io" "io/ioutil" @@ -19,7 +18,9 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" @@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - config := nebula.NewConfig(l) - config.LoadString(string(cb)) + c := config.NewC(l) + c.LoadString(string(cb)) - control, err := nebula.Main(config, false, "e2e-test", l, nil) + control, err := nebula.Main(c, false, "e2e-test", l, nil) if err != nil { panic(err) @@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } -func ip2int(ip []byte) uint32 { - if len(ip) == 16 { - return binary.BigEndian.Uint32(ip[12:16]) - } - return binary.BigEndian.Uint32(ip) -} - -func int2ip(nn uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, nn) - return ip -} - type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { @@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA") + hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") - hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB") + hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B") + assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") + assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") diff --git a/e2e/router/router.go b/e2e/router/router.go index b53bd94..ac8e02c 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -11,6 +11,8 @@ import ( "sync" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) type R struct { @@ -41,7 +43,7 @@ const ( RouteAndExit ExitType = 2 ) -type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType +type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType func NewR(controls ...*nebula.Control) *R { r := &R{ @@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { // OnceFrom will route a single packet from sender then return // If the router doesn't have the nebula controller for that address, we panic func (r *R) OnceFrom(sender *nebula.Control) { - r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType { + r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType { return RouteAndExit }) } @@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // - routeAndExit: this call will return immediately after routing the last packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { - h := &nebula.Header{} + h := &header.H{} for { p := sender.GetFromUDP(true) r.Lock() @@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) { - h := &nebula.Header{} - r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { +func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) { + h := &header.H{} + r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { if err := h.Parse(p.Data); err != nil { panic(err) } @@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr finish = RouteAndExit } - r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { + r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { return finish } @@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { x, rx, _ := reflect.Select(sc) r.Lock() - p := rx.Interface().(*nebula.UdpPacket) + p := rx.Interface().(*udp.Packet) outAddr := cm[x].GetUDPAddr() inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) @@ -277,7 +279,7 @@ func (r *R) FlushAll() { } r.Lock() - p := rx.Interface().(*nebula.UdpPacket) + p := rx.Interface().(*udp.Packet) outAddr := cm[x].GetUDPAddr() inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) @@ -292,7 +294,7 @@ func (r *R) FlushAll() { // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control { +func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { p.FromIp = newAddr.IP p.FromPort = uint16(newAddr.Port) diff --git a/firewall.go b/firewall.go index 5716894..dfc7fd1 100644 --- a/firewall.go +++ b/firewall.go @@ -4,7 +4,6 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" - "encoding/json" "errors" "fmt" "net" @@ -12,22 +11,14 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" -) - -const ( - fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - fwProtoTCP = 6 - fwProtoUDP = 17 - fwProtoICMP = 1 - - fwPortAny = 0 // Special value for matching `port: any` - fwPortFragment = -1 // Special value for matching `port: fragment` + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" ) const tcpACK = 0x10 @@ -63,7 +54,7 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *CIDRTree + localIps *cidr.Tree4 rules string rulesVersion uint16 @@ -85,7 +76,7 @@ type firewallMetrics struct { type FirewallConntrack struct { sync.Mutex - Conns map[FirewallPacket]*conn + Conns map[firewall.Packet]*conn TimerWheel *TimerWheel } @@ -116,55 +107,13 @@ type FirewallRule struct { Any bool Hosts map[string]struct{} Groups [][]string - CIDR *CIDRTree + CIDR *cidr.Tree4 } // Even though ports are uint16, int32 maps are faster for lookup // Plus we can use `-1` for fragment rules type firewallPort map[int32]*FirewallCA -type FirewallPacket struct { - LocalIP uint32 - RemoteIP uint32 - LocalPort uint16 - RemotePort uint16 - Protocol uint8 - Fragment bool -} - -func (fp *FirewallPacket) Copy() *FirewallPacket { - return &FirewallPacket{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, - LocalPort: fp.LocalPort, - RemotePort: fp.RemotePort, - Protocol: fp.Protocol, - Fragment: fp.Fragment, - } -} - -func (fp FirewallPacket) MarshalJSON() ([]byte, error) { - var proto string - switch fp.Protocol { - case fwProtoTCP: - proto = "tcp" - case fwProtoICMP: - proto = "icmp" - case fwProtoUDP: - proto = "udp" - default: - proto = fmt.Sprintf("unknown %v", fp.Protocol) - } - return json.Marshal(m{ - "LocalIP": int2ip(fp.LocalIP).String(), - "RemoteIP": int2ip(fp.RemoteIP).String(), - "LocalPort": fp.LocalPort, - "RemotePort": fp.RemotePort, - "Protocol": proto, - "Fragment": fp.Fragment, - }) -} - // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { //TODO: error on 0 duration @@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := NewCIDRTree() + localIps := cidr.NewTree4() for _, ip := range c.Details.Ips { localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } @@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D return &Firewall{ Conntrack: &FirewallConntrack{ - Conns: make(map[FirewallPacket]*conn), + Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel(min, max), }, InRules: newFirewallTable(), @@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) { fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), @@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort } switch proto { - case fwProtoTCP: + case firewall.ProtoTCP: fp = ft.TCP - case fwProtoUDP: + case firewall.ProtoUDP: fp = ft.UDP - case fwProtoICMP: + case firewall.ProtoICMP: fp = ft.ICMP - case fwProtoAny: + case firewall.ProtoAny: fp = ft.AnyProto default: return fmt.Errorf("unknown protocol %v", proto) @@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string { return hex.EncodeToString(sum[:]) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, table = "firewall.outbound" } - r := config.Get(table) + r := c.Get(table) if r == nil { return nil } @@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, var proto uint8 switch r.Proto { case "any": - proto = fwProtoAny + proto = firewall.ProtoAny case "tcp": - proto = fwProtoTCP + proto = firewall.ProtoTCP case "udp": - proto = fwProtoUDP + proto = firewall.ProtoUDP case "icmp": - proto = fwProtoICMP + proto = firewall.ProtoICMP default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } @@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error { +func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(packet, fp, incoming, h, caPool, localCache) { return nil @@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host } } else { // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.hostId { + if fp.RemoteIP != h.vpnIp { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP } @@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) } -func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool { +func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H } switch fp.Protocol { - case fwProtoTCP: + case firewall.ProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) if incoming { f.checkTCPRTT(c, packet) } else { setTCPRTTTracking(c, packet) } - case fwProtoUDP: + case firewall.ProtoUDP: c.Expires = time.Now().Add(f.UDPTimeout) default: c.Expires = time.Now().Add(f.DefaultTimeout) @@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H return true } -func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { +func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) { var timeout time.Duration c := &conn{} switch fp.Protocol { - case fwProtoTCP: + case firewall.ProtoTCP: timeout = f.TCPTimeout if !incoming { setTCPRTTTracking(c, packet) } - case fwProtoUDP: + case firewall.ProtoUDP: timeout = f.UDPTimeout default: timeout = f.DefaultTimeout @@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! -func (f *Firewall) evict(p FirewallPacket) { +func (f *Firewall) evict(p firewall.Packet) { //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack @@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) { delete(conntrack.Conns, p) } -func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } switch p.Protocol { - case fwProtoTCP: + case firewall.ProtoTCP: if ft.TCP.match(p, incoming, c, caPool) { return true } - case fwProtoUDP: + case firewall.ProtoUDP: if ft.UDP.match(p, incoming, c, caPool) { return true } - case fwProtoICMP: + case firewall.ProtoICMP: if ft.ICMP.match(p, incoming, c, caPool) { return true } @@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, return nil } -func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false @@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert var port int32 if p.Fragment { - port = fwPortFragment + port = firewall.PortFragment } else if incoming { port = int32(p.LocalPort) } else { @@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert return true } - return fp[fwPortAny].match(p, c, caPool) + return fp[firewall.PortAny].match(p, c, caPool) } func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { @@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam return &FirewallRule{ Hosts: make(map[string]struct{}), Groups: make([][]string, 0), - CIDR: NewCIDRTree(), + CIDR: cidr.NewTree4(), } } @@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam return nil } -func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { if fc == nil { return false } @@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err // If it's any we need to wipe out any pre-existing rules to save on memory fr.Groups = make([][]string, 0) fr.Hosts = make(map[string]struct{}) - fr.CIDR = NewCIDRTree() + fr.CIDR = cidr.NewTree4() } else { if len(groups) > 0 { fr.Groups = append(fr.Groups, groups) @@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return false } -func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool { +func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool { if fr == nil { return false } @@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er func parsePort(s string) (startPort, endPort int32, err error) { if s == "any" { - startPort = fwPortAny - endPort = fwPortAny + startPort = firewall.PortAny + endPort = firewall.PortAny } else if s == "fragment" { - startPort = fwPortFragment - endPort = fwPortFragment + startPort = firewall.PortFragment + endPort = firewall.PortFragment } else if strings.Contains(s, `-`) { sPorts := strings.SplitN(s, `-`, 2) @@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) { startPort = int32(rStartPort) endPort = int32(rEndPort) - if startPort == fwPortAny { - endPort = fwPortAny + if startPort == firewall.PortAny { + endPort = firewall.PortAny } } else { @@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { c.Seq = 0 return true } - -// ConntrackCache is used as a local routine cache to know if a given flow -// has been seen in the conntrack table. -type ConntrackCache map[FirewallPacket]struct{} - -type ConntrackCacheTicker struct { - cacheV uint64 - cacheTick uint64 - - cache ConntrackCache -} - -func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { - if d == 0 { - return nil - } - - c := &ConntrackCacheTicker{ - cache: ConntrackCache{}, - } - - go c.tick(d) - - return c -} - -func (c *ConntrackCacheTicker) tick(d time.Duration) { - for { - time.Sleep(d) - atomic.AddUint64(&c.cacheTick, 1) - } -} - -// Get checks if the cache ticker has moved to the next version before returning -// the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { - if c == nil { - return nil - } - if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { - c.cacheV = tick - if ll := len(c.cache); ll > 0 { - if l.Level == logrus.DebugLevel { - l.WithField("len", ll).Debug("resetting conntrack cache") - } - c.cache = make(ConntrackCache, ll) - } - } - - return c.cache -} diff --git a/firewall/cache.go b/firewall/cache.go new file mode 100644 index 0000000..5560ab2 --- /dev/null +++ b/firewall/cache.go @@ -0,0 +1,59 @@ +package firewall + +import ( + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" +) + +// ConntrackCache is used as a local routine cache to know if a given flow +// has been seen in the conntrack table. +type ConntrackCache map[Packet]struct{} + +type ConntrackCacheTicker struct { + cacheV uint64 + cacheTick uint64 + + cache ConntrackCache +} + +func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { + if d == 0 { + return nil + } + + c := &ConntrackCacheTicker{ + cache: ConntrackCache{}, + } + + go c.tick(d) + + return c +} + +func (c *ConntrackCacheTicker) tick(d time.Duration) { + for { + time.Sleep(d) + atomic.AddUint64(&c.cacheTick, 1) + } +} + +// Get checks if the cache ticker has moved to the next version before returning +// the map. If it has moved, we reset the map. +func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { + if c == nil { + return nil + } + if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { + c.cacheV = tick + if ll := len(c.cache); ll > 0 { + if l.Level == logrus.DebugLevel { + l.WithField("len", ll).Debug("resetting conntrack cache") + } + c.cache = make(ConntrackCache, ll) + } + } + + return c.cache +} diff --git a/firewall/packet.go b/firewall/packet.go new file mode 100644 index 0000000..1c4affd --- /dev/null +++ b/firewall/packet.go @@ -0,0 +1,62 @@ +package firewall + +import ( + "encoding/json" + "fmt" + + "github.com/slackhq/nebula/iputil" +) + +type m map[string]interface{} + +const ( + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + + PortAny = 0 // Special value for matching `port: any` + PortFragment = -1 // Special value for matching `port: fragment` +) + +type Packet struct { + LocalIP iputil.VpnIp + RemoteIP iputil.VpnIp + LocalPort uint16 + RemotePort uint16 + Protocol uint8 + Fragment bool +} + +func (fp *Packet) Copy() *Packet { + return &Packet{ + LocalIP: fp.LocalIP, + RemoteIP: fp.RemoteIP, + LocalPort: fp.LocalPort, + RemotePort: fp.RemotePort, + Protocol: fp.Protocol, + Fragment: fp.Fragment, + } +} + +func (fp Packet) MarshalJSON() ([]byte, error) { + var proto string + switch fp.Protocol { + case ProtoTCP: + proto = "tcp" + case ProtoICMP: + proto = "icmp" + case ProtoUDP: + proto = "udp" + default: + proto = fmt.Sprintf("unknown %v", fp.Protocol) + } + return json.Marshal(m{ + "LocalIP": fp.LocalIP.String(), + "RemoteIP": fp.RemoteIP.String(), + "LocalPort": fp.LocalPort, + "RemotePort": fp.RemotePort, + "Protocol": proto, + "Fragment": fp.Fragment, + }) +} diff --git a/firewall_test.go b/firewall_test.go index 43902cd..b98a2cf 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -11,11 +11,15 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func TestNewFirewall(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() c := &cert.NebulaCertificate{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack @@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) { _, ti, _ := net.ParseCIDR("1.2.3.4/32") - assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) - assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) - assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) - assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) assert.False(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) - assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left) - assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) - assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") - assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left) - assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) - assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", "")) assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") // Set any and clear fields fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP))) + assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) // run twice just to make sure //TODO: these ANY rules should clear the CA firewall portion - assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) - assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) - assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left) - assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) - assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) - assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", "")) } func TestFirewall_Drop(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - p := FirewallPacket{ - ip2int(net.IPv4(1, 2, 3, 4)), - ip2int(net.IPv4(1, 2, 3, 4)), + p := firewall.Packet{ + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 10, 90, - fwProtoUDP, + firewall.ProtoUDP, false, } @@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) { // test remote mismatch oldRemote := p.RemoteIP - p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) + p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } @@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("fail on proto", func(b *testing.B) { c := &cert.NebulaCertificate{} for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp) } }) b.Run("fail on port", func(b *testing.B) { c := &cert.NebulaCertificate{} for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp) } }) @@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) @@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) @@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) b.Run("pass on ip", func(b *testing.B) { - ip := ip2int(net.IPv4(172, 1, 1, 1)) + ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ InvertedGroups: map[string]struct{}{"nope": {}}, @@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) } }) _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") b.Run("pass on ip with any port", func(b *testing.B) { - ip := ip2int(net.IPv4(172, 1, 1, 1)) + ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ InvertedGroups: map[string]struct{}{"nope": {}}, @@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) } }) } func TestFirewall_Drop2(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - p := FirewallPacket{ - ip2int(net.IPv4(1, 2, 3, 4)), - ip2int(net.IPv4(1, 2, 3, 4)), + p := firewall.Packet{ + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 10, 90, - fwProtoUDP, + firewall.ProtoUDP, false, } @@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h.CreateRemoteCIDR(&c) @@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - p := FirewallPacket{ - ip2int(net.IPv4(1, 2, 3, 4)), - ip2int(net.IPv4(1, 2, 3, 4)), + p := firewall.Packet{ + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 1, 1, - fwProtoUDP, + firewall.ProtoUDP, false, } @@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h1.CreateRemoteCIDR(&c1) @@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h2.CreateRemoteCIDR(&c2) @@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - p := FirewallPacket{ - ip2int(net.IPv4(1, 2, 3, 4)), - ip2int(net.IPv4(1, 2, 3, 4)), + p := firewall.Packet{ + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 10, 90, - fwProtoUDP, + firewall.ProtoUDP, false, } @@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - hostId: ip2int(ipNet.IP), + vpnIp: iputil.Ip2VpnIp(ipNet.IP), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) { } func TestNewFirewallFromConfig(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() // Test a bad rule definition c := &cert.NebulaCertificate{} - conf := NewConfig(l) + conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} _, err := NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") @@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) { assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") // Test both group and groups - conf = NewConfig(l) + conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() // Test adding tcp rule - conf := NewConfig(l) + conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding udp rule - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding icmp rule - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding any rule - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) // Test adding rule with ca_sha - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) // Test single group - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) // Test single groups - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) // Test multiple AND groups - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) // Test Add error - conf = NewConfig(l) + conf = config.NewC(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} @@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() ob := &bytes.Buffer{} l.SetOutput(ob) @@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end func resetConntrack(fw *Firewall) { fw.Conntrack.Lock() - fw.Conntrack.Conns = map[FirewallPacket]*conn{} + fw.Conntrack.Conns = map[firewall.Packet]*conn{} fw.Conntrack.Unlock() } diff --git a/handshake.go b/handshake.go index 8d8aef0..a08fb2e 100644 --- a/handshake.go +++ b/handshake.go @@ -1,11 +1,11 @@ package nebula -const ( - handshakeIXPSK0 = 0 - handshakeXXPSK0 = 1 +import ( + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) -func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { +func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) { // First remote allow list check before we know the vpnIp if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) { f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") @@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head } switch h.Subtype { - case handshakeIXPSK0: + case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: ixHandshakeStage1(f, addr, packet, h) diff --git a/handshake_ix.go b/handshake_ix.go index 46fd1ec..a0defc6 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -6,13 +6,16 @@ import ( "github.com/flynn/noise" "github.com/golang/protobuf/proto" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { +func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { // This queries the lighthouse if we don't know a remote for the host // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send // more quickly, effect is a quicker handshake. @@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { err := f.handshakeManager.AddIndexHostInfo(hostinfo) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return } @@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { hsBytes, err = proto.Marshal(hs) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return } - header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1) + h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) atomic.AddUint64(&ci.atomicMessageCounter, 1) - msg, _, _, err := ci.H.WriteMessage(header, hsBytes) + msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + f.l.WithError(err).WithField("vpnIp", vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } @@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { hostinfo.handshakeStart = time.Now() } -func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { +func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) { ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) - msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") @@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { Info("Invalid certificate from host") return } - vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer - if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) { - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + if vpnIp == f.myVpnIp { + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } - if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) { - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) { + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - hostId: vpnIP, + vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, } @@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hostinfo.Lock() defer hostinfo.Unlock() - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hsBytes, err := proto.Marshal(hs) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } - header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) + nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) + msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } - hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) - copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) + hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:])) + copy(hostinfo.HandshakePacket[0], packet[header.Len:]) // Regardless of whether you are the sender or receiver, you should arrive here // and complete standing up the connection. @@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert) // Only overwrite existing record if we should win the handshake race - overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) + overwrite := vpnIp > f.myVpnIp existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) if err != nil { switch err { @@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } existing.Unlock() hostinfo.Lock() msg = existing.HandshakePacket[2] - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) err := f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } return case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). @@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). Error("Failed to add HostInfo due to localIndex collision") return case ErrExistingHandshake: // We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { } // Do the send - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) err = f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } -func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool { if hostinfo == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hostinfo.Lock() defer hostinfo.Unlock() - if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { - f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) { + f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } ci := hostinfo.ConnectionState if ci.ready { - f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Info("Handshake is already complete") @@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets return false } - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hs := &NebulaHandshake{} err = proto.Unmarshal(msg, hs) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again @@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) if err != nil { - f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Invalid certificate from host") @@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ return true } - vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer // Ensure the right host responded - if vpnIP != hostinfo.hostId { - f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)). + if vpnIp != hostinfo.vpnIp { + f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // Create a new hostinfo/handshake for the intended vpn ip //TODO: this adds it to the timer wheel in a way that aggressively retries - newHostInfo := f.getOrHandshake(hostinfo.hostId) + newHostInfo := f.getOrHandshake(hostinfo.vpnIp) newHostInfo.Lock() // Block the current used address @@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ newHostInfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)). + f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). Info("Blocked addresses for handshakes") @@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hostinfo.ConnectionState.queueLock.Unlock() // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.hostId = vpnIP + hostinfo.vpnIp = vpnIp f.sendCloseTunnel(hostinfo) newHostInfo.Unlock() @@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ ci.window.Update(f.l, 2) duration := time.Since(hostinfo.handshakeStart).Nanoseconds() - f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). diff --git a/handshake_manager.go b/handshake_manager.go index d03cad4..7f50c5b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -11,6 +11,9 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) const ( @@ -39,7 +42,7 @@ type HandshakeManager struct { pendingHostMap *HostMap mainHostMap *HostMap lightHouse *LightHouse - outside *udpConn + outside *udp.Conn config HandshakeConfig OutboundHandshakeTimer *SystemTimerWheel messageMetrics *MessageMetrics @@ -47,18 +50,18 @@ type HandshakeManager struct { metricTimedOut metrics.Counter l *logrus.Logger - // can be used to trigger outbound handshake for the given vpnIP - trigger chan uint32 + // can be used to trigger outbound handshake for the given vpnIp + trigger chan iputil.VpnIp } -func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan uint32, config.triggerBuffer), + trigger: make(chan iputil.VpnIp, config.triggerBuffer), OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), @@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { case <-ctx.Done(): return case vpnIP := <-c.trigger: - c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") + c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered") c.handleOutbound(vpnIP, f, true) case now := <-clockSource.C: c.NextOutboundHandshakeTimerTick(now, f) @@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { c.OutboundHandshakeTimer.advance(now) for { ep := c.OutboundHandshakeTimer.Purge() if ep == nil { break } - vpnIP := ep.(uint32) - c.handleOutbound(vpnIP, f, false) + vpnIp := ep.(iputil.VpnIp) + c.handleOutbound(vpnIp, f, false) } } -func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) { - hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { + hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { return } @@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT if !hostinfo.HandshakeReady { // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState - c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) return } @@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case if hostinfo.remotes == nil { - hostinfo.remotes = c.lightHouse.QueryCache(vpnIP) + hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) } //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse - // Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about + // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIP, f) + c.lightHouse.QueryServer(vpnIp, f) } // Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udpAddr - hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) { - c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) + var sentTo []*udp.Addr + hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(c.l).WithField("udpAddr", addr). @@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { //TODO: feel like we dupe handshake real fast in a tight loop, why? - c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } -func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { - hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) +func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { + hostinfo := c.pendingHostMap.AddVpnIp(vpnIp) // We lock here and use an array to insert items to prevent locking the // main receive thread for very long by waiting to add items to the pending map //TODO: what lock? - c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) c.metricInitiated.Inc(1) return hostinfo @@ -208,12 +211,12 @@ var ( // CheckAndComplete checks for any conflicts in the main and pending hostmap // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be: - +// // ErrAlreadySeen if we already have an entry in the hostmap that has seen the // exact same handshake packet // // ErrExistingHostInfo if we already have an entry in the hostmap for this -// VpnIP and the new handshake was older than the one we currently have +// VpnIp and the new handshake was older than the one we currently have // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. @@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket defer c.mainHostMap.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] if found && existingHostInfo != nil { // Is it just a delayed handshake packet? if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { @@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket } existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). Info("New host shadows existing host remoteIndex") } // Check if we are also handshaking with this vpn ip - pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId] + pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp] if found && pendingHostInfo != nil { if !overwrite { // We won, let our pending handshake win @@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if existingHostInfo != nil { // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) } @@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { c.mainHostMap.Lock() defer c.mainHostMap.Unlock() - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] if found && existingHostInfo != nil { // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) } @@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). Info("New host shadows existing host remoteIndex") } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index c34b0cf..b669050 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,25 +5,29 @@ import ( "testing" "time" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) -func Test_NewHandshakeManagerVpnIP(t *testing.T) { - l := NewTestLogger() +func Test_NewHandshakeManagerVpnIp(t *testing.T) { + l := util.NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := ip2int(net.ParseIP("172.1.1.2")) + ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) - i := blah.AddVpnIP(ip) + i := blah.AddVpnIp(ip) i.remotes = NewRemoteList() i.HandshakeReady = true @@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) { } func Test_NewHandshakeManagerTrigger(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := ip2int(net.ParseIP("172.1.1.2")) + ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l} + lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l} - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - hi := blah.AddVpnIP(ip) + hi := blah.AddVpnIp(ip) hi.HandshakeReady = true assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") @@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { // Make sure the trigger doesn't double schedule the timer entry assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - uaddr := NewUDPAddrFromString("10.1.1.1:4242") + uaddr := udp.NewAddrFromString("10.1.1.1:4242") hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) // We now have remotes but only the first trigger should have pushed things forward @@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { return } diff --git a/header.go b/header/header.go similarity index 54% rename from header.go rename to header/header.go index 3f15faa..3ba6d8c 100644 --- a/header.go +++ b/header/header.go @@ -1,4 +1,4 @@ -package nebula +package header import ( "encoding/binary" @@ -19,82 +19,78 @@ import ( // |-----------------------------------------------------------------------| // | payload... | -const ( - Version uint8 = 1 - HeaderLen = 16 -) - -type NebulaMessageType uint8 -type NebulaMessageSubType uint8 +type m map[string]interface{} const ( - handshake NebulaMessageType = 0 - message NebulaMessageType = 1 - recvError NebulaMessageType = 2 - lightHouse NebulaMessageType = 3 - test NebulaMessageType = 4 - closeTunnel NebulaMessageType = 5 - - //TODO These are deprecated as of 06/12/2018 - NB - testRemote NebulaMessageType = 6 - testRemoteReply NebulaMessageType = 7 + Version uint8 = 1 + Len = 16 ) -var typeMap = map[NebulaMessageType]string{ - handshake: "handshake", - message: "message", - recvError: "recvError", - lightHouse: "lightHouse", - test: "test", - closeTunnel: "closeTunnel", +type MessageType uint8 +type MessageSubType uint8 - //TODO These are deprecated as of 06/12/2018 - NB - testRemote: "testRemote", - testRemoteReply: "testRemoteReply", +const ( + Handshake MessageType = 0 + Message MessageType = 1 + RecvError MessageType = 2 + LightHouse MessageType = 3 + Test MessageType = 4 + CloseTunnel MessageType = 5 +) + +var typeMap = map[MessageType]string{ + Handshake: "handshake", + Message: "message", + RecvError: "recvError", + LightHouse: "lightHouse", + Test: "test", + CloseTunnel: "closeTunnel", } const ( - testRequest NebulaMessageSubType = 0 - testReply NebulaMessageSubType = 1 + TestRequest MessageSubType = 0 + TestReply MessageSubType = 1 ) -var eHeaderTooShort = errors.New("header is too short") +const ( + HandshakeIXPSK0 MessageSubType = 0 + HandshakeXXPSK0 MessageSubType = 1 +) -var subTypeTestMap = map[NebulaMessageSubType]string{ - testRequest: "testRequest", - testReply: "testReply", +var ErrHeaderTooShort = errors.New("header is too short") + +var subTypeTestMap = map[MessageSubType]string{ + TestRequest: "testRequest", + TestReply: "testReply", } -var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"} +var subTypeNoneMap = map[MessageSubType]string{0: "none"} -var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{ - message: &subTypeNoneMap, - recvError: &subTypeNoneMap, - lightHouse: &subTypeNoneMap, - test: &subTypeTestMap, - closeTunnel: &subTypeNoneMap, - handshake: { - handshakeIXPSK0: "ix_psk0", +var subTypeMap = map[MessageType]*map[MessageSubType]string{ + Message: &subTypeNoneMap, + RecvError: &subTypeNoneMap, + LightHouse: &subTypeNoneMap, + Test: &subTypeTestMap, + CloseTunnel: &subTypeNoneMap, + Handshake: { + HandshakeIXPSK0: "ix_psk0", }, - //TODO: these are deprecated - testRemote: &subTypeNoneMap, - testRemoteReply: &subTypeNoneMap, } -type Header struct { +type H struct { Version uint8 - Type NebulaMessageType - Subtype NebulaMessageSubType + Type MessageType + Subtype MessageSubType Reserved uint16 RemoteIndex uint32 MessageCounter uint64 } -// HeaderEncode uses the provided byte array to encode the provided header values into. +// Encode uses the provided byte array to encode the provided header values into. // Byte array must be capped higher than HeaderLen or this will panic -func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte { - b = b[:HeaderLen] - b[0] = byte(v<<4 | (t & 0x0f)) +func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte { + b = b[:Len] + b[0] = v<<4 | byte(t&0x0f) b[1] = byte(st) binary.BigEndian.PutUint16(b[2:4], 0) binary.BigEndian.PutUint32(b[4:8], ri) @@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b } // String creates a readable string representation of a header -func (h *Header) String() string { +func (h *H) String() string { if h == nil { return "" } @@ -112,7 +108,7 @@ func (h *Header) String() string { } // MarshalJSON creates a json string representation of a header -func (h *Header) MarshalJSON() ([]byte, error) { +func (h *H) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "version": h.Version, "type": h.TypeName(), @@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) { } // Encode turns header into bytes -func (h *Header) Encode(b []byte) ([]byte, error) { +func (h *H) Encode(b []byte) ([]byte, error) { if h == nil { return nil, errors.New("nil header") } - return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil + return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil } // Parse is a helper function to parses given bytes into new Header struct -func (h *Header) Parse(b []byte) error { - if len(b) < HeaderLen { - return eHeaderTooShort +func (h *H) Parse(b []byte) error { + if len(b) < Len { + return ErrHeaderTooShort } // get upper 4 bytes h.Version = uint8((b[0] >> 4) & 0x0f) // get lower 4 bytes - h.Type = NebulaMessageType(b[0] & 0x0f) - h.Subtype = NebulaMessageSubType(b[1]) + h.Type = MessageType(b[0] & 0x0f) + h.Subtype = MessageSubType(b[1]) h.Reserved = binary.BigEndian.Uint16(b[2:4]) h.RemoteIndex = binary.BigEndian.Uint32(b[4:8]) h.MessageCounter = binary.BigEndian.Uint64(b[8:16]) @@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error { } // TypeName will transform the headers message type into a human string -func (h *Header) TypeName() string { +func (h *H) TypeName() string { return TypeName(h.Type) } // TypeName will transform a nebula message type into a human string -func TypeName(t NebulaMessageType) string { +func TypeName(t MessageType) string { if n, ok := typeMap[t]; ok { return n } @@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string { } // SubTypeName will transform the headers message sub type into a human string -func (h *Header) SubTypeName() string { +func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } // SubTypeName will transform a nebula message sub type into a human string -func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string { +func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { if x, ok := (*n)[s]; ok { return x @@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string { } // NewHeader turns bytes into a header -func NewHeader(b []byte) (*Header, error) { - h := new(Header) +func NewHeader(b []byte) (*H, error) { + h := new(H) if err := h.Parse(b); err != nil { return nil, err } diff --git a/header/header_test.go b/header/header_test.go new file mode 100644 index 0000000..710e9c0 --- /dev/null +++ b/header/header_test.go @@ -0,0 +1,115 @@ +package header + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type headerTest struct { + expectedBytes []byte + *H +} + +// 0001 0010 00010010 +var headerBigEndianTests = []headerTest{{ + expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9}, + // 1010 0000 + H: &H{ + // 1111 1+2+4+8 = 15 + Version: 5, + Type: 4, + Subtype: 0, + Reserved: 0, + RemoteIndex: 10, + MessageCounter: 9, + }, +}, +} + +func TestEncode(t *testing.T) { + for _, tt := range headerBigEndianTests { + b, err := tt.Encode(make([]byte, Len)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, tt.expectedBytes, b) + } +} + +func TestParse(t *testing.T) { + for _, tt := range headerBigEndianTests { + b := tt.expectedBytes + parsedHeader := &H{} + parsedHeader.Parse(b) + + if !reflect.DeepEqual(tt.H, parsedHeader) { + t.Fatalf("got %#v; want %#v", parsedHeader, tt.H) + } + } +} + +func TestTypeName(t *testing.T) { + assert.Equal(t, "test", TypeName(Test)) + assert.Equal(t, "test", (&H{Type: Test}).TypeName()) + + assert.Equal(t, "unknown", TypeName(99)) + assert.Equal(t, "unknown", (&H{Type: 99}).TypeName()) +} + +func TestSubTypeName(t *testing.T) { + assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest)) + assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName()) + + assert.Equal(t, "unknown", SubTypeName(99, TestRequest)) + assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName()) + + assert.Equal(t, "unknown", SubTypeName(Test, 99)) + assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName()) + + assert.Equal(t, "none", SubTypeName(Message, 0)) + assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName()) +} + +func TestTypeMap(t *testing.T) { + // Force people to document this stuff + assert.Equal(t, map[MessageType]string{ + Handshake: "handshake", + Message: "message", + RecvError: "recvError", + LightHouse: "lightHouse", + Test: "test", + CloseTunnel: "closeTunnel", + }, typeMap) + + assert.Equal(t, map[MessageType]*map[MessageSubType]string{ + Message: &subTypeNoneMap, + RecvError: &subTypeNoneMap, + LightHouse: &subTypeNoneMap, + Test: &subTypeTestMap, + CloseTunnel: &subTypeNoneMap, + Handshake: { + HandshakeIXPSK0: "ix_psk0", + }, + }, subTypeMap) +} + +func TestHeader_String(t *testing.T) { + assert.Equal( + t, + "ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97", + (&H{100, Test, TestRequest, 99, 98, 97}).String(), + ) +} + +func TestHeader_MarshalJSON(t *testing.T) { + b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", + string(b), + ) +} diff --git a/header_test.go b/header_test.go deleted file mode 100644 index b2df090..0000000 --- a/header_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package nebula - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/assert" -) - -type headerTest struct { - expectedBytes []byte - *Header -} - -// 0001 0010 00010010 -var headerBigEndianTests = []headerTest{{ - expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9}, - // 1010 0000 - Header: &Header{ - // 1111 1+2+4+8 = 15 - Version: 5, - Type: 4, - Subtype: 0, - Reserved: 0, - RemoteIndex: 10, - MessageCounter: 9, - }, -}, -} - -func TestEncode(t *testing.T) { - for _, tt := range headerBigEndianTests { - b, err := tt.Encode(make([]byte, HeaderLen)) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, tt.expectedBytes, b) - } -} - -func TestParse(t *testing.T) { - for _, tt := range headerBigEndianTests { - b := tt.expectedBytes - parsedHeader := &Header{} - parsedHeader.Parse(b) - - if !reflect.DeepEqual(tt.Header, parsedHeader) { - t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header) - } - } -} - -func TestTypeName(t *testing.T) { - assert.Equal(t, "test", TypeName(test)) - assert.Equal(t, "test", (&Header{Type: test}).TypeName()) - - assert.Equal(t, "unknown", TypeName(99)) - assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName()) -} - -func TestSubTypeName(t *testing.T) { - assert.Equal(t, "testRequest", SubTypeName(test, testRequest)) - assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName()) - - assert.Equal(t, "unknown", SubTypeName(99, testRequest)) - assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName()) - - assert.Equal(t, "unknown", SubTypeName(test, 99)) - assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName()) - - assert.Equal(t, "none", SubTypeName(message, 0)) - assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName()) -} - -func TestTypeMap(t *testing.T) { - // Force people to document this stuff - assert.Equal(t, map[NebulaMessageType]string{ - handshake: "handshake", - message: "message", - recvError: "recvError", - lightHouse: "lightHouse", - test: "test", - closeTunnel: "closeTunnel", - testRemote: "testRemote", - testRemoteReply: "testRemoteReply", - }, typeMap) - - assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{ - message: &subTypeNoneMap, - recvError: &subTypeNoneMap, - lightHouse: &subTypeNoneMap, - test: &subTypeTestMap, - closeTunnel: &subTypeNoneMap, - handshake: { - handshakeIXPSK0: "ix_psk0", - }, - testRemote: &subTypeNoneMap, - testRemoteReply: &subTypeNoneMap, - }, subTypeMap) -} - -func TestHeader_String(t *testing.T) { - assert.Equal( - t, - "ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97", - (&Header{100, test, testRequest, 99, 98, 97}).String(), - ) -} - -func TestHeader_MarshalJSON(t *testing.T) { - b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON() - assert.Nil(t, err) - assert.Equal( - t, - "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", - string(b), - ) -} diff --git a/hostmap.go b/hostmap.go index 2f46d83..6545307 100644 --- a/hostmap.go +++ b/hostmap.go @@ -12,6 +12,10 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) //const ProbeLen = 100 @@ -28,10 +32,10 @@ type HostMap struct { name string Indexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo - Hosts map[uint32]*HostInfo + Hosts map[iputil.VpnIp]*HostInfo preferredRanges []*net.IPNet vpnCIDR *net.IPNet - unsafeRoutes *CIDRTree + unsafeRoutes *cidr.Tree4 metricsEnabled bool l *logrus.Logger } @@ -39,7 +43,7 @@ type HostMap struct { type HostInfo struct { sync.RWMutex - remote *udpAddr + remote *udp.Addr remotes *RemoteList promoteCounter uint32 ConnectionState *ConnectionState @@ -51,9 +55,9 @@ type HostInfo struct { packetStore []*cachedPacket //todo: this is other handshake manager entry remoteIndexId uint32 localIndexId uint32 - hostId uint32 + vpnIp iputil.VpnIp recvError int - remoteCidr *CIDRTree + remoteCidr *cidr.Tree4 // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like @@ -66,17 +70,17 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udpAddr + lastRoamRemote *udp.Addr } type cachedPacket struct { - messageType NebulaMessageType - messageSubType NebulaMessageSubType + messageType header.MessageType + messageSubType header.MessageSubType callback packetCallback packet []byte } -type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte) +type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) type cachedPacketMetrics struct { sent metrics.Counter @@ -84,7 +88,7 @@ type cachedPacketMetrics struct { } func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { - h := map[uint32]*HostInfo{} + h := map[iputil.VpnIp]*HostInfo{} i := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{} m := HostMap{ @@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang Hosts: h, preferredRanges: preferredRanges, vpnCIDR: vpnCIDR, - unsafeRoutes: NewCIDRTree(), + unsafeRoutes: cidr.NewTree4(), l: l, } return &m @@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) { metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) } -func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { +func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) { hm.RLock() - if i, ok := hm.Hosts[vpnIP]; ok { + if i, ok := hm.Hosts[vpnIp]; ok { index := i.localIndexId hm.RUnlock() return index, nil @@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { return 0, errors.New("vpn IP not found") } -func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) { +func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) { hm.Lock() hm.Hosts[ip] = hostinfo hm.Unlock() } -func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo { +func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { h := &HostInfo{} hm.RLock() - if _, ok := hm.Hosts[vpnIP]; !ok { + if _, ok := hm.Hosts[vpnIp]; !ok { hm.RUnlock() h = &HostInfo{ promoteCounter: 0, - hostId: vpnIP, + vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), } hm.Lock() - hm.Hosts[vpnIP] = h + hm.Hosts[vpnIp] = h hm.Unlock() return h } else { - h = hm.Hosts[vpnIP] + h = hm.Hosts[vpnIp] hm.RUnlock() return h } } -func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { +func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) { hm.Lock() - delete(hm.Hosts, vpnIP) + delete(hm.Hosts, vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[uint32]*HostInfo{} + hm.Hosts = map[iputil.VpnIp]*HostInfo{} } hm.Unlock() if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}). Debug("Hostmap vpnIp deleted") } } @@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { if hm.l.Level > logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}). Debug("Hostmap remoteIndex added") } } -func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) { +func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) { hm.Lock() - h.hostId = vpnIP - hm.Hosts[vpnIP] = h + h.vpnIp = vpnIp + hm.Hosts[vpnIp] = h hm.Indexes[h.localIndexId] = h hm.RemoteIndexes[h.remoteIndexId] = h hm.Unlock() if hm.l.Level > logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}). Debug("Hostmap vpnIp added") } } @@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) { // Check if we have an entry under hostId that matches the same hostinfo // instance. Clean it up as well if we do. - hostinfo2, ok := hm.Hosts[hostinfo.hostId] + hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.hostId) + delete(hm.Hosts, hostinfo.vpnIp) } } hm.Unlock() @@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) { // Check if we have an entry under hostId that matches the same hostinfo // instance. Clean it up as well if we do (they might not match in pendingHostmap) var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.hostId] + hostinfo2, ok = hm.Hosts[hostinfo.vpnIp] if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.hostId) + delete(hm.Hosts, hostinfo.vpnIp) } } hm.Unlock() @@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // Check if this same hostId is in the hostmap with a different instance. // This could happen if we have an entry in the pending hostmap with different // index values than the one in the main hostmap. - hostinfo2, ok := hm.Hosts[hostinfo.hostId] + hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] if ok && hostinfo2 != hostinfo { - delete(hm.Hosts, hostinfo2.hostId) + delete(hm.Hosts, hostinfo2.vpnIp) delete(hm.Indexes, hostinfo2.localIndexId) delete(hm.RemoteIndexes, hostinfo2.remoteIndexId) } - delete(hm.Hosts, hostinfo.hostId) + delete(hm.Hosts, hostinfo.vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[uint32]*HostInfo{} + hm.Hosts = map[iputil.VpnIp]*HostInfo{} } delete(hm.Indexes, hostinfo.localIndexId) if len(hm.Indexes) == 0 { @@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), - "vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } } @@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { } } -func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) { - return hm.queryVpnIP(vpnIp, nil) +func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) { + return hm.queryVpnIp(vpnIp, nil) } -// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every +// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every // `PromoteEvery` calls to this function for a given host. -func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) { - return hm.queryVpnIP(vpnIp, ifce) +func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) { + return hm.queryVpnIp(vpnIp, ifce) } -func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { +func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, return nil, errors.New("unable to find host") } -func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { +func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp { r := hm.unsafeRoutes.MostSpecificContains(ip) if r != nil { - return r.(uint32) + return r.(iputil.VpnIp) } else { return 0 } @@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } - hm.Hosts[hostinfo.hostId] = hostinfo + hm.Hosts[hostinfo.vpnIp] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). + hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). Debug("Hostmap vpnIp added") } } @@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList { } // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them -func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) { +func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { var metricsTxPunchy metrics.Counter if hm.metricsEnabled { metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) @@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) { func (hm *HostMap) addUnsafeRoutes(routes *[]route) { for _, r := range *routes { hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route") - hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via)) + hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via)) } } @@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) { + i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { if addr == nil || !preferred { return } // Try to send a test packet to that host, this should // cause it to detect a roaming event and switch remotes - ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }) } // Re query our lighthouses for new remotes occasionally if c%ReQueryEvery == 0 && ifce.lightHouse != nil { - ifce.lightHouse.QueryServer(i.hostId, ifce) + ifce.lightHouse.QueryServer(i.vpnIp, ifce) } } -func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { //TODO: return the error so we can log with more context if len(i.packetStore) < 100 { tempPacket := make([]byte, len(packet)) @@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) SetRemote(remote *udpAddr) { +func (i *HostInfo) SetRemote(remote *udp.Addr) { // We copy here because we likely got this remote from a source that reuses the object if !i.remote.Equals(remote) { i.remote = remote.Copy() - i.remotes.LearnRemote(i.hostId, remote.Copy()) + i.remotes.LearnRemote(i.vpnIp, remote.Copy()) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { currentRemote := i.remote if currentRemote == nil { i.SetRemote(newRemote) @@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := NewCIDRTree() + remoteCidr := cidr.NewTree4() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } @@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", IntIp(i.hostId)) - + li := l.WithField("vpnIp", i.vpnIp) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { li = li.WithField("certName", peerCert.Details.Name) @@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return li } -//######################## - -/* - -func (hm *HostMap) DebugRemotes(vpnIp uint32) string { - s := "\n" - for _, h := range hm.Hosts { - for _, r := range h.Remotes { - s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes) - } - } - return s -} - -func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) { - for _, r := range i.Remotes { - if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port { - r.ProbeReceived(counter) - } - } -} - -func (i *HostInfo) Probes() []*Probe { - p := []*Probe{} - for _, d := range i.Remotes { - p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()}) - } - return p -} - -*/ - // Utility functions func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { diff --git a/inside.go b/inside.go index 3695028..8a7c990 100644 --- a/inside.go +++ b/inside.go @@ -5,9 +5,13 @@ import ( "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) @@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, hostinfo := f.getOrHandshake(fwPacket.RemoteIP) if hostinfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)). + f.l.WithField("vpnIp", fwPacket.RemoteIP). WithField("fwPacket", fwPacket). Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") } @@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, // the packet queue. ci.queueLock.Lock() if !ci.ready { - hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) ci.queueLock.Unlock() return } @@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) if dropReason == nil { - f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) } else if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). @@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, } // getOrHandshake returns nil if the vpnIp is not routable -func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { - if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false { +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { + //TODO: we can find contains without converting back to bytes + if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false { vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) if vpnIp == 0 { return nil } } - hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) + hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) //if err != nil || hostinfo.ConnectionState == nil { if err != nil { - hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp) + hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { - hostinfo = f.handshakeManager.AddVpnIP(vpnIp) + hostinfo = f.handshakeManager.AddVpnIp(vpnIp) } } ci := hostinfo.ConnectionState @@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { return hostinfo } -func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { - fp := &FirewallPacket{} +func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { + fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) @@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, return } - f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { hostInfo := f.getOrHandshake(vpnIp) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", IntIp(vpnIp)). + f.l.WithField("vpnIp", vpnIp). Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") } return @@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT return } -func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { +func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) } -func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { +func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning return @@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, c := atomic.AddUint64(&ci.atomicMessageCounter, 1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) - out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c) - f.connectionManager.Out(hostinfo.hostId) + out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo.vpnIp) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // all our IPs and enable a faster roaming. - if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount { + if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.hostId, f) + f.lightHouse.QueryServer(hostinfo.vpnIp, f) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, return } -func isMulticast(ip uint32) bool { +func isMulticast(ip iputil.VpnIp) bool { // Class D multicast if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 { return true diff --git a/interface.go b/interface.go index fc5642a..c95a354 100644 --- a/interface.go +++ b/interface.go @@ -12,6 +12,10 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) const mtu = 9001 @@ -27,7 +31,7 @@ type Inside interface { type InterfaceConfig struct { HostMap *HostMap - Outside *udpConn + Outside *udp.Conn Inside Inside certState *CertState Cipher string @@ -39,7 +43,6 @@ type InterfaceConfig struct { pendingDeletionInterval int DropLocalBroadcast bool DropMulticast bool - UDPBatchSize int routines int MessageMetrics *MessageMetrics version string @@ -52,7 +55,7 @@ type InterfaceConfig struct { type Interface struct { hostMap *HostMap - outside *udpConn + outside *udp.Conn inside Inside certState *CertState cipher string @@ -62,11 +65,10 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - localBroadcast uint32 - myVpnIp uint32 + localBroadcast iputil.VpnIp + myVpnIp iputil.VpnIp dropLocalBroadcast bool dropMulticast bool - udpBatchSize int routines int caPool *cert.NebulaCAPool disconnectInvalid bool @@ -77,7 +79,7 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []*udpConn + writers []*udp.Conn readers []io.ReadWriteCloser metricHandshakes metrics.Histogram @@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } + myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) ifce := &Interface{ hostMap: c.HostMap, outside: c.Outside, @@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask), + localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, - udpBatchSize: c.UDPBatchSize, routines: c.routines, version: c.version, - writers: make([]*udpConn, c.routines), + writers: make([]*udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), caPool: c.caPool, disconnectInvalid: c.disconnectInvalid, - myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP), + myVpnIp: myVpnIp, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,14 +192,17 @@ func (f *Interface) run() { func (f *Interface) listenOut(i int) { runtime.LockOSThread() - var li *udpConn + var li *udp.Conn // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } - li.ListenOut(f, i) + + lhh := f.lightHouse.NewRequestHandler() + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) out := make([]byte, mtu) - fwPacket := &FirewallPacket{} + fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { n, err := reader.Read(packet) @@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } -func (f *Interface) RegisterConfigChangeCallbacks(c *Config) { +func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadCA) c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadFirewall) for _, udpConn := range f.writers { - c.RegisterReloadCallback(udpConn.reloadConfig) + c.RegisterReloadCallback(udpConn.ReloadConfig) } } -func (f *Interface) reloadCA(c *Config) { +func (f *Interface) reloadCA(c *config.C) { // reload and check regardless // todo: need mutex? newCAs, err := loadCAFromConfig(f.l, c) @@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) { f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") } -func (f *Interface) reloadCertKey(c *Config) { +func (f *Interface) reloadCertKey(c *config.C) { // reload and check in all cases cs, err := NewCertStateFromConfig(c) if err != nil { @@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) { f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") } -func (f *Interface) reloadFirewall(c *Config) { +func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { f.l.Debug("No firewall config change detected") @@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() - udpStats := NewUDPStatsEmitter(f.writers) + udpStats := udp.NewUDPStatsEmitter(f.writers) for { select { diff --git a/iputil/util.go b/iputil/util.go new file mode 100644 index 0000000..5e181a2 --- /dev/null +++ b/iputil/util.go @@ -0,0 +1,66 @@ +package iputil + +import ( + "encoding/binary" + "fmt" + "net" +) + +type VpnIp uint32 + +const maxIPv4StringLen = len("255.255.255.255") + +func (ip VpnIp) String() string { + b := make([]byte, maxIPv4StringLen) + + n := ubtoa(b, 0, byte(ip>>24)) + b[n] = '.' + n++ + + n += ubtoa(b, n, byte(ip>>16&255)) + b[n] = '.' + n++ + + n += ubtoa(b, n, byte(ip>>8&255)) + b[n] = '.' + n++ + + n += ubtoa(b, n, byte(ip&255)) + return string(b[:n]) +} + +func (ip VpnIp) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil +} + +func (ip VpnIp) ToIP() net.IP { + nip := make(net.IP, 4) + binary.BigEndian.PutUint32(nip, uint32(ip)) + return nip +} + +func Ip2VpnIp(ip []byte) VpnIp { + if len(ip) == 16 { + return VpnIp(binary.BigEndian.Uint32(ip[12:16])) + } + return VpnIp(binary.BigEndian.Uint32(ip)) +} + +// ubtoa encodes the string form of the integer v to dst[start:] and +// returns the number of bytes written to dst. The caller must ensure +// that dst has sufficient length. +func ubtoa(dst []byte, start int, v byte) int { + if v < 10 { + dst[start] = v + '0' + return 1 + } else if v < 100 { + dst[start+1] = v%10 + '0' + dst[start] = v/10 + '0' + return 2 + } + + dst[start+2] = v%10 + '0' + dst[start+1] = (v/10)%10 + '0' + dst[start] = v/100 + '0' + return 3 +} diff --git a/iputil/util_test.go b/iputil/util_test.go new file mode 100644 index 0000000..712d426 --- /dev/null +++ b/iputil/util_test.go @@ -0,0 +1,17 @@ +package iputil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVpnIp_String(t *testing.T) { + assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) + assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) + assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) + assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) + assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) + assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) +} diff --git a/lighthouse.go b/lighthouse.go index 0c12144..ac555fa 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -12,6 +12,9 @@ import ( "github.com/golang/protobuf/proto" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? @@ -23,13 +26,13 @@ type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps amLighthouse bool - myVpnIp uint32 - myVpnZeros uint32 - punchConn *udpConn + myVpnIp iputil.VpnIp + myVpnZeros iputil.VpnIp + punchConn *udp.Conn // Local cache of answers from light houses // map of vpn Ip to answers - addrMap map[uint32]*RemoteList + addrMap map[iputil.VpnIp]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -42,12 +45,12 @@ type LightHouse struct { localAllowList *LocalAllowList // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- uint32 + handshakeTrigger chan<- iputil.VpnIp // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList map[uint32]struct{} - lighthouses map[uint32]struct{} + staticList map[iputil.VpnIp]struct{} + lighthouses map[iputil.VpnIp]struct{} interval int nebulaPort uint32 // 32 bits because protobuf does not have a uint16 punchBack bool @@ -58,20 +61,16 @@ type LightHouse struct { l *logrus.Logger } -type EncWriter interface { - SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) -} - -func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { +func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { ones, _ := myVpnIpNet.Mask.Size() h := LightHouse{ amLighthouse: amLighthouse, - myVpnIp: ip2int(myVpnIpNet.IP), - myVpnZeros: uint32(32 - ones), - addrMap: make(map[uint32]*RemoteList), + myVpnIp: iputil.Ip2VpnIp(myVpnIpNet.IP), + myVpnZeros: iputil.VpnIp(32 - ones), + addrMap: make(map[iputil.VpnIp]*RemoteList), nebulaPort: nebulaPort, - lighthouses: make(map[uint32]struct{}), - staticList: make(map[uint32]struct{}), + lighthouses: make(map[iputil.VpnIp]struct{}), + staticList: make(map[iputil.VpnIp]struct{}), interval: interval, punchConn: pc, punchBack: punchBack, @@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) { func (lh *LightHouse) ValidateLHStaticEntries() error { for lhIP, _ := range lh.lighthouses { if _, ok := lh.staticList[lhIP]; !ok { - return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP)) + return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP) } } return nil } -func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip, f) } @@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList { } // This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { +func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { if lh.amLighthouse { return } @@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { // Send a query to the lighthouses and hope for the best next time query, err := proto.Marshal(NewLhQueryByInt(ip)) if err != nil { - lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") + lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") return } @@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { nb := make([]byte, 12, 12) out := make([]byte, mtu) for n := range lh.lighthouses { - f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) + f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) } } -func (lh *LightHouse) QueryCache(ip uint32) *RemoteList { +func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() @@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList { // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnIp]; ok { @@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err return false, 0, nil } -func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { +func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.staticList[vpnIP]; ok { + if _, ok := lh.staticList[vpnIp]; ok { return } lh.Lock() //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIP) + delete(lh.addrMap, vpnIp) if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) + lh.l.Debugf("deleting %s from lighthouse.", vpnIp) } lh.Unlock() @@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client -func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { +func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { } // unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList { - am, ok := lh.addrMap[vpnIP] +func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { + am, ok := lh.addrMap[vpnIp] if !ok { am = NewRemoteList() - lh.addrMap[vpnIP] = am + lh.addrMap[vpnIp] = am } return am } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool { - allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { + allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) { + if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { return false } @@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool { } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool { +func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") @@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool { - if _, ok := lh.lighthouses[vpnIP]; ok { +func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { + if _, ok := lh.lighthouses[vpnIp]; ok { return true } return false } -func NewLhQueryByInt(VpnIp uint32) *NebulaMeta { +func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { return &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: VpnIp, + VpnIp: uint32(VpnIp), }, } } func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { ipp := Ip4AndPort{Port: port} - ipp.Ip = ip2int(ip) + ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) return &ipp } @@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { } } -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr { +func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { ip := ipp.Ip - return NewUDPAddr( + return udp.NewAddr( net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), uint16(ipp.Port), ) } -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr { - return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) +func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { + return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { +func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { if lh.amLighthouse || lh.interval == 0 { return } @@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { } } -func (lh *LightHouse) SendUpdate(f EncWriter) { +func (lh *LightHouse) SendUpdate(f udp.EncWriter) { var v4 []*Ip4AndPort var v6 []*Ip6AndPort for _, e := range *localIps(lh.l, lh.localAllowList) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) { + if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { continue } @@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) { m := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: lh.myVpnIp, + VpnIp: uint32(lh.myVpnIp), Ip4AndPorts: v4, Ip6AndPorts: v6, }, @@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) { } for vpnIp := range lh.lighthouses { - f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out) + f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) } } @@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { } func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) { - lh.metrics.Rx(NebulaMessageType(t), 0, i) + lh.metrics.Rx(header.MessageType(t), 0, i) } func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) { - lh.metrics.Tx(NebulaMessageType(t), 0, i) + lh.metrics.Tx(header.MessageType(t), 0, i) } // This method is similar to Reset(), but it re-uses the pointer structs @@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") //TODO: send recv_error? return @@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr } //TODO: we can DRY this further - reqVpnIP := n.Details.VpnIp + reqVpnIp := n.Details.VpnIp //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIP + n.Details.VpnIp = reqVpnIp lhh.coalesceAnswers(c, n) @@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = vpnIp + n.Details.VpnIp = uint32(vpnIp) lhh.coalesceAnswers(c, n) @@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { @@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp) + am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + certVpnIp := iputil.VpnIp(n.Details.VpnIp) + am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- n.Details.VpnIp: + case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } //Simple check that the host sent this not someone else - if n.Details.VpnIp != vpnIp { + if n.Details.VpnIp != uint32(vpnIp) { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") } return } @@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + certVpnIp := iputil.VpnIp(n.Details.VpnIp) + am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } empty := []byte{0} - punch := func(vpnPeer *udpAddr) { + punch := func(vpnPeer *udp.Addr) { if vpnPeer == nil { return } @@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u if lhh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp)) + lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) } } @@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u go func() { time.Sleep(time.Second * 5) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) + lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp)) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } // ipMaskContains checks if testIp is contained by ip after applying a cidr // zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool { +func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { return (testIp^ip)>>zeros == 0 } diff --git a/lighthouse_test.go b/lighthouse_test.go index fcd9cc2..03c96b9 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -6,6 +6,10 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) @@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) { var m Ip4AndPort err := proto.Unmarshal(b, &m) assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String()) + assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) } func TestNewLhQuery(t *testing.T) { myIp := net.ParseIP("192.1.1.1") - myIpint := ip2int(myIp) + myIpint := iputil.Ip2VpnIp(myIp) // Generating a new lh query should work a := NewLhQueryByInt(myIpint) @@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) { } func Test_lhStaticMapping(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) - udpServer, _ := NewListener(l, "0.0.0.0", 0, true) + udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) - meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) + meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false) + meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242))) err := meh.ValidateLHStaticEntries() assert.Nil(t, err) lh2 := "10.128.0.3" lh2IP := net.ParseIP(lh2) - meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) + meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false) + meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242))) err = meh.ValidateLHStaticEntries() assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") } func BenchmarkLighthouseHandleRequest(b *testing.B) { - l := NewTestLogger() + l := util.NewTestLogger() lh1 := "10.128.0.2" lh1IP := net.ParseIP(lh1) - udpServer, _ := NewListener(l, "0.0.0.0", 0, true) + udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) - lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false) - hAddr := NewUDPAddrFromString("4.5.6.7:12345") - hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") + hAddr := udp.NewAddrFromString("4.5.6.7:12345") + hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") lh.addrMap[3] = NewRemoteList() lh.addrMap[3].unlockedSetV4( 3, @@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), }, - func(uint32, *Ip4AndPort) bool { return true }, + func(iputil.VpnIp, *Ip4AndPort) bool { return true }, ) - rAddr := NewUDPAddrFromString("1.2.2.3:12345") - rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") + rAddr := udp.NewAddrFromString("1.2.2.3:12345") + rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") lh.addrMap[2] = NewRemoteList() lh.addrMap[2].unlockedSetV4( 3, @@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), }, - func(uint32, *Ip4AndPort) bool { return true }, + func(iputil.VpnIp, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} @@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } func TestLighthouse_Memory(t *testing.T) { - l := NewTestLogger() + l := util.NewTestLogger() - myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := ip2int(net.ParseIP("10.128.0.2")) + myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} + myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} + myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} + myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} + myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} + myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} + myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} + myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} + myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} + myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} + myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} + myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} + myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) - theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := ip2int(net.ParseIP("10.128.0.3")) + theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} + theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} + theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} + theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} + theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} + theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) - udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false) + udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) + lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Update a different host - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) @@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) { newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udpAddr{ + []*udp.Addr{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) { ) // Make sure we won't add ips in our vpn network - bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh) + bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} + bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} + good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} + newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } -func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: queryVpnIp, + VpnIp: uint32(queryVpnIp), }, } @@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH return w.lastReply } -func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: vpnIp, + VpnIp: uint32(vpnIp), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)} + req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} } b, err := req.Marshal() @@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig //} func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1")))) + assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) + assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) + assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) } type testLhReply struct { - nebType NebulaMessageType - nebSubType NebulaMessageSubType - vpnIp uint32 + nebType header.MessageType + nebSubType header.MessageSubType + vpnIp iputil.VpnIp msg *NebulaMeta } @@ -343,7 +347,7 @@ type testEncWriter struct { lastReply testLhReply } -func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { tw.lastReply = testLhReply{ nebType: t, nebSubType: st, @@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) { +func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { assert.Len(t, have, len(want)) for k, w := range want { - if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) { + if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) } } } // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) { +func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { assert.Len(t, have, len(want)) for k, w := range want { if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { @@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) { } } -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr { - addrs := make([]*udpAddr, len(ips)) +func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { + addrs := make([]*udp.Addr, len(ips)) for k, v := range ips { addrs[k] = NewUDPAddrFromLH4(v) } diff --git a/logger.go b/logger.go index fa42f19..8846264 100644 --- a/logger.go +++ b/logger.go @@ -2,8 +2,12 @@ package nebula import ( "errors" + "fmt" + "strings" + "time" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" ) type ContextualError struct { @@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) { lr.WithFields(ce.Fields).Error(ce.Context) } } + +func configLogger(l *logrus.Logger, c *config.C) error { + // set up our logging level + logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) + } + l.SetLevel(logLevel) + + disableTimestamp := c.GetBool("logging.disable_timestamp", false) + timestampFormat := c.GetString("logging.timestamp_format", "") + fullTimestamp := (timestampFormat != "") + if timestampFormat == "" { + timestampFormat = time.RFC3339 + } + + logFormat := strings.ToLower(c.GetString("logging.format", "text")) + switch logFormat { + case "text": + l.Formatter = &logrus.TextFormatter{ + TimestampFormat: timestampFormat, + FullTimestamp: fullTimestamp, + DisableTimestamp: disableTimestamp, + } + case "json": + l.Formatter = &logrus.JSONFormatter{ + TimestampFormat: timestampFormat, + DisableTimestamp: disableTimestamp, + } + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + } + + return nil +} diff --git a/main.go b/main.go index 048a4f3..91418e1 100644 --- a/main.go +++ b/main.go @@ -8,14 +8,16 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" + "github.com/slackhq/nebula/udp" "gopkg.in/yaml.v2" ) type m map[string]interface{} -func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { - +func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // Print the config if in test, the exit comes later if configTest { - b, err := yaml.Marshal(config.Settings) + b, err := yaml.Marshal(c.Settings) if err != nil { return nil, err } @@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L l.Println(string(b)) } - err := configLogger(config) + err := configLogger(l, c) if err != nil { return nil, NewContextualError("Failed to configure the logger", nil, err) } - config.RegisterReloadCallback(func(c *Config) { - err := configLogger(c) + c.RegisterReloadCallback(func(c *config.C) { + err := configLogger(l, c) if err != nil { l.WithError(err).Error("Failed to configure the logger") } }) - caPool, err := loadCAFromConfig(l, config) + caPool, err := loadCAFromConfig(l, c) if err != nil { //The errors coming out of loadCA are already nicely formatted return nil, NewContextualError("Failed to load ca from config", nil, err) } l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") - cs, err := NewCertStateFromConfig(config) + cs, err := NewCertStateFromConfig(c) if err != nil { //The errors coming out of NewCertStateFromConfig are already nicely formatted return nil, NewContextualError("Failed to load certificate from config", nil, err) } l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - fw, err := NewFirewallFromConfig(l, cs.certificate, config) + fw, err := NewFirewallFromConfig(l, cs.certificate, c) if err != nil { return nil, NewContextualError("Error while loading firewall rules", nil, err) } @@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // TODO: make sure mask is 4 bytes tunCidr := cs.certificate.Details.Ips[0] - routes, err := parseRoutes(config, tunCidr) + routes, err := parseRoutes(c, tunCidr) if err != nil { return nil, NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) + unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) if err != nil { return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) - wireSSHReload(l, ssh, config) + wireSSHReload(l, ssh, c) var sshStart func() - if config.GetBool("sshd.enabled", false) { - sshStart, err = configSSH(l, ssh, config) + if c.GetBool("sshd.enabled", false) { + sshStart, err = configSSH(l, ssh, c) if err != nil { return nil, NewContextualError("Error while configuring the sshd", nil, err) } @@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L var routines int // If `routines` is set, use that and ignore the specific values - if routines = config.GetInt("routines", 0); routines != 0 { + if routines = c.GetInt("routines", 0); routines != 0 { if routines < 1 { routines = 1 } @@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } } else { // deprecated and undocumented - tunQueues := config.GetInt("tun.routines", 1) - udpQueues := config.GetInt("listen.routines", 1) + tunQueues := c.GetInt("tun.routines", 1) + udpQueues := c.GetInt("listen.routines", 1) if tunQueues > udpQueues { routines = tunQueues } else { @@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // EXPERIMENTAL // Intentionally not documented yet while we do more testing and determine // a good default value. - conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0) - if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") { + conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0) + if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") { // Use a different default if we are running with multiple routines conntrackCacheTimeout = 1 * time.Second } @@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L var tun Inside if !configTest { - config.CatchHUP(ctx) + c.CatchHUP(ctx) switch { - case config.GetBool("tun.disabled", false): - tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l) + case c.GetBool("tun.disabled", false): + tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) case tunFd != nil: tun, err = newTunFromFd( l, *tunFd, tunCidr, - config.GetInt("tun.mtu", DEFAULT_MTU), + c.GetInt("tun.mtu", DEFAULT_MTU), routes, unsafeRoutes, - config.GetInt("tun.tx_queue", 500), + c.GetInt("tun.tx_queue", 500), ) default: tun, err = newTun( l, - config.GetString("tun.dev", ""), + c.GetString("tun.dev", ""), tunCidr, - config.GetInt("tun.mtu", DEFAULT_MTU), + c.GetInt("tun.mtu", DEFAULT_MTU), routes, unsafeRoutes, - config.GetInt("tun.tx_queue", 500), + c.GetInt("tun.tx_queue", 500), routines > 1, ) } @@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L }() // set up our UDP listener - udpConns := make([]*udpConn, routines) - port := config.GetInt("listen.port", 0) + udpConns := make([]*udp.Conn, routines) + port := c.GetInt("listen.port", 0) if !configTest { for i := 0; i < routines; i++ { - udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1) + udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) } - udpServer.reloadConfig(config) + udpServer.ReloadConfig(c) udpConns[i] = udpServer // If port is dynamic, discover it @@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // Set up my internal host map var preferredRanges []*net.IPNet - rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) + rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) // First, check if 'preferred_ranges' is set and fallback to 'local_range' if len(rawPreferredRanges) > 0 { for _, rawPreferredRange := range rawPreferredRanges { @@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // local_range was superseded by preferred_ranges. If it is still present, // merge the local_range setting into preferred_ranges. We will probably // deprecate local_range and remove in the future. - rawLocalRange := config.GetString("local_range", "") + rawLocalRange := c.GetString("local_range", "") if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { @@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) hostMap.addUnsafeRoutes(&unsafeRoutes) - hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) + hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") @@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L go hostMap.Promoter(config.GetInt("promoter.interval")) */ - punchy := NewPunchyFromConfig(config) + punchy := NewPunchyFromConfig(c) if punchy.Punch && !configTest { l.Info("UDP hole punching enabled") go hostMap.Punchy(ctx, udpConns[0]) } - amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) // fatal if am_lighthouse is enabled but we are using an ephemeral port - if amLighthouse && (config.GetInt("listen.port", 0) == 0) { + if amLighthouse && (c.GetInt("listen.port", 0) == 0) { return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) } // warn if am_lighthouse is enabled but upstream lighthouses exists - rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{}) + rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{}) if amLighthouse && len(rawLighthouseHosts) != 0 { l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } - lighthouseHosts := make([]uint32, len(rawLighthouseHosts)) + lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts)) for i, host := range rawLighthouseHosts { ip := net.ParseIP(host) if ip == nil { @@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L if !tunCidr.Contains(ip) { return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) } - lighthouseHosts[i] = ip2int(ip) + lighthouseHosts[i] = iputil.Ip2VpnIp(ip) } lightHouse := NewLightHouse( @@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L tunCidr, lighthouseHosts, //TODO: change to a duration - config.GetInt("lighthouse.interval", 10), + c.GetInt("lighthouse.interval", 10), uint32(port), udpConns[0], punchy.Respond, punchy.Delay, - config.GetBool("stats.lighthouse_metrics", false), + c.GetBool("stats.lighthouse_metrics", false), ) - remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") + remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) - localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list") + localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") if err != nil { return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } lightHouse.SetLocalAllowList(localAllowList) //TODO: Move all of this inside functions in lighthouse.go - for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { - vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) - if !tunCidr.Contains(vpnIp) { + for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) { + ip := net.ParseIP(fmt.Sprintf("%v", k)) + vpnIp := iputil.Ip2VpnIp(ip) + if !tunCidr.Contains(ip) { return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) } vals, ok := v.([]interface{}) if ok { for _, v := range vals { - ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) + ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) if err != nil { return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } - lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) + lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) } } else { - ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) + ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) if err != nil { return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } - lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) + lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) } } @@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } var messageMetrics *MessageMetrics - if config.GetBool("stats.message_metrics", false) { + if c.GetBool("stats.message_metrics", false) { messageMetrics = newMessageMetrics() } else { messageMetrics = newMessageMetricsOnlyRecvError() } handshakeConfig := HandshakeConfig{ - tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), - triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, } @@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) serveDns := false - if config.GetBool("lighthouse.serve_dns", false) { - if config.GetBool("lighthouse.am_lighthouse", false) { + if c.GetBool("lighthouse.serve_dns", false) { + if c.GetBool("lighthouse.am_lighthouse", false) { serveDns = true } else { l.Warn("DNS server refusing to run because this host is not a lighthouse.") } } - checkInterval := config.GetInt("timers.connection_alive_interval", 5) - pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10) + checkInterval := c.GetInt("timers.connection_alive_interval", 5) + pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10) ifConfig := &InterfaceConfig{ HostMap: hostMap, Inside: tun, Outside: udpConns[0], certState: cs, - Cipher: config.GetString("cipher", "aes"), + Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, lightHouse: lightHouse, checkInterval: checkInterval, pendingDeletionInterval: pendingDeletionInterval, - DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), - DropMulticast: config.GetBool("tun.drop_multicast", false), - UDPBatchSize: config.GetInt("listen.batch", 64), + DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), + DropMulticast: c.GetBool("tun.drop_multicast", false), routines: routines, MessageMetrics: messageMetrics, version: buildVersion, caPool: caPool, - disconnectInvalid: config.GetBool("pki.disconnect_invalid", false), + disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), ConntrackCacheTimeout: conntrackCacheTimeout, l: l, @@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns - ifce.RegisterConfigChangeCallbacks(config) + ifce.RegisterConfigChangeCallbacks(c) go handshakeManager.Run(ctx, ifce) go lightHouse.LhUpdateWorker(ctx, ifce) @@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // a context so that they can exit when the context is Done. - statsStart, err := startStats(l, config, buildVersion, configTest) + statsStart, err := startStats(l, c, buildVersion, configTest) + if err != nil { return nil, NewContextualError("Failed to start stats emitter", nil, err) } @@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } //TODO: check if we _should_ be emitting stats - go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10)) + go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) @@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L var dnsStart func() if amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, config) + dnsStart = dnsMain(l, hostMap, c) } return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil diff --git a/message_metrics.go b/message_metrics.go index ccd0207..b229cdf 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -4,8 +4,11 @@ import ( "fmt" "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/header" ) +//TODO: this can probably move into the header package + type MessageMetrics struct { rx [][]metrics.Counter tx [][]metrics.Counter @@ -14,7 +17,7 @@ type MessageMetrics struct { txUnknown metrics.Counter } -func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) { +func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) { m.rx[t][s].Inc(i) @@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64 } } } -func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) { +func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) { m.tx[t][s].Inc(i) diff --git a/outside.go b/outside.go index 57a00dc..a081ec0 100644 --- a/outside.go +++ b/outside.go @@ -10,6 +10,10 @@ import ( "github.com/golang/protobuf/proto" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -17,8 +21,8 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) { - err := header.Parse(packet) +func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { + err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? @@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, //l.Error("in packet ", header, packet[HeaderLen:]) // verify if we've seen this index before, otherwise respond to the handshake initiation - hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex) + hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex) var ci *ConnectionState if err == nil { ci = hostinfo.ConnectionState } - switch header.Type { - case message: - if !f.handleEncrypted(ci, addr, header) { + switch h.Type { + case header.Message: + if !f.handleEncrypted(ci, addr, h) { return } - f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache) + f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) // Fallthrough to the bottom to record incoming traffic - case lightHouse: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { + case header.LightHouse: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, addr, h) { return } - d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). WithField("packet", packet). @@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - lhh.HandleRequest(addr, hostinfo.hostId, d, f) + lhf(addr, hostinfo.vpnIp, d, f) // Fallthrough to the bottom to record incoming traffic - case test: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { + case header.Test: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, addr, h) { return } - d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). WithField("packet", packet). @@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - if header.Subtype == testRequest { + if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding f.handleHostRoaming(hostinfo, addr) - f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out) + f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out) } // Fallthrough to the bottom to record incoming traffic @@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they // are unauthenticated - case handshake: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - HandleIncomingHandshake(f, addr, packet, header, hostinfo) + case header.Handshake: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + HandleIncomingHandshake(f, addr, packet, h, hostinfo) return - case recvError: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - f.handleRecvError(addr, header) + case header.RecvError: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handleRecvError(addr, h) return - case closeTunnel: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { + case header.CloseTunnel: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, addr, h) { return } @@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return default: - f.messageMetrics.Rx(header.Type, header.Subtype, 1) + f.messageMetrics.Rx(h.Type, h.Subtype, 1) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) return } f.handleHostRoaming(hostinfo, addr) - f.connectionManager.In(hostinfo.hostId) + f.connectionManager.In(hostinfo.vpnIp) } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) { //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately - f.connectionManager.ClearIP(hostInfo.hostId) - f.connectionManager.ClearPendingDeletion(hostInfo.hostId) - f.lightHouse.DeleteVpnIP(hostInfo.hostId) + f.connectionManager.ClearIP(hostInfo.vpnIp) + f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp) + f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) if hasHostMapLock { f.hostMap.unlockedDeleteHostInfo(hostInfo) @@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) { // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote func (f *Interface) sendCloseTunnel(h *HostInfo) { - f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { - if hostDidRoam(hostinfo.remote, addr) { - if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { + if !hostinfo.remote.Equals(addr) { + if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) { hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") return } @@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. - if ci == nil || !ci.window.Check(f.l, header.MessageCounter) { - f.sendRecvError(addr, header.RemoteIndex) + if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { + f.sendRecvError(addr, h.RemoteIndex) return false } @@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header * } // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers -func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { +func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Do we at least have an ipv4 header worth of data? if len(data) < ipv4.HeaderLen { return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) @@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { // Accounting for a variable header length, do we have enough data for our src/dst tuples? minLen := ihl - if !fp.Fragment && fp.Protocol != fwProtoICMP { + if !fp.Fragment && fp.Protocol != firewall.ProtoICMP { minLen += minFwPacketLen } if len(data) < minLen { @@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP = binary.BigEndian.Uint32(data[12:16]) - fp.LocalIP = binary.BigEndian.Uint32(data[16:20]) - if fp.Fragment || fp.Protocol == fwProtoICMP { + fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) + fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 } else { @@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = binary.BigEndian.Uint32(data[12:16]) - fp.RemoteIP = binary.BigEndian.Uint32(data[16:20]) - if fp.Fragment || fp.Protocol == fwProtoICMP { + fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) + fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 } else { @@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { return nil } -func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) { +func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) { var err error - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb) + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb) if err != nil { return nil, err } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - hostinfo.logger(f.l).WithField("header", header). + hostinfo.logger(f.l).WithField("header", h). Debugln("dropping out of window packet") return nil, errors.New("out of window packet") } @@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { var err error - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") //TODO: maybe after build 64 is out? 06/14/2018 - NB @@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return } - f.connectionManager.In(hostinfo.hostId) + f.connectionManager.In(hostinfo.vpnIp) _, err = f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } } -func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { - f.messageMetrics.Tx(recvError, 0, 1) +func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { + f.messageMetrics.Tx(header.RecvError, 0, 1) //TODO: this should be a signed message so we can trust that we should drop the index - b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) + b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). @@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { +func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). diff --git a/outside_test.go b/outside_test.go index 6dd8bdc..682107b 100644 --- a/outside_test.go +++ b/outside_test.go @@ -4,12 +4,14 @@ import ( "net" "testing" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) func Test_newPacket(t *testing.T) { - p := &FirewallPacket{} + p := &firewall.Packet{} // length fail err := newPacket([]byte{0, 1}, true, p) @@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) { Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Options: []byte{0, 1, 0, 2}, - Protocol: fwProtoTCP, + Protocol: firewall.ProtoTCP, } b, _ = h.Marshal() @@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, true, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(fwProtoTCP)) - assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) + assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } diff --git a/punchy.go b/punchy.go index 153d0ac..90d7b94 100644 --- a/punchy.go +++ b/punchy.go @@ -1,6 +1,10 @@ package nebula -import "time" +import ( + "time" + + "github.com/slackhq/nebula/config" +) type Punchy struct { Punch bool @@ -8,7 +12,7 @@ type Punchy struct { Delay time.Duration } -func NewPunchyFromConfig(c *Config) *Punchy { +func NewPunchyFromConfig(c *config.C) *Punchy { p := &Punchy{} if c.IsSet("punchy.punch") { diff --git a/punchy_test.go b/punchy_test.go index 2ab570f..8b8cd1a 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -4,12 +4,14 @@ import ( "testing" "time" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func TestNewPunchyFromConfig(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) + l := util.NewTestLogger() + c := config.NewC(l) // Test defaults p := NewPunchyFromConfig(c) diff --git a/remote_list.go b/remote_list.go index 7c3e716..5135f38 100644 --- a/remote_list.go +++ b/remote_list.go @@ -5,14 +5,17 @@ import ( "net" "sort" "sync" + + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udpAddr, preferred bool) +type forEachFunc func(addr *udp.Addr, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -21,8 +24,8 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udpAddr `json:"learned,omitempty"` - Reported []*udpAddr `json:"reported,omitempty"` + Learned []*udp.Addr `json:"learned,omitempty"` + Reported []*udp.Addr `json:"reported,omitempty"` } //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion @@ -53,16 +56,16 @@ type RemoteList struct { sync.RWMutex // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udpAddr + addrs []*udp.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[uint32]*cache + cache map[iputil.VpnIp]*cache // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udpAddr + badRemotes []*udp.Addr // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool @@ -71,8 +74,8 @@ type RemoteList struct { // NewRemoteList creates a new empty RemoteList func NewRemoteList() *RemoteList { return &RemoteList{ - addrs: make([]*udpAddr, 0), - cache: make(map[uint32]*cache), + addrs: make([]*udp.Addr, 0), + cache: make(map[iputil.VpnIp]*cache), } } @@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr { +func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { if r == nil { return nil } @@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr { r.RLock() defer r.RUnlock() - c := make([]*udpAddr, len(r.addrs)) + c := make([]*udp.Addr, len(r.addrs)) for i, v := range r.addrs { c[i] = v.Copy() } @@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr { // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available //TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) { +func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { r.Lock() defer r.Unlock() if v4 := addr.IP.To4(); v4 != nil { @@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udpAddr, 0), - Reported: make([]*udpAddr, 0), + Learned: make([]*udp.Addr, 0), + Reported: make([]*udp.Addr, 0), } cm[vpnIp] = c } @@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap { } for owner, mc := range r.cache { - c := getOrMake(IntIp(owner).String()) + c := getOrMake(owner.String()) if mc.v4 != nil { if mc.v4.learned != nil { @@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udpAddr) { +func (r *RemoteList) BlockRemote(bad *udp.Addr) { r.Lock() defer r.Unlock() @@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) { } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udpAddr { +func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udpAddr, len(r.badRemotes)) + c := make([]*udp.Addr, len(r.badRemotes)) for i, v := range r.badRemotes { c[i] = v.Copy() } @@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool { +func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { for _, v := range r.badRemotes { if v.Equals(remote) { return true @@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) { // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) { // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} diff --git a/remote_list_test.go b/remote_list_test.go index 05d3887..2170930 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -4,6 +4,7 @@ import ( "net" "testing" + "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) @@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) { 0, 0, []*Ip4AndPort{ - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped - {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped - {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped - {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe - {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe }, - func(uint32, *Ip4AndPort) bool { return true }, + func(iputil.VpnIp, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( @@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) { NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe }, - func(uint32, *Ip6AndPort) bool { return true }, + func(iputil.VpnIp, *Ip6AndPort) bool { return true }, ) rl.Rebuild([]*net.IPNet{}) @@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) { 0, 0, []*Ip4AndPort{ - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, - {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port }, - func(uint32, *Ip4AndPort) bool { return true }, + func(iputil.VpnIp, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( @@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(uint32, *Ip6AndPort) bool { return true }, + func(iputil.VpnIp, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) { 0, 0, []*Ip4AndPort{ - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, - {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, - {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe - {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe + {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port }, - func(uint32, *Ip4AndPort) bool { return true }, + func(iputil.VpnIp, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( @@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(uint32, *Ip6AndPort) bool { return true }, + func(iputil.VpnIp, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { diff --git a/ssh.go b/ssh.go index dec676e..e640dde 100644 --- a/ssh.go +++ b/ssh.go @@ -15,7 +15,11 @@ import ( "syscall" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" + "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct { Address string } -func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { - c.RegisterReloadCallback(func(c *Config) { +func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { + c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { @@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) { +func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { //TODO conntrack list //TODO print firewall rules or hash? @@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error hm := listHostMap(hostMap) sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0 + return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) if err != nil { return err } @@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr } type lighthouseInfo struct { - VpnIP net.IP `json:"vpnIp"` + VpnIp string `json:"vpnIp"` Addrs *CacheMap `json:"addrs"` } @@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIP: int2ip(k), + VpnIp: k.String(), Addrs: v.CopyCache(), } x++ @@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0 + return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) if err != nil { return err } @@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) if err != nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } if !flags.LocalOnly { ifce.send( - closeTunnel, + header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, @@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udpAddr + var addr *udp.Addr if flags.Address != "" { - addr = NewUDPAddrFromString(flags.Address) + addr = udp.NewAddrFromString(flags.Address) if addr == nil { return w.WriteLine("Address could not be parsed") } } - hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp) + hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp) if addr != nil { hostInfo.SetRemote(addr) } @@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No address was provided") } - addr := NewUDPAddrFromString(flags.Address) + addr := udp.NewAddrFromString(flags.Address) if addr == nil { return w.WriteLine("Address could not be parsed") } @@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) if err != nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) if err != nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := ip2int(parsedIp) + vpnIp := iputil.Ip2VpnIp(parsedIp) if vpnIp == 0 { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp) + hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) if err != nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } diff --git a/stats.go b/stats.go index 94e75ef..3993455 100644 --- a/stats.go +++ b/stats.go @@ -15,12 +15,13 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" ) // startStats initializes stats from config. On success, if any futher work // is needed to serve stats, it returns a func to handle that work. If no // work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) { +func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { return nil, nil @@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo return startFn, nil } -func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { +func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { proto := c.GetString("stats.protocol", "tcp") host := c.GetString("stats.host", "") if host == "" { @@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) { +func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") diff --git a/timeout.go b/timeout.go index 6e80614..fe63f3e 100644 --- a/timeout.go +++ b/timeout.go @@ -2,12 +2,14 @@ package nebula import ( "time" + + "github.com/slackhq/nebula/firewall" ) // How many timer objects should be cached const timerCacheMax = 50000 -var emptyFWPacket = FirewallPacket{} +var emptyFWPacket = firewall.Packet{} type TimerWheel struct { // Current tick @@ -42,7 +44,7 @@ type TimeoutList struct { // Represents an item within a tick type TimeoutItem struct { - Packet FirewallPacket + Packet firewall.Packet Next *TimeoutItem } @@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel { return &tw } -// Add will add a FirewallPacket to the wheel in it's proper timeout -func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem { +// Add will add a firewall.Packet to the wheel in it's proper timeout +func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem { // Check and see if we should progress the tick tw.advance(time.Now()) @@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem return ti } -func (tw *TimerWheel) Purge() (FirewallPacket, bool) { +func (tw *TimerWheel) Purge() (firewall.Packet, bool) { if tw.expired.Head == nil { return emptyFWPacket, false } diff --git a/timeout_system.go b/timeout_system.go index e458458..72f6af9 100644 --- a/timeout_system.go +++ b/timeout_system.go @@ -3,6 +3,8 @@ package nebula import ( "sync" "time" + + "github.com/slackhq/nebula/iputil" ) // How many timer objects should be cached @@ -43,7 +45,7 @@ type SystemTimeoutList struct { // Represents an item within a tick type SystemTimeoutItem struct { - Item uint32 + Item iputil.VpnIp Next *SystemTimeoutItem } @@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel { return &tw } -func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem { +func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem { tw.lock.Lock() defer tw.lock.Unlock() diff --git a/timeout_system_test.go b/timeout_system_test.go index 712725d..41c64a0 100644 --- a/timeout_system_test.go +++ b/timeout_system_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) @@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) { func TestSystemTimerWheel_Add(t *testing.T) { tw := NewSystemTimerWheel(time.Second, time.Second*10) - fp1 := ip2int(net.ParseIP("1.2.3.4")) + fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4")) tw.Add(fp1, time.Second*1) // Make sure we set head and tail properly @@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) { assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we only modify head - fp2 := ip2int(net.ParseIP("1.2.3.4")) + fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4")) tw.Add(fp2, time.Second*1) assert.Equal(t, fp2, tw.wheel[2].Head.Item) assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) @@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) { assert.NotNil(t, tw.lastTick) assert.Equal(t, 0, tw.current) - fps := []uint32{9, 10, 11, 12} + fps := []iputil.VpnIp{9, 10, 11, 12} //fp1 := ip2int(net.ParseIP("1.2.3.4")) diff --git a/timeout_test.go b/timeout_test.go index 2f4ceb1..9678b35 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" ) @@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) { func TestTimerWheel_Add(t *testing.T) { tw := NewTimerWheel(time.Second, time.Second*10) - fp1 := FirewallPacket{} + fp1 := firewall.Packet{} tw.Add(fp1, time.Second*1) // Make sure we set head and tail properly @@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) { assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we only modify head - fp2 := FirewallPacket{} + fp2 := firewall.Packet{} tw.Add(fp2, time.Second*1) assert.Equal(t, fp2, tw.wheel[2].Head.Packet) assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet) @@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) { assert.NotNil(t, tw.lastTick) assert.Equal(t, 0, tw.current) - fps := []FirewallPacket{ + fps := []firewall.Packet{ {LocalIP: 1}, {LocalIP: 2}, {LocalIP: 3}, diff --git a/tun_common.go b/tun_common.go index 68f11ea..0a81eab 100644 --- a/tun_common.go +++ b/tun_common.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "strconv" + + "github.com/slackhq/nebula/config" ) const DEFAULT_MTU = 1300 @@ -14,10 +16,10 @@ type route struct { via *net.IP } -func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { +func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) { var err error - r := config.Get("tun.routes") + r := c.Get("tun.routes") if r == nil { return []route{}, nil } @@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { return routes, nil } -func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) { +func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) { var err error - r := config.Get("tun.unsafe_routes") + r := c.Get("tun.unsafe_routes") if r == nil { return []route{}, nil } @@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) { rMtu, ok := m["mtu"] if !ok { - rMtu = config.GetInt("tun.mtu", DEFAULT_MTU) + rMtu = c.GetInt("tun.mtu", DEFAULT_MTU) } mtu, ok := rMtu.(int) diff --git a/tun_test.go b/tun_test.go index 08ff10f..6043eb5 100644 --- a/tun_test.go +++ b/tun_test.go @@ -5,12 +5,14 @@ import ( "net" "testing" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" ) func Test_parseRoutes(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) + l := util.NewTestLogger() + c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config @@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) { } func Test_parseUnsafeRoutes(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) + l := util.NewTestLogger() + c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config diff --git a/udp/conn.go b/udp/conn.go new file mode 100644 index 0000000..f967a9a --- /dev/null +++ b/udp/conn.go @@ -0,0 +1,20 @@ +package udp + +import ( + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" +) + +const MTU = 9001 + +type EncReader func( + addr *Addr, + out []byte, + packet []byte, + header *header.H, + fwPacket *firewall.Packet, + lhh LightHouseHandlerFunc, + nb []byte, + q int, + localCache firewall.ConntrackCache, +) diff --git a/udp/temp.go b/udp/temp.go new file mode 100644 index 0000000..f4ef1b5 --- /dev/null +++ b/udp/temp.go @@ -0,0 +1,14 @@ +package udp + +import ( + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" +) + +//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare + +type EncWriter interface { + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) +} + +type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) diff --git a/udp_all.go b/udp/udp_all.go similarity index 63% rename from udp_all.go rename to udp/udp_all.go index 1ce8135..a4a462e 100644 --- a/udp_all.go +++ b/udp/udp_all.go @@ -1,4 +1,4 @@ -package nebula +package udp import ( "encoding/json" @@ -7,32 +7,34 @@ import ( "strconv" ) -type udpAddr struct { +type m map[string]interface{} + +type Addr struct { IP net.IP Port uint16 } -func NewUDPAddr(ip net.IP, port uint16) *udpAddr { - addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port} +func NewAddr(ip net.IP, port uint16) *Addr { + addr := Addr{IP: make([]byte, net.IPv6len), Port: port} copy(addr.IP, ip.To16()) return &addr } -func NewUDPAddrFromString(s string) *udpAddr { - ip, port, err := parseIPAndPort(s) +func NewAddrFromString(s string) *Addr { + ip, port, err := ParseIPAndPort(s) //TODO: handle err _ = err - return &udpAddr{IP: ip.To16(), Port: port} + return &Addr{IP: ip.To16(), Port: port} } -func (ua *udpAddr) Equals(t *udpAddr) bool { +func (ua *Addr) Equals(t *Addr) bool { if t == nil || ua == nil { return t == nil && ua == nil } return ua.IP.Equal(t.IP) && ua.Port == t.Port } -func (ua *udpAddr) String() string { +func (ua *Addr) String() string { if ua == nil { return "" } @@ -40,7 +42,7 @@ func (ua *udpAddr) String() string { return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) } -func (ua *udpAddr) MarshalJSON() ([]byte, error) { +func (ua *Addr) MarshalJSON() ([]byte, error) { if ua == nil { return nil, nil } @@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) { return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) } -func (ua *udpAddr) Copy() *udpAddr { +func (ua *Addr) Copy() *Addr { if ua == nil { return nil } - nu := udpAddr{ + nu := Addr{ Port: ua.Port, IP: make(net.IP, len(ua.IP)), } @@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr { return &nu } -func parseIPAndPort(s string) (net.IP, uint16, error) { +func ParseIPAndPort(s string) (net.IP, uint16, error) { rIp, sPort, err := net.SplitHostPort(s) if err != nil { return nil, 0, err diff --git a/udp_android.go b/udp/udp_android.go similarity index 93% rename from udp_android.go rename to udp/udp_android.go index c41cf8f..d2812a8 100644 --- a/udp_android.go +++ b/udp/udp_android.go @@ -1,7 +1,7 @@ //go:build !e2e_testing // +build !e2e_testing -package nebula +package udp import ( "fmt" @@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { return nil } diff --git a/udp_darwin.go b/udp/udp_darwin.go similarity index 94% rename from udp_darwin.go rename to udp/udp_darwin.go index 795199a..69d0c58 100644 --- a/udp_darwin.go +++ b/udp/udp_darwin.go @@ -1,7 +1,7 @@ //go:build !e2e_testing // +build !e2e_testing -package nebula +package udp // Darwin support is primarily implemented in udp_generic, besides NewListenConfig @@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { file, err := u.File() if err != nil { return err diff --git a/udp_freebsd.go b/udp/udp_freebsd.go similarity index 93% rename from udp_freebsd.go rename to udp/udp_freebsd.go index 2858919..10ff94b 100644 --- a/udp_freebsd.go +++ b/udp/udp_freebsd.go @@ -1,7 +1,7 @@ //go:build !e2e_testing // +build !e2e_testing -package nebula +package udp // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig @@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { return nil } diff --git a/udp_generic.go b/udp/udp_generic.go similarity index 57% rename from udp_generic.go rename to udp/udp_generic.go index 28525d9..c314bbe 100644 --- a/udp_generic.go +++ b/udp/udp_generic.go @@ -5,7 +5,7 @@ // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. -package nebula +package udp import ( "context" @@ -13,36 +13,39 @@ import ( "net" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" ) -type udpConn struct { +type Conn struct { *net.UDPConn l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { +func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { - return &udpConn{UDPConn: uc, l: l}, nil + return &Conn{UDPConn: uc, l: l}, nil } return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { +func (uc *Conn) WriteTo(b []byte, addr *Addr) error { _, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) return err } -func (uc *udpConn) LocalAddr() (*udpAddr, error) { +func (uc *Conn) LocalAddr() (*Addr, error) { a := uc.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &udpAddr{IP: make([]byte, len(v.IP))} + addr := &Addr{IP: make([]byte, len(v.IP))} copy(addr.IP, v.IP) addr.Port = uint16(v.Port) return addr, nil @@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) { } } -func (u *udpConn) reloadConfig(c *Config) { +func (u *Conn) ReloadConfig(c *config.C) { // TODO } -func NewUDPStatsEmitter(udpConns []*udpConn) func() { +func NewUDPStatsEmitter(udpConns []*Conn) func() { // No UDP stats for non-linux return func() {} } @@ -65,32 +68,24 @@ type rawMessage struct { Len uint32 } -func (u *udpConn) ListenOut(f *Interface, q int) { - plaintext := make([]byte, mtu) - buffer := make([]byte, mtu) - header := &Header{} - fwPacket := &FirewallPacket{} - udpAddr := &udpAddr{IP: make([]byte, 16)} +func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + buffer := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) - lhh := f.lightHouse.NewRequestHandler() - - conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) - for { // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) if err != nil { - f.l.WithError(err).Error("Failed to read packets") + u.l.WithError(err).Error("Failed to read packets") continue } udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l)) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } - -func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool { - return !addr.Equals(newaddr) -} diff --git a/udp_linux.go b/udp/udp_linux.go similarity index 82% rename from udp_linux.go rename to udp/udp_linux.go index 45e2fe6..3de397b 100644 --- a/udp_linux.go +++ b/udp/udp_linux.go @@ -1,7 +1,7 @@ //go:build !android && !e2e_testing // +build !android,!e2e_testing -package nebula +package udp import ( "encoding/binary" @@ -12,14 +12,18 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) //TODO: make it support reload as best you can! -type udpConn struct { +type Conn struct { sysFd int l *logrus.Logger + batch int } var x int @@ -41,7 +45,7 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { +func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { syscall.ForkLock.RLock() fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { @@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &udpConn{sysFd: fd, l: l}, err + return &Conn{sysFd: fd, l: l, batch: batch}, err } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { return nil } -func (u *udpConn) SetRecvBuffer(n int) error { +func (u *Conn) SetRecvBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } -func (u *udpConn) SetSendBuffer(n int) error { +func (u *Conn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } -func (u *udpConn) GetRecvBuffer() (int, error) { +func (u *Conn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } -func (u *udpConn) GetSendBuffer() (int, error) { +func (u *Conn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *udpConn) LocalAddr() (*udpAddr, error) { +func (u *Conn) LocalAddr() (*Addr, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { return nil, err } - addr := &udpAddr{} + addr := &Addr{} switch sa := sa.(type) { case *unix.SockaddrInet4: addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() @@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) { return addr, nil } -func (u *udpConn) ListenOut(f *Interface, q int) { - plaintext := make([]byte, mtu) - header := &Header{} - fwPacket := &FirewallPacket{} - udpAddr := &udpAddr{} +func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + udpAddr := &Addr{} nb := make([]byte, 12, 12) - lhh := f.lightHouse.NewRequestHandler() - //TODO: should we track this? //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) - msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) + msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti - if f.udpBatchSize == 1 { + if u.batch == 1 { read = u.ReadSingle } - conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) - for { n, err := read(msgs) if err != nil { @@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) { for i := 0; i < n; i++ { udpAddr.IP = names[i][8:24] udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l)) + r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } } -func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) { +func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMSG, @@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) { } } -func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) { +func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMMSG, @@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { +func (u *Conn) WriteTo(b []byte, addr *Addr) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { } } -func (u *udpConn) reloadConfig(c *Config) { +func (u *Conn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { err := u.SetRecvBuffer(b) @@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) { } } -func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error { +func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error { var vallen uint32 = 4 * _SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { @@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error { return nil } -func NewUDPStatsEmitter(udpConns []*udpConn) func() { +func NewUDPStatsEmitter(udpConns []*Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge var meminfo _SK_MEMINFO @@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() { } } } - -func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool { - return !addr.Equals(newaddr) -} diff --git a/udp_linux_32.go b/udp/udp_linux_32.go similarity index 88% rename from udp_linux_32.go rename to udp/udp_linux_32.go index de01862..06cd382 100644 --- a/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -4,7 +4,7 @@ // +build !android // +build !e2e_testing -package nebula +package udp import ( "golang.org/x/sys/unix" @@ -30,13 +30,13 @@ type rawMessage struct { Len uint32 } -func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { - buffers[i] = make([]byte, mtu) + buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) //TODO: this is still silly, no need for an array diff --git a/udp_linux_64.go b/udp/udp_linux_64.go similarity index 89% rename from udp_linux_64.go rename to udp/udp_linux_64.go index f88b899..c442405 100644 --- a/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -4,7 +4,7 @@ // +build !android // +build !e2e_testing -package nebula +package udp import ( "golang.org/x/sys/unix" @@ -33,13 +33,13 @@ type rawMessage struct { Pad0 [4]byte } -func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { - buffers[i] = make([]byte, mtu) + buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) //TODO: this is still silly, no need for an array diff --git a/udp_tester.go b/udp/udp_tester.go similarity index 55% rename from udp_tester.go rename to udp/udp_tester.go index 9350a4b..0157a8e 100644 --- a/udp_tester.go +++ b/udp/udp_tester.go @@ -1,16 +1,19 @@ //go:build e2e_testing // +build e2e_testing -package nebula +package udp import ( "fmt" "net" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" ) -type UdpPacket struct { +type Packet struct { ToIp net.IP ToPort uint16 FromIp net.IP @@ -18,8 +21,8 @@ type UdpPacket struct { Data []byte } -func (u *UdpPacket) Copy() *UdpPacket { - n := &UdpPacket{ +func (u *Packet) Copy() *Packet { + n := &Packet{ ToIp: make(net.IP, len(u.ToIp)), ToPort: u.ToPort, FromIp: make(net.IP, len(u.FromIp)), @@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket { return n } -type udpConn struct { - addr *udpAddr +type Conn struct { + Addr *Addr - rxPackets chan *UdpPacket // Packets to receive into nebula - txPackets chan *UdpPacket // Packets transmitted outside by nebula + RxPackets chan *Packet // Packets to receive into nebula + TxPackets chan *Packet // Packets transmitted outside by nebula l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) { - return &udpConn{ - addr: &udpAddr{net.ParseIP(ip), uint16(port)}, - rxPackets: make(chan *UdpPacket, 1), - txPackets: make(chan *UdpPacket, 1), +func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) { + return &Conn{ + Addr: &Addr{net.ParseIP(ip), uint16(port)}, + RxPackets: make(chan *Packet, 1), + TxPackets: make(chan *Packet, 1), l: l, }, nil } @@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error // Send will place a UdpPacket onto the receive queue for nebula to consume // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send -func (u *udpConn) Send(packet *UdpPacket) { - h := &Header{} +func (u *Conn) Send(packet *Packet) { + h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) } @@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) { WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("dataLen", len(packet.Data)). Info("UDP receiving injected packet") - u.rxPackets <- packet + u.RxPackets <- packet } // Get will pull a UdpPacket from the transmit queue // nebula meant to send this message on the network, it will be encrypted // packets were ingested from the tun side (in most cases), you can send them with Tun.Send -func (u *udpConn) Get(block bool) *UdpPacket { +func (u *Conn) Get(block bool) *Packet { if block { - return <-u.txPackets + return <-u.TxPackets } select { - case p := <-u.txPackets: + case p := <-u.TxPackets: return p default: return nil @@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { - p := &UdpPacket{ +func (u *Conn) WriteTo(b []byte, addr *Addr) error { + p := &Packet{ Data: make([]byte, len(b), len(b)), FromIp: make([]byte, 16), - FromPort: u.addr.Port, + FromPort: u.Addr.Port, ToIp: make([]byte, 16), ToPort: addr.Port, } copy(p.Data, b) copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.addr.IP.To16()) + copy(p.FromIp, u.Addr.IP.To16()) - u.txPackets <- p + u.TxPackets <- p return nil } -func (u *udpConn) ListenOut(f *Interface, q int) { - plaintext := make([]byte, mtu) - header := &Header{} - fwPacket := &FirewallPacket{} - ua := &udpAddr{IP: make([]byte, 16)} +func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + ua := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) - lhh := f.lightHouse.NewRequestHandler() - conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) - for { - p := <-u.rxPackets + p := <-u.RxPackets ua.Port = p.FromPort copy(ua.IP, p.FromIp.To16()) - f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l)) + r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } -func (u *udpConn) reloadConfig(*Config) {} +func (u *Conn) ReloadConfig(*config.C) {} -func NewUDPStatsEmitter(_ []*udpConn) func() { +func NewUDPStatsEmitter(_ []*Conn) func() { // No UDP stats for non-linux return func() {} } -func (u *udpConn) LocalAddr() (*udpAddr, error) { - return u.addr, nil +func (u *Conn) LocalAddr() (*Addr, error) { + return u.Addr, nil } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { return nil } - -func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool { - return !addr.Equals(newaddr) -} diff --git a/udp_windows.go b/udp/udp_windows.go similarity index 92% rename from udp_windows.go rename to udp/udp_windows.go index b983d97..1f2ce64 100644 --- a/udp_windows.go +++ b/udp/udp_windows.go @@ -1,7 +1,7 @@ //go:build !e2e_testing // +build !e2e_testing -package nebula +package udp // Windows support is primarily implemented in udp_generic, besides NewListenConfig @@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() error { +func (u *Conn) Rebind() error { return nil } diff --git a/main_test.go b/util/main.go similarity index 81% rename from main_test.go rename to util/main.go index f638011..0d84c73 100644 --- a/main_test.go +++ b/util/main.go @@ -1,4 +1,4 @@ -package nebula +package util import ( "io/ioutil" @@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger { } switch v { - case "1": - // This is the default level but we are being explicit - l.SetLevel(logrus.InfoLevel) case "2": l.SetLevel(logrus.DebugLevel) case "3": l.SetLevel(logrus.TraceLevel) + default: + l.SetLevel(logrus.InfoLevel) } return l