Rework some things into packages (#489)

This commit is contained in:
Nate Brown 2021-11-03 20:54:04 -05:00 committed by GitHub
parent 1f75fb3c73
commit bcabcfdaca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 2526 additions and 2374 deletions

View File

@ -4,11 +4,15 @@ import (
"fmt" "fmt"
"net" "net"
"regexp" "regexp"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
type AllowList struct { type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny // The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *CIDR6Tree cidrTree *cidr.Tree6
} }
type RemoteAllowList struct { type RemoteAllowList struct {
@ -16,7 +20,7 @@ type RemoteAllowList struct {
// Inside Range Specific, keys of this tree are inside CIDRs and values // Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList // are *AllowList
insideAllowLists *CIDR6Tree insideAllowLists *cidr.Tree6
} }
type LocalAllowList struct { type LocalAllowList struct {
@ -31,6 +35,223 @@ type AllowListNameRule struct {
Allow bool Allow bool
} }
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule
handleKey := func(key string, value interface{}) (bool, error) {
if key == "interfaces" {
var err error
nameRules, err = getAllowListInterfaces(k, value)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
al, err := newAllowListFromConfig(c, k, handleKey)
if err != nil {
return nil, err
}
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
}
func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) {
al, err := newAllowListFromConfig(c, k, nil)
if err != nil {
return nil, err
}
remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey)
if err != nil {
return nil, err
}
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
return newAllowList(k, r, handleKey)
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
tree := cidr.NewTree6()
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
firstValue bool
allValuesMatch bool
defaultSet bool
allValues bool
}
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue)
if err != nil {
return nil, err
}
if handled {
continue
}
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, ipNet, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(ipNet, value)
maskBits, maskSize := ipNet.Mask.Size()
var rules *allowListRules
if maskSize == 32 {
rules = &rules4
} else {
rules = &rules6
}
if rules.firstValue {
rules.allValues = value
rules.firstValue = false
} else {
if value != rules.allValues {
rules.allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0 or ::/0
if maskBits == 0 {
rules.defaultSet = true
}
}
if !rules4.defaultSet {
if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
if !rules6.defaultSet {
if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0")
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
}
return &AllowList{cidrTree: tree}, nil
}
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
remoteAllowRanges := cidr.NewTree6()
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil {
return nil, err
}
_, ipNet, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
remoteAllowRanges.AddCIDR(ipNet, allowList)
}
return remoteAllowRanges, nil
}
func (al *AllowList) Allow(ip net.IP) bool { func (al *AllowList) Allow(ip net.IP) bool {
if al == nil { if al == nil {
return true return true
@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool {
} }
} }
func (al *AllowList) AllowIpV4(ip uint32) bool { func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
if al == nil { if al == nil {
return true return true
} }
@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
return al.AllowList.Allow(ip) return al.AllowList.Allow(ip)
} }
func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool { func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) { if !al.getInsideAllowList(vpnIp).Allow(ip) {
return false return false
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(ip)
} }
func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool { func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
if al == nil { if al == nil {
return true return true
} }
@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
return al.AllowList.AllowIpV4(ip) return al.AllowList.AllowIpV4(ip)
} }
func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool { func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
if al == nil { if al == nil {
return true return true
} }
@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
return al.AllowList.AllowIpV6(hi, lo) return al.AllowList.AllowIpV6(hi, lo)
} }
func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList { func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
if al.insideAllowLists != nil { if al.insideAllowLists != nil {
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if inside != nil { if inside != nil {

View File

@ -5,21 +5,110 @@ import (
"regexp" "regexp"
"testing" "testing"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewAllowListFromConfig(t *testing.T) {
l := util.NewTestLogger()
c := config.NewC(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"::/0": false,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
if assert.NoError(t, err) {
assert.NotNil(t, lr)
}
}
func TestAllowList_Allow(t *testing.T) { func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
tree := NewCIDR6Tree() tree := cidr.NewTree6()
tree.AddCIDR(getCIDR("0.0.0.0/0"), true) tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
tree.AddCIDR(getCIDR("10.0.0.0/8"), false) tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
tree.AddCIDR(getCIDR("10.42.42.42/32"), true) tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
tree.AddCIDR(getCIDR("10.42.0.0/16"), true) tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), true) tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), false) tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
tree.AddCIDR(getCIDR("::1/128"), true) tree.AddCIDR(cidr.Parse("::1/128"), true)
tree.AddCIDR(getCIDR("::2/128"), false) tree.AddCIDR(cidr.Parse("::2/128"), false)
al := &AllowList{cidrTree: tree} al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))

View File

@ -3,11 +3,12 @@ package nebula
import ( import (
"testing" "testing"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBits(t *testing.T) { func TestBits(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
b := NewBits(10) b := NewBits(10)
// make sure it is the right size // make sure it is the right size
@ -75,7 +76,7 @@ func TestBits(t *testing.T) {
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
} }
func TestBitsOutOfWindowCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
} }
func TestBitsLostCounter(t *testing.T) { func TestBitsLostCounter(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()

View File

@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
) )
type CertState struct { type CertState struct {
@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
return cs, nil return cs, nil
} }
func NewCertStateFromConfig(c *Config) (*CertState, error) { func NewCertStateFromConfig(c *config.C) (*CertState, error) {
var pemPrivateKey []byte var pemPrivateKey []byte
var err error var err error
@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
return NewCertState(nebulaCert, rawKey) return NewCertState(nebulaCert, rawKey)
} }
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) { func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
var rawCA []byte var rawCA []byte
var err error var err error

10
cidr/parse.go Normal file
View File

@ -0,0 +1,10 @@
package cidr
import "net"
// Parse is a convenience function that returns only the IPNet
// This function ignores errors since it is primarily a test helper, the result could be nil
func Parse(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

View File

@ -1,39 +1,39 @@
package nebula package cidr
import ( import (
"encoding/binary"
"fmt"
"net" "net"
"github.com/slackhq/nebula/iputil"
) )
type CIDRNode struct { type Node struct {
left *CIDRNode left *Node
right *CIDRNode right *Node
parent *CIDRNode parent *Node
value interface{} value interface{}
} }
type CIDRTree struct { type Tree4 struct {
root *CIDRNode root *Node
} }
const ( const (
startbit = uint32(0x80000000) startbit = iputil.VpnIp(0x80000000)
) )
func NewCIDRTree() *CIDRTree { func NewTree4() *Tree4 {
tree := new(CIDRTree) tree := new(Tree4)
tree.root = &CIDRNode{} tree.root = &Node{}
return tree return tree
} }
func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) { func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
next := tree.root next := tree.root
ip := ip2int(cidr.IP) ip := iputil.Ip2VpnIp(cidr.IP)
mask := ip2int(cidr.Mask) mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree // Find our last ancestor in the tree
for bit&mask != 0 { for bit&mask != 0 {
@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
// Build up the rest of the tree we don't already have // Build up the rest of the tree we don't already have
for bit&mask != 0 { for bit&mask != 0 {
next = &CIDRNode{} next = &Node{}
next.parent = node next.parent = node
if ip&bit != 0 { if ip&bit != 0 {
@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
} }
// Finds the first match, which may be the least specific // Finds the first match, which may be the least specific
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) { func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
} }
// Finds the most specific match // Finds the most specific match
func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) { func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
} }
// Finds the most specific match // Finds the most specific match
func (tree *CIDRTree) Match(ip uint32) (value interface{}) { func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
lastNode := node lastNode := node
@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
} }
return value return value
} }
// A helper type to avoid converting to IP when logging
type IntIp uint32
func (ip IntIp) String() string {
return fmt.Sprintf("%v", int2ip(uint32(ip)))
}
func (ip IntIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}

153
cidr/tree4_test.go Normal file
View File

@ -0,0 +1,153 @@
package cidr
import (
"net"
"testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_Contains(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4a", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_MostSpecificContains(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4b", "4.1.1.2"},
{"4c", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_Match(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
tests := []struct {
Result interface{}
IP string
}{
{"1a", "4.1.1.0"},
{"1b", "4.1.1.1"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewTree4()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewTree4()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}

View File

@ -1,26 +1,27 @@
package nebula package cidr
import ( import (
"encoding/binary"
"net" "net"
"github.com/slackhq/nebula/iputil"
) )
const startbit6 = uint64(1 << 63) const startbit6 = uint64(1 << 63)
type CIDR6Tree struct { type Tree6 struct {
root4 *CIDRNode root4 *Node
root6 *CIDRNode root6 *Node
} }
func NewCIDR6Tree() *CIDR6Tree { func NewTree6() *Tree6 {
tree := new(CIDR6Tree) tree := new(Tree6)
tree.root4 = &CIDRNode{} tree.root4 = &Node{}
tree.root6 = &CIDRNode{} tree.root6 = &Node{}
return tree return tree
} }
func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) { func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
var node, next *CIDRNode var node, next *Node
cidrIP, ipv4 := isIPV4(cidr.IP) cidrIP, ipv4 := isIPV4(cidr.IP)
if ipv4 { if ipv4 {
@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
} }
for i := 0; i < len(cidrIP); i += 4 { for i := 0; i < len(cidrIP); i += 4 {
ip := binary.BigEndian.Uint32(cidrIP[i : i+4]) ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4]) mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
bit := startbit bit := startbit
// Find our last ancestor in the tree // Find our last ancestor in the tree
@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
// Build up the rest of the tree we don't already have // Build up the rest of the tree we don't already have
for bit&mask != 0 { for bit&mask != 0 {
next = &CIDRNode{} next = &Node{}
next.parent = node next.parent = node
if ip&bit != 0 { if ip&bit != 0 {
@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
} }
// Finds the most specific match // Finds the most specific match
func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) { func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
var node *CIDRNode var node *Node
wholeIP, ipv4 := isIPV4(ip) wholeIP, ipv4 := isIPV4(ip)
if ipv4 { if ipv4 {
@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
} }
for i := 0; i < len(wholeIP); i += 4 { for i := 0; i < len(wholeIP); i += 4 {
ip := ip2int(wholeIP[i : i+4]) ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
bit := startbit bit := startbit
for node != nil { for node != nil {
@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
return value return value
} }
func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) { func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root4 node := tree.root4
@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
return value return value
} }
func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
ip := hi ip := hi
node := tree.root6 node := tree.root6

View File

@ -1,6 +1,7 @@
package nebula package cidr
import ( import (
"encoding/binary"
"net" "net"
"testing" "testing"
@ -8,17 +9,17 @@ import (
) )
func TestCIDR6Tree_MostSpecificContains(t *testing.T) { func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
tree := NewCIDR6Tree() tree := NewTree6()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a") tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b") tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c") tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct { tests := []struct {
Result interface{} Result interface{}
@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP))) assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
} }
tree = NewCIDR6Tree() tree = NewTree6()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
tree.AddCIDR(getCIDR("::/0"), "cool6") tree.AddCIDR(Parse("::/0"), "cool6")
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0"))) assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255"))) assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::"))) assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
} }
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
tree := NewCIDR6Tree() tree := NewTree6()
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct { tests := []struct {
Result interface{} Result interface{}
@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
ip := NewIp6AndPort(net.ParseIP(tt.IP), 0) ip := net.ParseIP(tt.IP)
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo)) hi := binary.BigEndian.Uint64(ip[:8])
lo := binary.BigEndian.Uint64(ip[8:])
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
} }
} }

View File

@ -1,157 +0,0 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_Contains(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4a", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_MostSpecificContains(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4b", "4.1.1.2"},
{"4c", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_Match(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
tests := []struct {
Result interface{}
IP string
}{
{"1a", "4.1.1.0"},
{"1b", "4.1.1.1"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}
func getCIDR(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

View File

@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
) )
// A version string that can be set with // A version string that can be set with
@ -49,14 +50,14 @@ func main() {
l := logrus.New() l := logrus.New()
l.Out = os.Stdout l.Out = os.Stdout
config := nebula.NewConfig(l) c := config.NewC(l)
err := config.Load(*configPath) err := c.Load(*configPath)
if err != nil { if err != nil {
fmt.Printf("failed to load config: %s", err) fmt.Printf("failed to load config: %s", err)
os.Exit(1) os.Exit(1)
} }
c, err := nebula.Main(config, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {
case nebula.ContextualError: case nebula.ContextualError:
@ -68,8 +69,8 @@ func main() {
} }
if !*configTest { if !*configTest {
c.Start() ctrl.Start()
c.ShutdownBlock() ctrl.ShutdownBlock()
} }
os.Exit(0) os.Exit(0)

View File

@ -9,6 +9,7 @@ import (
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
) )
var logger service.Logger var logger service.Logger
@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error {
l := logrus.New() l := logrus.New()
HookLogger(l) HookLogger(l)
config := nebula.NewConfig(l) c := config.NewC(l)
err := config.Load(*p.configPath) err := c.Load(*p.configPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to load config: %s", err) return fmt.Errorf("failed to load config: %s", err)
} }
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil) p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
) )
// A version string that can be set with // A version string that can be set with
@ -43,14 +44,14 @@ func main() {
l := logrus.New() l := logrus.New()
l.Out = os.Stdout l.Out = os.Stdout
config := nebula.NewConfig(l) c := config.NewC(l)
err := config.Load(*configPath) err := c.Load(*configPath)
if err != nil { if err != nil {
fmt.Printf("failed to load config: %s", err) fmt.Printf("failed to load config: %s", err)
os.Exit(1) os.Exit(1)
} }
c, err := nebula.Main(config, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {
case nebula.ContextualError: case nebula.ContextualError:
@ -62,8 +63,8 @@ func main() {
} }
if !*configTest { if !*configTest {
c.Start() ctrl.Start()
c.ShutdownBlock() ctrl.ShutdownBlock()
} }
os.Exit(0) os.Exit(0)

611
config.go
View File

@ -1,611 +0,0 @@
package nebula
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"os/signal"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type Config struct {
path string
files []string
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*Config)
l *logrus.Logger
}
func NewConfig(l *logrus.Logger) *Config {
return &Config{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
// Load will find all yaml files within path and load them in lexical order
func (c *Config) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path, true)
if err != nil {
return err
}
if len(c.files) == 0 {
return fmt.Errorf("no config files found at %s", path)
}
sort.Strings(c.files)
err = c.parse()
if err != nil {
return err
}
return nil
}
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
// These functions should return quickly or spawn their own go routine if they will take a while
func (c *Config) RegisterReloadCallback(f func(*Config)) {
c.callbacks = append(c.callbacks, f)
}
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
// If k is an empty string the entire config is tested.
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
// there is change when there actually wasn't any.
func (c *Config) HasChanged(k string) bool {
if c.oldSettings == nil {
return false
}
var (
nv interface{}
ov interface{}
)
if k == "" {
nv = c.Settings
ov = c.oldSettings
k = "all settings"
} else {
nv = c.get(k, c.Settings)
ov = c.get(k, c.oldSettings)
}
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
}
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *Config) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for {
select {
case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}
}()
}
func (c *Config) ReloadConfig() {
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
for _, v := range c.callbacks {
v(c)
}
}
// GetString will get the string for k or return the default d if not found or invalid
func (c *Config) GetString(k, d string) string {
r := c.Get(k)
if r == nil {
return d
}
return fmt.Sprintf("%v", r)
}
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
func (c *Config) GetStringSlice(k string, d []string) []string {
r := c.Get(k)
if r == nil {
return d
}
rv, ok := r.([]interface{})
if !ok {
return d
}
v := make([]string, len(rv))
for i := 0; i < len(v); i++ {
v[i] = fmt.Sprintf("%v", rv[i])
}
return v
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
return v
}
// GetInt will get the int for k or return the default d if not found or invalid
func (c *Config) GetInt(k string, d int) int {
r := c.GetString(k, strconv.Itoa(d))
v, err := strconv.Atoi(r)
if err != nil {
return d
}
return v
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *Config) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
v, err := strconv.ParseBool(r)
if err != nil {
switch r {
case "y", "yes":
return true
case "n", "no":
return false
}
return d
}
return v
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
v, err := time.ParseDuration(r)
if err != nil {
return d
}
return v
}
func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule
handleKey := func(key string, value interface{}) (bool, error) {
if key == "interfaces" {
var err error
nameRules, err = c.getAllowListInterfaces(k, value)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
al, err := c.GetAllowList(k, handleKey)
if err != nil {
return nil, err
}
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
}
func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
al, err := c.GetAllowList(k, nil)
if err != nil {
return nil, err
}
remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
if err != nil {
return nil, err
}
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
}
func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
remoteAllowRanges := NewCIDR6Tree()
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil {
return nil, err
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
remoteAllowRanges.AddCIDR(cidr, allowList)
}
return remoteAllowRanges, nil
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
return c.getAllowList(k, r, handleKey)
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
tree := NewCIDR6Tree()
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
firstValue bool
allValuesMatch bool
defaultSet bool
allValues bool
}
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue)
if err != nil {
return nil, err
}
if handled {
continue
}
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(cidr, value)
maskBits, maskSize := cidr.Mask.Size()
var rules *allowListRules
if maskSize == 32 {
rules = &rules4
} else {
rules = &rules6
}
if rules.firstValue {
rules.allValues = value
rules.firstValue = false
} else {
if value != rules.allValues {
rules.allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0 or ::/0
if maskBits == 0 {
rules.defaultSet = true
}
}
if !rules4.defaultSet {
if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
if !rules6.defaultSet {
if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0")
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
}
return &AllowList{cidrTree: tree}, nil
}
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func (c *Config) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *Config) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *Config) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
v, ok = m[p]
if !ok {
return nil
}
}
return v
}
// direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *Config) resolve(path string, direct bool) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path, direct)
return nil
}
paths, err := readDirNames(path)
if err != nil {
return fmt.Errorf("problem while reading directory %s: %s", path, err)
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p), false)
if err != nil {
return err
}
}
return nil
}
func (c *Config) addFile(path string, direct bool) error {
ext := filepath.Ext(path)
if !direct && ext != ".yaml" && ext != ".yml" {
return nil
}
ap, err := filepath.Abs(path)
if err != nil {
return err
}
c.files = append(c.files, ap)
return nil
}
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error {
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err
}
// We need to use WithAppendSlice so that firewall rules in separate
// files are appended together
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
m = nm
if err != nil {
return err
}
}
c.Settings = m
return nil
}
func readDirNames(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
paths, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(paths)
return paths, nil
}
func configLogger(c *Config) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
c.l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
c.l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
c.l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

358
config/config.go Normal file
View File

@ -0,0 +1,358 @@
package config
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type C struct {
path string
files []string
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*C)
l *logrus.Logger
}
func NewC(l *logrus.Logger) *C {
return &C{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
// Load will find all yaml files within path and load them in lexical order
func (c *C) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path, true)
if err != nil {
return err
}
if len(c.files) == 0 {
return fmt.Errorf("no config files found at %s", path)
}
sort.Strings(c.files)
err = c.parse()
if err != nil {
return err
}
return nil
}
func (c *C) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
// These functions should return quickly or spawn their own go routine if they will take a while
func (c *C) RegisterReloadCallback(f func(*C)) {
c.callbacks = append(c.callbacks, f)
}
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
// If k is an empty string the entire config is tested.
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
// there is change when there actually wasn't any.
func (c *C) HasChanged(k string) bool {
if c.oldSettings == nil {
return false
}
var (
nv interface{}
ov interface{}
)
if k == "" {
nv = c.Settings
ov = c.oldSettings
k = "all settings"
} else {
nv = c.get(k, c.Settings)
ov = c.get(k, c.oldSettings)
}
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
}
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *C) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for {
select {
case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}
}()
}
func (c *C) ReloadConfig() {
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
for _, v := range c.callbacks {
v(c)
}
}
// GetString will get the string for k or return the default d if not found or invalid
func (c *C) GetString(k, d string) string {
r := c.Get(k)
if r == nil {
return d
}
return fmt.Sprintf("%v", r)
}
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
func (c *C) GetStringSlice(k string, d []string) []string {
r := c.Get(k)
if r == nil {
return d
}
rv, ok := r.([]interface{})
if !ok {
return d
}
v := make([]string, len(rv))
for i := 0; i < len(v); i++ {
v[i] = fmt.Sprintf("%v", rv[i])
}
return v
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
return v
}
// GetInt will get the int for k or return the default d if not found or invalid
func (c *C) GetInt(k string, d int) int {
r := c.GetString(k, strconv.Itoa(d))
v, err := strconv.Atoi(r)
if err != nil {
return d
}
return v
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *C) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
v, err := strconv.ParseBool(r)
if err != nil {
switch r {
case "y", "yes":
return true
case "n", "no":
return false
}
return d
}
return v
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
v, err := time.ParseDuration(r)
if err != nil {
return d
}
return v
}
func (c *C) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *C) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *C) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
v, ok = m[p]
if !ok {
return nil
}
}
return v
}
// direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *C) resolve(path string, direct bool) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path, direct)
return nil
}
paths, err := readDirNames(path)
if err != nil {
return fmt.Errorf("problem while reading directory %s: %s", path, err)
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p), false)
if err != nil {
return err
}
}
return nil
}
func (c *C) addFile(path string, direct bool) error {
ext := filepath.Ext(path)
if !direct && ext != ".yaml" && ext != ".yml" {
return nil
}
ap, err := filepath.Abs(path)
if err != nil {
return err
}
c.files = append(c.files, ap)
return nil
}
func (c *C) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *C) parse() error {
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err
}
// We need to use WithAppendSlice so that firewall rules in separate
// files are appended together
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
m = nm
if err != nil {
return err
}
}
c.Settings = m
return nil
}
func readDirNames(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
paths, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(paths)
return paths, nil
}

View File

@ -1,4 +1,4 @@
package nebula package config
import ( import (
"io/ioutil" "io/ioutil"
@ -7,19 +7,20 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
// invalid yaml // invalid yaml
c := NewConfig(l) c := NewC(l)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge // simple multi config merge
c = NewConfig(l) c = NewC(l)
os.RemoveAll(dir) os.RemoveAll(dir)
os.Mkdir(dir, 0755) os.Mkdir(dir, 0755)
@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) {
} }
func TestConfig_Get(t *testing.T) { func TestConfig_Get(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
// test simple type // test simple type
c := NewConfig(l) c := NewC(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound")) assert.Equal(t, "hi", c.Get("firewall.outbound"))
@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) {
} }
func TestConfig_GetStringSlice(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := NewConfig(l) c := NewC(l)
c.Settings["slice"] = []interface{}{"one", "two"} c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
} }
func TestConfig_GetBool(t *testing.T) { func TestConfig_GetBool(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := NewConfig(l) c := NewC(l)
c.Settings["bool"] = true c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false)) assert.Equal(t, true, c.GetBool("bool", false))
@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) {
assert.Equal(t, false, c.GetBool("bool", true)) assert.Equal(t, false, c.GetBool("bool", true))
} }
func TestConfig_GetAllowList(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = c.GetAllowList("allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"::/0": false,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
lr, err := c.GetLocalAllowList("allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
lr, err = c.GetLocalAllowList("allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
lr, err = c.GetLocalAllowList("allowlist")
if assert.NoError(t, err) {
assert.NotNil(t, lr)
}
}
func TestConfig_HasChanged(t *testing.T) { func TestConfig_HasChanged(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
// No reload has occurred, return false // No reload has occurred, return false
c := NewConfig(l) c := NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
assert.False(t, c.HasChanged("")) assert.False(t, c.HasChanged(""))
// Test key change // Test key change
c = NewConfig(l) c = NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"} c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged("")) assert.True(t, c.HasChanged(""))
// No key change // No key change
c = NewConfig(l) c = NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"} c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("test"))
@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) {
} }
func TestConfig_ReloadConfig(t *testing.T) { func TestConfig_ReloadConfig(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
done := make(chan bool, 1) done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err) assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig(l) c := NewC(l)
assert.Nil(t, c.Load(dir)) assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer.inner"))
@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *C) {
done <- true done <- true
}) })

View File

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
) )
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
@ -13,16 +15,16 @@ import (
type connectionManager struct { type connectionManager struct {
hostMap *HostMap hostMap *HostMap
in map[uint32]struct{} in map[iputil.VpnIp]struct{}
inLock *sync.RWMutex inLock *sync.RWMutex
inCount int inCount int
out map[uint32]struct{} out map[iputil.VpnIp]struct{}
outLock *sync.RWMutex outLock *sync.RWMutex
outCount int outCount int
TrafficTimer *SystemTimerWheel TrafficTimer *SystemTimerWheel
intf *Interface intf *Interface
pendingDeletion map[uint32]int pendingDeletion map[iputil.VpnIp]int
pendingDeletionLock *sync.RWMutex pendingDeletionLock *sync.RWMutex
pendingDeletionTimer *SystemTimerWheel pendingDeletionTimer *SystemTimerWheel
@ -36,15 +38,15 @@ type connectionManager struct {
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{ nc := &connectionManager{
hostMap: intf.hostMap, hostMap: intf.hostMap,
in: make(map[uint32]struct{}), in: make(map[iputil.VpnIp]struct{}),
inLock: &sync.RWMutex{}, inLock: &sync.RWMutex{},
inCount: 0, inCount: 0,
out: make(map[uint32]struct{}), out: make(map[iputil.VpnIp]struct{}),
outLock: &sync.RWMutex{}, outLock: &sync.RWMutex{},
outCount: 0, outCount: 0,
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
intf: intf, intf: intf,
pendingDeletion: make(map[uint32]int), pendingDeletion: make(map[iputil.VpnIp]int),
pendingDeletionLock: &sync.RWMutex{}, pendingDeletionLock: &sync.RWMutex{},
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval, checkInterval: checkInterval,
@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
return nc return nc
} }
func (n *connectionManager) In(ip uint32) { func (n *connectionManager) In(ip iputil.VpnIp) {
n.inLock.RLock() n.inLock.RLock()
// If this already exists, return // If this already exists, return
if _, ok := n.in[ip]; ok { if _, ok := n.in[ip]; ok {
@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) {
n.inLock.Unlock() n.inLock.Unlock()
} }
func (n *connectionManager) Out(ip uint32) { func (n *connectionManager) Out(ip iputil.VpnIp) {
n.outLock.RLock() n.outLock.RLock()
// If this already exists, return // If this already exists, return
if _, ok := n.out[ip]; ok { if _, ok := n.out[ip]; ok {
@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) {
n.outLock.Unlock() n.outLock.Unlock()
} }
func (n *connectionManager) CheckIn(vpnIP uint32) bool { func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool {
n.inLock.RLock() n.inLock.RLock()
if _, ok := n.in[vpnIP]; ok { if _, ok := n.in[vpnIp]; ok {
n.inLock.RUnlock() n.inLock.RUnlock()
return true return true
} }
@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool {
return false return false
} }
func (n *connectionManager) ClearIP(ip uint32) { func (n *connectionManager) ClearIP(ip iputil.VpnIp) {
n.inLock.Lock() n.inLock.Lock()
n.outLock.Lock() n.outLock.Lock()
delete(n.in, ip) delete(n.in, ip)
@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) {
n.outLock.Unlock() n.outLock.Unlock()
} }
func (n *connectionManager) ClearPendingDeletion(ip uint32) { func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) {
n.pendingDeletionLock.Lock() n.pendingDeletionLock.Lock()
delete(n.pendingDeletion, ip) delete(n.pendingDeletion, ip)
n.pendingDeletionLock.Unlock() n.pendingDeletionLock.Unlock()
} }
func (n *connectionManager) AddPendingDeletion(ip uint32) { func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) {
n.pendingDeletionLock.Lock() n.pendingDeletionLock.Lock()
if _, ok := n.pendingDeletion[ip]; ok { if _, ok := n.pendingDeletion[ip]; ok {
n.pendingDeletion[ip] += 1 n.pendingDeletion[ip] += 1
@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) {
n.pendingDeletionLock.Unlock() n.pendingDeletionLock.Unlock()
} }
func (n *connectionManager) checkPendingDeletion(ip uint32) bool { func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool {
n.pendingDeletionLock.RLock() n.pendingDeletionLock.RLock()
if _, ok := n.pendingDeletion[ip]; ok { if _, ok := n.pendingDeletion[ip]; ok {
@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
return false return false
} }
func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) { func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) {
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds)) n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds))
} }
func (n *connectionManager) Start(ctx context.Context) { func (n *connectionManager) Start(ctx context.Context) {
@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
break break
} }
vpnIP := ep.(uint32) vpnIp := ep.(iputil.VpnIp)
// Check for traffic coming back in from this host. // Check for traffic coming back in from this host.
traf := n.CheckIn(vpnIP) traf := n.CheckIn(vpnIp)
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) n.l.Debugf("Not found in hostmap: %s", vpnIp)
if !n.intf.disconnectInvalid { if !n.intf.disconnectInvalid {
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
continue continue
} }
} }
if n.handleInvalidCertificate(now, vpnIP, hostinfo) { if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
continue continue
} }
@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// expired, just ignore. // expired, just ignore.
if traf { if traf {
if n.l.Level >= logrus.DebugLevel { if n.l.Level >= logrus.DebugLevel {
n.l.WithField("vpnIp", IntIp(vpnIP)). n.l.WithField("vpnIp", vpnIp).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
continue continue
} }
@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
if hostinfo != nil && hostinfo.ConnectionState != nil { if hostinfo != nil && hostinfo.ConnectionState != nil {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out) n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out)
} else { } else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp)
} }
n.AddPendingDeletion(vpnIP) n.AddPendingDeletion(vpnIp)
} }
} }
@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
break break
} }
vpnIP := ep.(uint32) vpnIp := ep.(iputil.VpnIp)
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) n.l.Debugf("Not found in hostmap: %s", vpnIp)
if !n.intf.disconnectInvalid { if !n.intf.disconnectInvalid {
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
continue continue
} }
} }
if n.handleInvalidCertificate(now, vpnIP, hostinfo) { if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
continue continue
} }
// If we saw an incoming packets from this ip and peer's certificate is not // If we saw an incoming packets from this ip and peer's certificate is not
// expired, just ignore. // expired, just ignore.
traf := n.CheckIn(vpnIP) traf := n.CheckIn(vpnIp)
if traf { if traf {
n.l.WithField("vpnIp", IntIp(vpnIP)). n.l.WithField("vpnIp", vpnIp).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}). WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
continue continue
} }
// If it comes around on deletion wheel and hasn't resolved itself, delete // If it comes around on deletion wheel and hasn't resolved itself, delete
if n.checkPendingDeletion(vpnIP) { if n.checkPendingDeletion(vpnIp) {
cn := "" cn := ""
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name cn = hostinfo.ConnectionState.peerCert.Details.Name
@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
WithField("certName", cn). WithField("certName", cn).
Info("Tunnel status") Info("Tunnel status")
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
// TODO: This is only here to let tests work. Should do proper mocking // TODO: This is only here to let tests work. Should do proper mocking
if n.intf.lightHouse != nil { if n.intf.lightHouse != nil {
n.intf.lightHouse.DeleteVpnIP(vpnIP) n.intf.lightHouse.DeleteVpnIp(vpnIp)
} }
n.hostMap.DeleteHostInfo(hostinfo) n.hostMap.DeleteHostInfo(hostinfo)
} else { } else {
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
} }
} }
} }
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool { func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool {
if !n.intf.disconnectInvalid { if !n.intf.disconnectInvalid {
return false return false
} }
@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
} }
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err). n.l.WithField("vpnIp", vpnIp).WithError(err).
WithField("certName", remoteCert.Details.Name). WithField("certName", remoteCert.Details.Name).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel") Info("Remote certificate is no longer valid, tearing down the tunnel")
@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
n.intf.sendCloseTunnel(hostinfo) n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo, false) n.intf.closeTunnel(hostinfo, false)
n.ClearIP(vpnIP) n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIp)
return true return true
} }

View File

@ -10,17 +10,20 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var vpnIP uint32 var vpnIp iputil.VpnIp
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2")) vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects // Very incomplete mock objects
@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
outside: &udpConn{}, outside: &udp.Conn{},
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l, l: l,
} }
now := time.Now() now := time.Now()
@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
out := make([]byte, mtu) out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out) nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
// We saw traffic out to vpnIP // We saw traffic out to vpnIp
nc.Out(vpnIP) nc.Out(vpnIp)
assert.NotContains(t, nc.pendingDeletion, vpnIP) assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIP) assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead 5s. Nothing should happen // Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second) next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick) nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion // This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP) assert.Contains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIP) assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead some more // Move ahead some more
next_tick = now.Add(45 * time.Second) next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick) nc.HandleDeletionTick(next_tick)
// The host should be evicted // The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP) assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.NotContains(t, nc.hostMap.Hosts, vpnIP) assert.NotContains(t, nc.hostMap.Hosts, vpnIp)
} }
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
outside: &udpConn{}, outside: &udp.Conn{},
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l, l: l,
} }
now := time.Now() now := time.Now()
@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
out := make([]byte, mtu) out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out) nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
// We saw traffic out to vpnIP // We saw traffic out to vpnIp
nc.Out(vpnIP) nc.Out(vpnIp)
assert.NotContains(t, nc.pendingDeletion, vpnIP) assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIP) assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead 5s. Nothing should happen // Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second) next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick) nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion // This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP) assert.Contains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIP) assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// We heard back this time // We heard back this time
nc.In(vpnIP) nc.In(vpnIp)
// Move ahead some more // Move ahead some more
next_tick = now.Add(45 * time.Second) next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out) nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick) nc.HandleDeletionTick(next_tick)
// The host should be evicted // The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP) assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIP) assert.Contains(t, nc.hostMap.Hosts, vpnIp)
} }
@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Disconnect only if disconnectInvalid: true is set. // Disconnect only if disconnectInvalid: true is set.
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
now := time.Now() now := time.Now()
l := NewTestLogger() l := util.NewTestLogger()
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: net.IPv4(172, 1, 1, 2), IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
outside: &udpConn{}, outside: &udp.Conn{},
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l, l: l,
disconnectInvalid: true, disconnectInvalid: true,
caPool: ncp, caPool: ncp,
@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
defer cancel() defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10) nc := newConnectionManager(ctx, l, ifce, 5, 10)
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, certState: cs,
peerCert: &peerCert, peerCert: &peerCert,
@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Check if to disconnect with invalid certificate. // Check if to disconnect with invalid certificate.
// Should be alive. // Should be alive.
nextTick := now.Add(45 * time.Second) nextTick := now.Add(45 * time.Second)
destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
assert.False(t, destroyed) assert.False(t, destroyed)
// Move ahead 61s. // Move ahead 61s.
// Check if to disconnect with invalid certificate. // Check if to disconnect with invalid certificate.
// Should be disconnected. // Should be disconnected.
nextTick = now.Add(61 * time.Second) nextTick = now.Add(61 * time.Second)
destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
assert.True(t, destroyed) assert.True(t, destroyed)
} }

View File

@ -10,6 +10,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@ -25,14 +28,14 @@ type Control struct {
} }
type ControlHostInfo struct { type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"` VpnIp net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"` LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"` RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []*udpAddr `json:"remoteAddrs"` RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"` CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"` Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"` MessageCounter uint64 `json:"messageCounter"`
CurrentRemote *udpAddr `json:"currentRemote"` CurrentRemote *udp.Addr `json:"currentRemote"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
} }
} }
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo { func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
var hm *HostMap var hm *HostMap
if pending { if pending {
hm = c.f.handshakeManager.pendingHostMap hm = c.f.handshakeManager.pendingHostMap
@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
hm = c.f.hostMap hm = c.f.hostMap
} }
h, err := hm.QueryVpnIP(vpnIP) h, err := hm.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return nil return nil
} }
@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
} }
// SetRemoteForTunnel forces a tunnel to use a specific remote // SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo { func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return nil return nil
} }
@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
} }
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool { func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return false return false
} }
if !localOnly { if !localOnly {
c.f.send( c.f.send(
closeTunnel, header.CloseTunnel,
0, 0,
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts { for _, h := range c.f.hostMap.Hosts {
if excludeLighthouses { if excludeLighthouses {
if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok { if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
continue continue
} }
} }
if h.ConnectionState.ready { if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h, true) c.f.closeTunnel(h, true)
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote). c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message") Debug("Sending close tunnel message")
closed++ closed++
} }
@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{ chi := ControlHostInfo{
VpnIP: int2ip(h.hostId), VpnIp: h.vpnIp.ToIP(),
LocalIndex: h.localIndexId, LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),

View File

@ -8,17 +8,19 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestControl_GetHostInfoByVpnIP(t *testing.T) { func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(int2ip(100), 4444) remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4), IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
remotes := NewRemoteList() remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{ hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
remote: remote1, remote: remote1,
remotes: remotes, remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}) })
hm.Add(ip2int(ipNet2.IP), &HostInfo{ hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
remote: remote1, remote: remote1,
remotes: remotes, remotes: remotes,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
hostId: ip2int(ipNet2.IP), vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
}) })
c := Control{ c := Control{
@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l: logrus.New(), l: logrus.New(),
} }
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false) thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
expectedInfo := ControlHostInfo{ expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(), VpnIp: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201, LocalIndex: 201,
RemoteIndex: 200, RemoteIndex: 200,
RemoteAddrs: []*udpAddr{remote2, remote1}, RemoteAddrs: []*udp.Addr{remote2, remote1},
CachedPackets: 0, CachedPackets: 0,
Cert: crt.Copy(), Cert: crt.Copy(),
MessageCounter: 0, MessageCounter: 0,
CurrentRemote: NewUDPAddr(int2ip(100), 4444), CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
} }
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi) util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false) thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
}) })
} }

View File

@ -8,12 +8,15 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device // WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
// returning after a message matching the criteria has been piped // returning after a message matching the criteria has been piped
func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) { func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &Header{} h := &header.H{}
for { for {
p := c.f.outside.Get(true) p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil { if err := h.Parse(p.Data); err != nil {
@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu
// WaitForTypeByIndex is similar to WaitForType except it adds an index check // WaitForTypeByIndex is similar to WaitForType except it adds an index check
// Useful if you have many nodes communicating and want to wait to find a specific nodes packet // Useful if you have many nodes communicating and want to wait to find a specific nodes packet
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) { func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &Header{} h := &header.H{}
for { for {
p := c.f.outside.Get(true) p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil { if err := h.Parse(p.Data); err != nil {
@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse // This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp)) remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
iVpnIp := ip2int(vpnIp) iVpnIp := iputil.Ip2VpnIp(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil { if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else { } else {
@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte {
} }
// GetFromUDP will pull a udp packet off the udp side of nebula // GetFromUDP will pull a udp packet off the udp side of nebula
func (c *Control) GetFromUDP(block bool) *UdpPacket { func (c *Control) GetFromUDP(block bool) *udp.Packet {
return c.f.outside.Get(block) return c.f.outside.Get(block)
} }
func (c *Control) GetUDPTxChan() <-chan *UdpPacket { func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
return c.f.outside.txPackets return c.f.outside.TxPackets
} }
func (c *Control) GetTunTxChan() <-chan []byte { func (c *Control) GetTunTxChan() <-chan []byte {
@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
} }
// InjectUDPPacket will inject a packet into the udp side of nebula // InjectUDPPacket will inject a packet into the udp side of nebula
func (c *Control) InjectUDPPacket(p *UdpPacket) { func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.Send(p) c.f.outside.Send(p)
} }
@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
} }
func (c *Control) GetUDPAddr() string { func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String() return c.f.outside.Addr.String()
} }
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)] hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
if !ok { if !ok {
return false return false
} }

View File

@ -8,6 +8,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
// This whole thing should be rewritten to use context // This whole thing should be rewritten to use context
@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string {
if ip == nil { if ip == nil {
return "" return ""
} }
iip := ip2int(ip) iip := iputil.Ip2VpnIp(ip)
hostinfo, err := d.hostMap.QueryVpnIP(iip) hostinfo, err := d.hostMap.QueryVpnIp(iip)
if err != nil { if err != nil {
return "" return ""
} }
@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() { func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap) dnsR = newDnsRecords(hostMap)
// attach request handler func // attach request handler func
@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
handleDnsRequest(l, w, r) handleDnsRequest(l, w, r)
}) })
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c) reloadDns(l, c)
}) })
@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
} }
} }
func getDnsServerAddr(c *Config) string { func getDnsServerAddr(c *config.C) string {
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
} }
func startDns(l *logrus.Logger, c *Config) { func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c) dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder") l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) {
} }
} }
func reloadDns(l *logrus.Logger, c *Config) { func reloadDns(l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) { if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected") l.Debug("No DNS server config change detected")
return return

View File

@ -10,6 +10,9 @@ import (
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) {
t.Log("I consume a garbage packet with a proper nebula header for our tunnel") t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy() badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen] badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket) myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now") t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &nebula.Header{} h := &header.H{}
err := h.Parse(p.Data) err := h.Parse(p.Data)
if err != nil { if err != nil {
panic(err) panic(err)
@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) {
r.FlushAll() r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone

View File

@ -5,7 +5,6 @@ package e2e
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -19,7 +18,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err) panic(err)
} }
config := nebula.NewConfig(l) c := config.NewC(l)
config.LoadString(string(cb)) c.LoadString(string(cb))
control, err := nebula.Main(config, false, "e2e-test", l, nil) control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil { if err != nil {
panic(err) panic(err)
@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) {
return pubkey, privkey return pubkey, privkey
} }
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}
type doneCb func() type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb { func deadline(t *testing.T, seconds time.Duration) doneCb {
@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false) hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false) hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")

View File

@ -11,6 +11,8 @@ import (
"sync" "sync"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
) )
type R struct { type R struct {
@ -41,7 +43,7 @@ const (
RouteAndExit ExitType = 2 RouteAndExit ExitType = 2
) )
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
func NewR(controls ...*nebula.Control) *R { func NewR(controls ...*nebula.Control) *R {
r := &R{ r := &R{
@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
// OnceFrom will route a single packet from sender then return // OnceFrom will route a single packet from sender then return
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) OnceFrom(sender *nebula.Control) { func (r *R) OnceFrom(sender *nebula.Control) {
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType { r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
return RouteAndExit return RouteAndExit
}) })
} }
@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// - routeAndExit: this call will return immediately after routing the last packet from sender // - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
h := &nebula.Header{} h := &header.H{}
for { for {
p := sender.GetFromUDP(true) p := sender.GetFromUDP(true)
r.Lock() r.Lock()
@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) { func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
h := &nebula.Header{} h := &header.H{}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
if err := h.Parse(p.Data); err != nil { if err := h.Parse(p.Data); err != nil {
panic(err) panic(err)
} }
@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr
finish = RouteAndExit finish = RouteAndExit
} }
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
return finish return finish
} }
@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
x, rx, _ := reflect.Select(sc) x, rx, _ := reflect.Select(sc)
r.Lock() r.Lock()
p := rx.Interface().(*nebula.UdpPacket) p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr() outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@ -277,7 +279,7 @@ func (r *R) FlushAll() {
} }
r.Lock() r.Lock()
p := rx.Interface().(*nebula.UdpPacket) p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr() outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@ -292,7 +294,7 @@ func (r *R) FlushAll() {
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock // This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control { func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
p.FromIp = newAddr.IP p.FromIp = newAddr.IP
p.FromPort = uint16(newAddr.Port) p.FromPort = uint16(newAddr.Port)

View File

@ -4,7 +4,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -12,22 +11,14 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
const ( "github.com/slackhq/nebula/firewall"
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
fwProtoTCP = 6
fwProtoUDP = 17
fwProtoICMP = 1
fwPortAny = 0 // Special value for matching `port: any`
fwPortFragment = -1 // Special value for matching `port: fragment`
) )
const tcpACK = 0x10 const tcpACK = 0x10
@ -63,7 +54,7 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree localIps *cidr.Tree4
rules string rules string
rulesVersion uint16 rulesVersion uint16
@ -85,7 +76,7 @@ type firewallMetrics struct {
type FirewallConntrack struct { type FirewallConntrack struct {
sync.Mutex sync.Mutex
Conns map[FirewallPacket]*conn Conns map[firewall.Packet]*conn
TimerWheel *TimerWheel TimerWheel *TimerWheel
} }
@ -116,55 +107,13 @@ type FirewallRule struct {
Any bool Any bool
Hosts map[string]struct{} Hosts map[string]struct{}
Groups [][]string Groups [][]string
CIDR *CIDRTree CIDR *cidr.Tree4
} }
// Even though ports are uint16, int32 maps are faster for lookup // Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules // Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA type firewallPort map[int32]*FirewallCA
type FirewallPacket struct {
LocalIP uint32
RemoteIP uint32
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *FirewallPacket) Copy() *FirewallPacket {
return &FirewallPacket{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case fwProtoTCP:
proto = "tcp"
case fwProtoICMP:
proto = "icmp"
case fwProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": int2ip(fp.LocalIP).String(),
"RemoteIP": int2ip(fp.RemoteIP).String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout max = defaultTimeout
} }
localIps := NewCIDRTree() localIps := cidr.NewTree4()
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
} }
@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{ Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn), Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel(min, max), TimerWheel: NewTimerWheel(min, max),
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
} }
} }
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
fw := NewFirewall( fw := NewFirewall(
l, l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
} }
switch proto { switch proto {
case fwProtoTCP: case firewall.ProtoTCP:
fp = ft.TCP fp = ft.TCP
case fwProtoUDP: case firewall.ProtoUDP:
fp = ft.UDP fp = ft.UDP
case fwProtoICMP: case firewall.ProtoICMP:
fp = ft.ICMP fp = ft.ICMP
case fwProtoAny: case firewall.ProtoAny:
fp = ft.AnyProto fp = ft.AnyProto
default: default:
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error { func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string var table string
if inbound { if inbound {
table = "firewall.inbound" table = "firewall.inbound"
@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
table = "firewall.outbound" table = "firewall.outbound"
} }
r := config.Get(table) r := c.Get(table)
if r == nil { if r == nil {
return nil return nil
} }
@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
var proto uint8 var proto uint8
switch r.Proto { switch r.Proto {
case "any": case "any":
proto = fwProtoAny proto = firewall.ProtoAny
case "tcp": case "tcp":
proto = fwProtoTCP proto = firewall.ProtoTCP
case "udp": case "udp":
proto = fwProtoUDP proto = firewall.ProtoUDP
case "icmp": case "icmp":
proto = fwProtoICMP proto = firewall.ProtoICMP
default: default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error { func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool, localCache) { if f.inConns(packet, fp, incoming, h, caPool, localCache) {
return nil return nil
@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
} }
} else { } else {
// Simple case: Certificate has one IP and no subnets // Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId { if fp.RemoteIP != h.vpnIp {
f.metrics(incoming).droppedRemoteIP.Inc(1) f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
} }
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool { func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil { if localCache != nil {
if _, ok := localCache[fp]; ok { if _, ok := localCache[fp]; ok {
return true return true
@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
} }
switch fp.Protocol { switch fp.Protocol {
case fwProtoTCP: case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = time.Now().Add(f.TCPTimeout)
if incoming { if incoming {
f.checkTCPRTT(c, packet) f.checkTCPRTT(c, packet)
} else { } else {
setTCPRTTTracking(c, packet) setTCPRTTTracking(c, packet)
} }
case fwProtoUDP: case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout) c.Expires = time.Now().Add(f.UDPTimeout)
default: default:
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = time.Now().Add(f.DefaultTimeout)
@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return true return true
} }
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
switch fp.Protocol { switch fp.Protocol {
case fwProtoTCP: case firewall.ProtoTCP:
timeout = f.TCPTimeout timeout = f.TCPTimeout
if !incoming { if !incoming {
setTCPRTTTracking(c, packet) setTCPRTTTracking(c, packet)
} }
case fwProtoUDP: case firewall.ProtoUDP:
timeout = f.UDPTimeout timeout = f.UDPTimeout
default: default:
timeout = f.DefaultTimeout timeout = f.DefaultTimeout
@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock! // Caller must own the connMutex lock!
func (f *Firewall) evict(p FirewallPacket) { func (f *Firewall) evict(p firewall.Packet) {
//TODO: report a stat if the tcp rtt tracking was never resolved? //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn? // Are we still tracking this conn?
conntrack := f.Conntrack conntrack := f.Conntrack
@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
delete(conntrack.Conns, p) delete(conntrack.Conns, p)
} }
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) { if ft.AnyProto.match(p, incoming, c, caPool) {
return true return true
} }
switch p.Protocol { switch p.Protocol {
case fwProtoTCP: case firewall.ProtoTCP:
if ft.TCP.match(p, incoming, c, caPool) { if ft.TCP.match(p, incoming, c, caPool) {
return true return true
} }
case fwProtoUDP: case firewall.ProtoUDP:
if ft.UDP.match(p, incoming, c, caPool) { if ft.UDP.match(p, incoming, c, caPool) {
return true return true
} }
case fwProtoICMP: case firewall.ProtoICMP:
if ft.ICMP.match(p, incoming, c, caPool) { if ft.ICMP.match(p, incoming, c, caPool) {
return true return true
} }
@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
return nil return nil
} }
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
// We don't have any allowed ports, bail // We don't have any allowed ports, bail
if fp == nil { if fp == nil {
return false return false
@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
var port int32 var port int32
if p.Fragment { if p.Fragment {
port = fwPortFragment port = firewall.PortFragment
} else if incoming { } else if incoming {
port = int32(p.LocalPort) port = int32(p.LocalPort)
} else { } else {
@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
return true return true
} }
return fp[fwPortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]struct{}), Hosts: make(map[string]struct{}),
Groups: make([][]string, 0), Groups: make([][]string, 0),
CIDR: NewCIDRTree(), CIDR: cidr.NewTree4(),
} }
} }
@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return nil return nil
} }
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if fc == nil { if fc == nil {
return false return false
} }
@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
// If it's any we need to wipe out any pre-existing rules to save on memory // If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0) fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{}) fr.Hosts = make(map[string]struct{})
fr.CIDR = NewCIDRTree() fr.CIDR = cidr.NewTree4()
} else { } else {
if len(groups) > 0 { if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups) fr.Groups = append(fr.Groups, groups)
@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return false return false
} }
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool { func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if fr == nil { if fr == nil {
return false return false
} }
@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
func parsePort(s string) (startPort, endPort int32, err error) { func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" { if s == "any" {
startPort = fwPortAny startPort = firewall.PortAny
endPort = fwPortAny endPort = firewall.PortAny
} else if s == "fragment" { } else if s == "fragment" {
startPort = fwPortFragment startPort = firewall.PortFragment
endPort = fwPortFragment endPort = firewall.PortFragment
} else if strings.Contains(s, `-`) { } else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2) sPorts := strings.SplitN(s, `-`, 2)
@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
startPort = int32(rStartPort) startPort = int32(rStartPort)
endPort = int32(rEndPort) endPort = int32(rEndPort)
if startPort == fwPortAny { if startPort == firewall.PortAny {
endPort = fwPortAny endPort = firewall.PortAny
} }
} else { } else {
@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
c.Seq = 0 c.Seq = 0
return true return true
} }
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[FirewallPacket]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

59
firewall/cache.go Normal file
View File

@ -0,0 +1,59 @@
package firewall
import (
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[Packet]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

62
firewall/packet.go Normal file
View File

@ -0,0 +1,62 @@
package firewall
import (
"encoding/json"
"fmt"
"github.com/slackhq/nebula/iputil"
)
type m map[string]interface{}
const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment`
)
type Packet struct {
LocalIP iputil.VpnIp
RemoteIP iputil.VpnIp
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp Packet) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case ProtoTCP:
proto = "tcp"
case ProtoICMP:
proto = "icmp"
case ProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": fp.LocalIP.String(),
"RemoteIP": fp.RemoteIP.String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}

View File

@ -11,11 +11,15 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack conntrack := fw.Conntrack
@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) {
_, ti, _ := net.ParseCIDR("1.2.3.4/32") _, ti, _ := net.ParseCIDR("1.2.3.4/32")
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any) assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields // Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP))) assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
// run twice just to make sure // run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion //TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
p := FirewallPacket{ p := firewall.Packet{
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10, 10,
90, 90,
fwProtoUDP, firewall.ProtoUDP,
false, false,
} }
@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) {
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
} }
@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
} }
}) })
b.Run("fail on port", func(b *testing.B) { b.Run("fail on port", func(b *testing.B) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
} }
}) })
@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
} }
}) })
@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
} }
}) })
@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
} }
}) })
b.Run("pass on ip", func(b *testing.B) { b.Run("pass on ip", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1)) ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, InvertedGroups: map[string]struct{}{"nope": {}},
@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
} }
}) })
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) { b.Run("pass on ip with any port", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1)) ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, InvertedGroups: map[string]struct{}{"nope": {}},
@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
} }
}) })
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
p := FirewallPacket{ p := firewall.Packet{
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10, 10,
90, 90,
fwProtoUDP, firewall.ProtoUDP,
false, false,
} }
@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) {
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
p := FirewallPacket{ p := firewall.Packet{
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
1, 1,
1, 1,
fwProtoUDP, firewall.ProtoUDP,
false, false,
} }
@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c2, peerCert: &c2,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h2.CreateRemoteCIDR(&c2) h2.CreateRemoteCIDR(&c2)
@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c3, peerCert: &c3,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h3.CreateRemoteCIDR(&c3) h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) {
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
p := FirewallPacket{ p := firewall.Packet{
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)), iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10, 10,
90, 90,
fwProtoUDP, firewall.ProtoUDP,
false, false,
} }
@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
hostId: ip2int(ipNet.IP), vpnIp: iputil.Ip2VpnIp(ipNet.IP),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) {
} }
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
// Test a bad rule definition // Test a bad rule definition
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
conf := NewConfig(l) conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf) _, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) {
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups // Test both group and groups
conf = NewConfig(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
// Test adding tcp rule // Test adding tcp rule
conf := NewConfig(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups // Test single groups
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error // Test Add error
conf = NewConfig(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
func resetConntrack(fw *Firewall) { func resetConntrack(fw *Firewall) {
fw.Conntrack.Lock() fw.Conntrack.Lock()
fw.Conntrack.Conns = map[FirewallPacket]*conn{} fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Unlock() fw.Conntrack.Unlock()
} }

View File

@ -1,11 +1,11 @@
package nebula package nebula
const ( import (
handshakeIXPSK0 = 0 "github.com/slackhq/nebula/header"
handshakeXXPSK0 = 1 "github.com/slackhq/nebula/udp"
) )
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) { if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
} }
switch h.Subtype { switch h.Subtype {
case handshakeIXPSK0: case header.HandshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
ixHandshakeStage1(f, addr, packet, h) ixHandshakeStage1(f, addr, packet, h)

View File

@ -6,13 +6,16 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// NOISE IX Handshakes // NOISE IX Handshakes
// This function constructs a handshake packet, but does not actually send it // This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager // Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host // This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake. // more quickly, effect is a quicker handshake.
@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
err := f.handshakeManager.AddIndexHostInfo(hostinfo) err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return return
} }
@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hsBytes, err = proto.Marshal(hs) hsBytes, err = proto.Marshal(hs)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return return
} }
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1) h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
atomic.AddUint64(&ci.atomicMessageCounter, 1) atomic.AddUint64(&ci.atomicMessageCounter, 1)
msg, _, _, err := ci.H.WriteMessage(header, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} }
@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hostinfo.handshakeStart = time.Now() hostinfo.handshakeStart = time.Now()
} }
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1) ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Invalid certificate from host") Info("Invalid certificate from host")
return return
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer issuer := remoteCert.Details.Issuer
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) { if vpnIp == f.myVpnIp {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
} }
if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) { if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ConnectionState: ci, ConnectionState: ci,
localIndexId: myIndex, localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex, remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP, vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time, lastHandshakeTime: hs.Details.Time,
} }
@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hostinfo.Lock() hostinfo.Lock()
defer hostinfo.Unlock() defer hostinfo.Unlock()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hsBytes, err := proto.Marshal(hs) hsBytes, err := proto.Marshal(hs)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
} }
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
} }
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) copy(hostinfo.HandshakePacket[0], packet[header.Len:])
// Regardless of whether you are the sender or receiver, you should arrive here // Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection. // and complete standing up the connection.
@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
// Only overwrite existing record if we should win the handshake race // Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) overwrite := vpnIp > f.myVpnIp
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil { if err != nil {
switch err { switch err {
@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }
existing.Unlock() existing.Unlock()
hostinfo.Lock() hostinfo.Lock()
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
return return
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Handshake too old") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
Error("Failed to add HostInfo due to localIndex collision") Error("Failed to add HostInfo due to localIndex collision")
return return
case ErrExistingHandshake: case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish // We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
} }
// Do the send // Do the send
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
} }
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil { if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.Lock() hostinfo.Lock()
defer hostinfo.Unlock() defer hostinfo.Unlock()
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
if ci.ready { if ci.ready {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Handshake is already complete") Info("Handshake is already complete")
@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) { if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false return false
} }
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs) err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host") Error("Invalid certificate from host")
@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true return true
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer issuer := remoteCert.Details.Issuer
// Ensure the right host responded // Ensure the right host responded
if vpnIP != hostinfo.hostId { if vpnIp != hostinfo.vpnIp {
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)). f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
WithField("udpAddr", addr).WithField("certName", certName). WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake") Info("Incorrect host responded to handshake")
@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries //TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.hostId) newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
newHostInfo.Lock() newHostInfo.Lock()
// Block the current used address // Block the current used address
@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
newHostInfo.remotes.BlockRemote(addr) newHostInfo.remotes.BlockRemote(addr)
// Get the correct remote list for the host we did handshake with // Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)). f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes") Info("Blocked addresses for handshakes")
@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.ConnectionState.queueLock.Unlock() hostinfo.ConnectionState.queueLock.Unlock()
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.hostId = vpnIP hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo) f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock() newHostInfo.Unlock()
@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hostinfo.handshakeStart).Nanoseconds() duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).

View File

@ -11,6 +11,9 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
const ( const (
@ -39,7 +42,7 @@ type HandshakeManager struct {
pendingHostMap *HostMap pendingHostMap *HostMap
mainHostMap *HostMap mainHostMap *HostMap
lightHouse *LightHouse lightHouse *LightHouse
outside *udpConn outside *udp.Conn
config HandshakeConfig config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel OutboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
@ -47,18 +50,18 @@ type HandshakeManager struct {
metricTimedOut metrics.Counter metricTimedOut metrics.Counter
l *logrus.Logger l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIP // can be used to trigger outbound handshake for the given vpnIp
trigger chan uint32 trigger chan iputil.VpnIp
} }
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
lightHouse: lightHouse, lightHouse: lightHouse,
outside: outside, outside: outside,
config: config, config: config,
trigger: make(chan uint32, config.triggerBuffer), trigger: make(chan iputil.VpnIp, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)), OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
} }
} }
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval) clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop() defer clockSource.Stop()
@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
case <-ctx.Done(): case <-ctx.Done():
return return
case vpnIP := <-c.trigger: case vpnIP := <-c.trigger:
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true) c.handleOutbound(vpnIP, f, true)
case now := <-clockSource.C: case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
} }
} }
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
c.OutboundHandshakeTimer.advance(now) c.OutboundHandshakeTimer.advance(now)
for { for {
ep := c.OutboundHandshakeTimer.Purge() ep := c.OutboundHandshakeTimer.Purge()
if ep == nil { if ep == nil {
break break
} }
vpnIP := ep.(uint32) vpnIp := ep.(iputil.VpnIp)
c.handleOutbound(vpnIP, f, false) c.handleOutbound(vpnIp, f, false)
} }
} }
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) { func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return return
} }
@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
if !hostinfo.HandshakeReady { if !hostinfo.HandshakeReady {
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
return return
} }
@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
// Get a remotes object if we don't already have one. // Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case // This is mainly to protect us as this should never be the case
if hostinfo.remotes == nil { if hostinfo.remotes == nil {
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP) hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
} }
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter // the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIP, f) c.lightHouse.QueryServer(vpnIp, f)
} }
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply // Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udpAddr var sentTo []*udp.Addr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) { hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil { if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr). hostinfo.logger(c.l).WithField("udpAddr", addr).
@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered { if !lighthouseTriggered {
//TODO: feel like we dupe handshake real fast in a tight loop, why? //TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
} }
} }
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
// We lock here and use an array to insert items to prevent locking the // We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map // main receive thread for very long by waiting to add items to the pending map
//TODO: what lock? //TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1) c.metricInitiated.Inc(1)
return hostinfo return hostinfo
@ -208,12 +211,12 @@ var (
// CheckAndComplete checks for any conflicts in the main and pending hostmap // CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be: // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
//
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet // exact same handshake packet
// //
// ErrExistingHostInfo if we already have an entry in the hostmap for this // ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and the new handshake was older than the one we currently have // VpnIp and the new handshake was older than the one we currently have
// //
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
defer c.mainHostMap.Unlock() defer c.mainHostMap.Unlock()
// Check if we already have a tunnel with this vpn ip // Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
// Is it just a delayed handshake packet? // Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
} }
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l). hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
// Check if we are also handshaking with this vpn ip // Check if we are also handshaking with this vpn ip
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId] pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
if found && pendingHostInfo != nil { if found && pendingHostInfo != nil {
if !overwrite { if !overwrite {
// We won, let our pending handshake win // We won, let our pending handshake win
@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if existingHostInfo != nil { if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references // We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
} }
@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.mainHostMap.Lock() c.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references // We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
} }
@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l). hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }

View File

@ -5,25 +5,29 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_NewHandshakeManagerVpnIP(t *testing.T) { func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2")) ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
i := blah.AddVpnIP(ip) i := blah.AddVpnIp(ip)
i.remotes = NewRemoteList() i.remotes = NewRemoteList()
i.HandshakeReady = true i.HandshakeReady = true
@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) { func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2")) ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l} lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.AddVpnIP(ip) hi := blah.AddVpnIp(ip)
hi.HandshakeReady = true hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
// Make sure the trigger doesn't double schedule the timer entry // Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
uaddr := NewUDPAddrFromString("10.1.1.1:4242") uaddr := udp.NewAddrFromString("10.1.1.1:4242")
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
// We now have remotes but only the first trigger should have pushed things forward // We now have remotes but only the first trigger should have pushed things forward
@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
return return
} }

View File

@ -1,4 +1,4 @@
package nebula package header
import ( import (
"encoding/binary" "encoding/binary"
@ -19,82 +19,78 @@ import (
// |-----------------------------------------------------------------------| // |-----------------------------------------------------------------------|
// | payload... | // | payload... |
const ( type m map[string]interface{}
Version uint8 = 1
HeaderLen = 16
)
type NebulaMessageType uint8
type NebulaMessageSubType uint8
const ( const (
handshake NebulaMessageType = 0 Version uint8 = 1
message NebulaMessageType = 1 Len = 16
recvError NebulaMessageType = 2
lightHouse NebulaMessageType = 3
test NebulaMessageType = 4
closeTunnel NebulaMessageType = 5
//TODO These are deprecated as of 06/12/2018 - NB
testRemote NebulaMessageType = 6
testRemoteReply NebulaMessageType = 7
) )
var typeMap = map[NebulaMessageType]string{ type MessageType uint8
handshake: "handshake", type MessageSubType uint8
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
//TODO These are deprecated as of 06/12/2018 - NB const (
testRemote: "testRemote", Handshake MessageType = 0
testRemoteReply: "testRemoteReply", Message MessageType = 1
RecvError MessageType = 2
LightHouse MessageType = 3
Test MessageType = 4
CloseTunnel MessageType = 5
)
var typeMap = map[MessageType]string{
Handshake: "handshake",
Message: "message",
RecvError: "recvError",
LightHouse: "lightHouse",
Test: "test",
CloseTunnel: "closeTunnel",
} }
const ( const (
testRequest NebulaMessageSubType = 0 TestRequest MessageSubType = 0
testReply NebulaMessageSubType = 1 TestReply MessageSubType = 1
) )
var eHeaderTooShort = errors.New("header is too short") const (
HandshakeIXPSK0 MessageSubType = 0
HandshakeXXPSK0 MessageSubType = 1
)
var subTypeTestMap = map[NebulaMessageSubType]string{ var ErrHeaderTooShort = errors.New("header is too short")
testRequest: "testRequest",
testReply: "testReply", var subTypeTestMap = map[MessageSubType]string{
TestRequest: "testRequest",
TestReply: "testReply",
} }
var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"} var subTypeNoneMap = map[MessageSubType]string{0: "none"}
var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{ var subTypeMap = map[MessageType]*map[MessageSubType]string{
message: &subTypeNoneMap, Message: &subTypeNoneMap,
recvError: &subTypeNoneMap, RecvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap, LightHouse: &subTypeNoneMap,
test: &subTypeTestMap, Test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap, CloseTunnel: &subTypeNoneMap,
handshake: { Handshake: {
handshakeIXPSK0: "ix_psk0", HandshakeIXPSK0: "ix_psk0",
}, },
//TODO: these are deprecated
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
} }
type Header struct { type H struct {
Version uint8 Version uint8
Type NebulaMessageType Type MessageType
Subtype NebulaMessageSubType Subtype MessageSubType
Reserved uint16 Reserved uint16
RemoteIndex uint32 RemoteIndex uint32
MessageCounter uint64 MessageCounter uint64
} }
// HeaderEncode uses the provided byte array to encode the provided header values into. // Encode uses the provided byte array to encode the provided header values into.
// Byte array must be capped higher than HeaderLen or this will panic // Byte array must be capped higher than HeaderLen or this will panic
func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte { func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte {
b = b[:HeaderLen] b = b[:Len]
b[0] = byte(v<<4 | (t & 0x0f)) b[0] = v<<4 | byte(t&0x0f)
b[1] = byte(st) b[1] = byte(st)
binary.BigEndian.PutUint16(b[2:4], 0) binary.BigEndian.PutUint16(b[2:4], 0)
binary.BigEndian.PutUint32(b[4:8], ri) binary.BigEndian.PutUint32(b[4:8], ri)
@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b
} }
// String creates a readable string representation of a header // String creates a readable string representation of a header
func (h *Header) String() string { func (h *H) String() string {
if h == nil { if h == nil {
return "<nil>" return "<nil>"
} }
@ -112,7 +108,7 @@ func (h *Header) String() string {
} }
// MarshalJSON creates a json string representation of a header // MarshalJSON creates a json string representation of a header
func (h *Header) MarshalJSON() ([]byte, error) { func (h *H) MarshalJSON() ([]byte, error) {
return json.Marshal(m{ return json.Marshal(m{
"version": h.Version, "version": h.Version,
"type": h.TypeName(), "type": h.TypeName(),
@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) {
} }
// Encode turns header into bytes // Encode turns header into bytes
func (h *Header) Encode(b []byte) ([]byte, error) { func (h *H) Encode(b []byte) ([]byte, error) {
if h == nil { if h == nil {
return nil, errors.New("nil header") return nil, errors.New("nil header")
} }
return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil
} }
// Parse is a helper function to parses given bytes into new Header struct // Parse is a helper function to parses given bytes into new Header struct
func (h *Header) Parse(b []byte) error { func (h *H) Parse(b []byte) error {
if len(b) < HeaderLen { if len(b) < Len {
return eHeaderTooShort return ErrHeaderTooShort
} }
// get upper 4 bytes // get upper 4 bytes
h.Version = uint8((b[0] >> 4) & 0x0f) h.Version = uint8((b[0] >> 4) & 0x0f)
// get lower 4 bytes // get lower 4 bytes
h.Type = NebulaMessageType(b[0] & 0x0f) h.Type = MessageType(b[0] & 0x0f)
h.Subtype = NebulaMessageSubType(b[1]) h.Subtype = MessageSubType(b[1])
h.Reserved = binary.BigEndian.Uint16(b[2:4]) h.Reserved = binary.BigEndian.Uint16(b[2:4])
h.RemoteIndex = binary.BigEndian.Uint32(b[4:8]) h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
h.MessageCounter = binary.BigEndian.Uint64(b[8:16]) h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error {
} }
// TypeName will transform the headers message type into a human string // TypeName will transform the headers message type into a human string
func (h *Header) TypeName() string { func (h *H) TypeName() string {
return TypeName(h.Type) return TypeName(h.Type)
} }
// TypeName will transform a nebula message type into a human string // TypeName will transform a nebula message type into a human string
func TypeName(t NebulaMessageType) string { func TypeName(t MessageType) string {
if n, ok := typeMap[t]; ok { if n, ok := typeMap[t]; ok {
return n return n
} }
@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string {
} }
// SubTypeName will transform the headers message sub type into a human string // SubTypeName will transform the headers message sub type into a human string
func (h *Header) SubTypeName() string { func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype) return SubTypeName(h.Type, h.Subtype)
} }
// SubTypeName will transform a nebula message sub type into a human string // SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string { func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok { if n, ok := subTypeMap[t]; ok {
if x, ok := (*n)[s]; ok { if x, ok := (*n)[s]; ok {
return x return x
@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
} }
// NewHeader turns bytes into a header // NewHeader turns bytes into a header
func NewHeader(b []byte) (*Header, error) { func NewHeader(b []byte) (*H, error) {
h := new(Header) h := new(H)
if err := h.Parse(b); err != nil { if err := h.Parse(b); err != nil {
return nil, err return nil, err
} }

115
header/header_test.go Normal file
View File

@ -0,0 +1,115 @@
package header
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type headerTest struct {
expectedBytes []byte
*H
}
// 0001 0010 00010010
var headerBigEndianTests = []headerTest{{
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
// 1010 0000
H: &H{
// 1111 1+2+4+8 = 15
Version: 5,
Type: 4,
Subtype: 0,
Reserved: 0,
RemoteIndex: 10,
MessageCounter: 9,
},
},
}
func TestEncode(t *testing.T) {
for _, tt := range headerBigEndianTests {
b, err := tt.Encode(make([]byte, Len))
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tt.expectedBytes, b)
}
}
func TestParse(t *testing.T) {
for _, tt := range headerBigEndianTests {
b := tt.expectedBytes
parsedHeader := &H{}
parsedHeader.Parse(b)
if !reflect.DeepEqual(tt.H, parsedHeader) {
t.Fatalf("got %#v; want %#v", parsedHeader, tt.H)
}
}
}
func TestTypeName(t *testing.T) {
assert.Equal(t, "test", TypeName(Test))
assert.Equal(t, "test", (&H{Type: Test}).TypeName())
assert.Equal(t, "unknown", TypeName(99))
assert.Equal(t, "unknown", (&H{Type: 99}).TypeName())
}
func TestSubTypeName(t *testing.T) {
assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest))
assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(99, TestRequest))
assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(Test, 99))
assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName())
assert.Equal(t, "none", SubTypeName(Message, 0))
assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName())
}
func TestTypeMap(t *testing.T) {
// Force people to document this stuff
assert.Equal(t, map[MessageType]string{
Handshake: "handshake",
Message: "message",
RecvError: "recvError",
LightHouse: "lightHouse",
Test: "test",
CloseTunnel: "closeTunnel",
}, typeMap)
assert.Equal(t, map[MessageType]*map[MessageSubType]string{
Message: &subTypeNoneMap,
RecvError: &subTypeNoneMap,
LightHouse: &subTypeNoneMap,
Test: &subTypeTestMap,
CloseTunnel: &subTypeNoneMap,
Handshake: {
HandshakeIXPSK0: "ix_psk0",
},
}, subTypeMap)
}
func TestHeader_String(t *testing.T) {
assert.Equal(
t,
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
(&H{100, Test, TestRequest, 99, 98, 97}).String(),
)
}
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
string(b),
)
}

View File

@ -1,119 +0,0 @@
package nebula
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type headerTest struct {
expectedBytes []byte
*Header
}
// 0001 0010 00010010
var headerBigEndianTests = []headerTest{{
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
// 1010 0000
Header: &Header{
// 1111 1+2+4+8 = 15
Version: 5,
Type: 4,
Subtype: 0,
Reserved: 0,
RemoteIndex: 10,
MessageCounter: 9,
},
},
}
func TestEncode(t *testing.T) {
for _, tt := range headerBigEndianTests {
b, err := tt.Encode(make([]byte, HeaderLen))
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tt.expectedBytes, b)
}
}
func TestParse(t *testing.T) {
for _, tt := range headerBigEndianTests {
b := tt.expectedBytes
parsedHeader := &Header{}
parsedHeader.Parse(b)
if !reflect.DeepEqual(tt.Header, parsedHeader) {
t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
}
}
}
func TestTypeName(t *testing.T) {
assert.Equal(t, "test", TypeName(test))
assert.Equal(t, "test", (&Header{Type: test}).TypeName())
assert.Equal(t, "unknown", TypeName(99))
assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
}
func TestSubTypeName(t *testing.T) {
assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(99, testRequest))
assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(test, 99))
assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
assert.Equal(t, "none", SubTypeName(message, 0))
assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
}
func TestTypeMap(t *testing.T) {
// Force people to document this stuff
assert.Equal(t, map[NebulaMessageType]string{
handshake: "handshake",
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
testRemote: "testRemote",
testRemoteReply: "testRemoteReply",
}, typeMap)
assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
message: &subTypeNoneMap,
recvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap,
test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap,
handshake: {
handshakeIXPSK0: "ix_psk0",
},
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
}, subTypeMap)
}
func TestHeader_String(t *testing.T) {
assert.Equal(
t,
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
(&Header{100, test, testRequest, 99, 98, 97}).String(),
)
}
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
string(b),
)
}

View File

@ -12,6 +12,10 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
//const ProbeLen = 100 //const ProbeLen = 100
@ -28,10 +32,10 @@ type HostMap struct {
name string name string
Indexes map[uint32]*HostInfo Indexes map[uint32]*HostInfo
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet preferredRanges []*net.IPNet
vpnCIDR *net.IPNet vpnCIDR *net.IPNet
unsafeRoutes *CIDRTree unsafeRoutes *cidr.Tree4
metricsEnabled bool metricsEnabled bool
l *logrus.Logger l *logrus.Logger
} }
@ -39,7 +43,7 @@ type HostMap struct {
type HostInfo struct { type HostInfo struct {
sync.RWMutex sync.RWMutex
remote *udpAddr remote *udp.Addr
remotes *RemoteList remotes *RemoteList
promoteCounter uint32 promoteCounter uint32
ConnectionState *ConnectionState ConnectionState *ConnectionState
@ -51,9 +55,9 @@ type HostInfo struct {
packetStore []*cachedPacket //todo: this is other handshake manager entry packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32 remoteIndexId uint32
localIndexId uint32 localIndexId uint32
hostId uint32 vpnIp iputil.VpnIp
recvError int recvError int
remoteCidr *CIDRTree remoteCidr *cidr.Tree4
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
@ -66,17 +70,17 @@ type HostInfo struct {
lastHandshakeTime uint64 lastHandshakeTime uint64
lastRoam time.Time lastRoam time.Time
lastRoamRemote *udpAddr lastRoamRemote *udp.Addr
} }
type cachedPacket struct { type cachedPacket struct {
messageType NebulaMessageType messageType header.MessageType
messageSubType NebulaMessageSubType messageSubType header.MessageSubType
callback packetCallback callback packetCallback
packet []byte packet []byte
} }
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte) type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
type cachedPacketMetrics struct { type cachedPacketMetrics struct {
sent metrics.Counter sent metrics.Counter
@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
} }
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{} h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{} i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{}
m := HostMap{ m := HostMap{
@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h, Hosts: h,
preferredRanges: preferredRanges, preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR, vpnCIDR: vpnCIDR,
unsafeRoutes: NewCIDRTree(), unsafeRoutes: cidr.NewTree4(),
l: l, l: l,
} }
return &m return &m
@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
} }
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
hm.RLock() hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok { if i, ok := hm.Hosts[vpnIp]; ok {
index := i.localIndexId index := i.localIndexId
hm.RUnlock() hm.RUnlock()
return index, nil return index, nil
@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
return 0, errors.New("vpn IP not found") return 0, errors.New("vpn IP not found")
} }
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) { func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Lock() hm.Lock()
hm.Hosts[ip] = hostinfo hm.Hosts[ip] = hostinfo
hm.Unlock() hm.Unlock()
} }
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo { func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
h := &HostInfo{} h := &HostInfo{}
hm.RLock() hm.RLock()
if _, ok := hm.Hosts[vpnIP]; !ok { if _, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock() hm.RUnlock()
h = &HostInfo{ h = &HostInfo{
promoteCounter: 0, promoteCounter: 0,
hostId: vpnIP, vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
} }
hm.Lock() hm.Lock()
hm.Hosts[vpnIP] = h hm.Hosts[vpnIp] = h
hm.Unlock() hm.Unlock()
return h return h
} else { } else {
h = hm.Hosts[vpnIP] h = hm.Hosts[vpnIp]
hm.RUnlock() hm.RUnlock()
return h return h
} }
} }
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
hm.Lock() hm.Lock()
delete(hm.Hosts, vpnIP) delete(hm.Hosts, vpnIp)
if len(hm.Hosts) == 0 { if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{} hm.Hosts = map[iputil.VpnIp]*HostInfo{}
} }
hm.Unlock() hm.Unlock()
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted") Debug("Hostmap vpnIp deleted")
} }
} }
@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
if hm.l.Level > logrus.DebugLevel { if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
Debug("Hostmap remoteIndex added") Debug("Hostmap remoteIndex added")
} }
} }
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) { func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
hm.Lock() hm.Lock()
h.hostId = vpnIP h.vpnIp = vpnIp
hm.Hosts[vpnIP] = h hm.Hosts[vpnIp] = h
hm.Indexes[h.localIndexId] = h hm.Indexes[h.localIndexId] = h
hm.RemoteIndexes[h.remoteIndexId] = h hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock() hm.Unlock()
if hm.l.Level > logrus.DebugLevel { if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
} }
@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo // Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do. // instance. Clean it up as well if we do.
hostinfo2, ok := hm.Hosts[hostinfo.hostId] hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo { if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId) delete(hm.Hosts, hostinfo.vpnIp)
} }
} }
hm.Unlock() hm.Unlock()
@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo // Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap) // instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.hostId] hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo { if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId) delete(hm.Hosts, hostinfo.vpnIp)
} }
} }
hm.Unlock() hm.Unlock()
@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance. // Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different // This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap. // index values than the one in the main hostmap.
hostinfo2, ok := hm.Hosts[hostinfo.hostId] hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 != hostinfo { if ok && hostinfo2 != hostinfo {
delete(hm.Hosts, hostinfo2.hostId) delete(hm.Hosts, hostinfo2.vpnIp)
delete(hm.Indexes, hostinfo2.localIndexId) delete(hm.Indexes, hostinfo2.localIndexId)
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId) delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
} }
delete(hm.Hosts, hostinfo.hostId) delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 { if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{} hm.Hosts = map[iputil.VpnIp]*HostInfo{}
} }
delete(hm.Indexes, hostinfo.localIndexId) delete(hm.Indexes, hostinfo.localIndexId)
if len(hm.Indexes) == 0 { if len(hm.Indexes) == 0 {
@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted") Debug("Hostmap hostInfo deleted")
} }
} }
@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
} }
} }
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) { func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil) return hm.queryVpnIp(vpnIp, nil)
} }
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host. // `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) { func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, ifce) return hm.queryVpnIp(vpnIp, ifce)
} }
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock() hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok { if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock() hm.RUnlock()
@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
return nil, errors.New("unable to find host") return nil, errors.New("unable to find host")
} }
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
r := hm.unsafeRoutes.MostSpecificContains(ip) r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(uint32) return r.(iputil.VpnIp)
} else { } else {
return 0 return 0
} }
@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
} }
hm.Hosts[hostinfo.hostId] = hostinfo hm.Hosts[hostinfo.vpnIp] = hostinfo
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
} }
@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
} }
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) { func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
var metricsTxPunchy metrics.Counter var metricsTxPunchy metrics.Counter
if hm.metricsEnabled { if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
func (hm *HostMap) addUnsafeRoutes(routes *[]route) { func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes { for _, r := range *routes {
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route") hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via)) hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
} }
} }
@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
} }
} }
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) { i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
if addr == nil || !preferred { if addr == nil || !preferred {
return return
} }
// Try to send a test packet to that host, this should // Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes // cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}) })
} }
// Re query our lighthouses for new remotes occasionally // Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil { if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce) ifce.lightHouse.QueryServer(i.vpnIp, ifce)
} }
} }
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
//TODO: return the error so we can log with more context //TODO: return the error so we can log with more context
if len(i.packetStore) < 100 { if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet)) tempPacket := make([]byte, len(packet))
@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil return nil
} }
func (i *HostInfo) SetRemote(remote *udpAddr) { func (i *HostInfo) SetRemote(remote *udp.Addr) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) { if !i.remote.Equals(remote) {
i.remote = remote.Copy() i.remote = remote.Copy()
i.remotes.LearnRemote(i.hostId, remote.Copy()) i.remotes.LearnRemote(i.vpnIp, remote.Copy())
} }
} }
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated. // time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool { func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
currentRemote := i.remote currentRemote := i.remote
if currentRemote == nil { if currentRemote == nil {
i.SetRemote(newRemote) i.SetRemote(newRemote)
@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return return
} }
remoteCidr := NewCIDRTree() remoteCidr := cidr.NewTree4()
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
} }
@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l) return logrus.NewEntry(l)
} }
li := l.WithField("vpnIp", IntIp(i.hostId)) li := l.WithField("vpnIp", i.vpnIp)
if connState := i.ConnectionState; connState != nil { if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil { if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name) li = li.WithField("certName", peerCert.Details.Name)
@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return li return li
} }
//########################
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
s := "\n"
for _, h := range hm.Hosts {
for _, r := range h.Remotes {
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
}
}
return s
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
r.ProbeReceived(counter)
}
}
}
func (i *HostInfo) Probes() []*Probe {
p := []*Probe{}
for _, d := range i.Remotes {
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
}
return p
}
*/
// Utility functions // Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {

View File

@ -5,9 +5,13 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
hostinfo := f.getOrHandshake(fwPacket.RemoteIP) hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil { if hostinfo == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)). f.l.WithField("vpnIp", fwPacket.RemoteIP).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
} }
@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue. // the packet queue.
ci.queueLock.Lock() ci.queueLock.Lock()
if !ci.ready { if !ci.ready {
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
ci.queueLock.Unlock() ci.queueLock.Unlock()
return return
} }
@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
} else if f.l.Level >= logrus.DebugLevel { } else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l). hostinfo.logger(f.l).
@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
} }
// getOrHandshake returns nil if the vpnIp is not routable // getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false { //TODO: we can find contains without converting back to bytes
if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
if vpnIp == 0 { if vpnIp == 0 {
return nil return nil
} }
} }
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil { //if err != nil || hostinfo.ConnectionState == nil {
if err != nil { if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp) hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp) hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
} }
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
return hostinfo return hostinfo
} }
func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
fp := &FirewallPacket{} fp := &firewall.Packet{}
err := newPacket(p, false, fp) err := newPacket(p, false, fp)
if err != nil { if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
return return
} }
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
} }
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil { if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)). f.l.WithField("vpnIp", vpnIp).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
} }
return return
@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
return return
} }
func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
} }
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
} }
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) { func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning //TODO: log warning
return return
@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
c := atomic.AddUint64(&ci.atomicMessageCounter, 1) c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.hostId) f.connectionManager.Out(hostinfo.vpnIp)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming. // all our IPs and enable a faster roaming.
if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount { if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.hostId, f) f.lightHouse.QueryServer(hostinfo.vpnIp, f)
hostinfo.lastRebindCount = f.rebindCount hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter") f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
} }
} }
@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
return return
} }
func isMulticast(ip uint32) bool { func isMulticast(ip iputil.VpnIp) bool {
// Class D multicast // Class D multicast
if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 { if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
return true return true

View File

@ -12,6 +12,10 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const mtu = 9001
@ -27,7 +31,7 @@ type Inside interface {
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
Outside *udpConn Outside *udp.Conn
Inside Inside Inside Inside
certState *CertState certState *CertState
Cipher string Cipher string
@ -39,7 +43,6 @@ type InterfaceConfig struct {
pendingDeletionInterval int pendingDeletionInterval int
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
UDPBatchSize int
routines int routines int
MessageMetrics *MessageMetrics MessageMetrics *MessageMetrics
version string version string
@ -52,7 +55,7 @@ type InterfaceConfig struct {
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside *udpConn outside *udp.Conn
inside Inside inside Inside
certState *CertState certState *CertState
cipher string cipher string
@ -62,11 +65,10 @@ type Interface struct {
serveDns bool serveDns bool
createTime time.Time createTime time.Time
lightHouse *LightHouse lightHouse *LightHouse
localBroadcast uint32 localBroadcast iputil.VpnIp
myVpnIp uint32 myVpnIp iputil.VpnIp
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
udpBatchSize int
routines int routines int
caPool *cert.NebulaCAPool caPool *cert.NebulaCAPool
disconnectInvalid bool disconnectInvalid bool
@ -77,7 +79,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
writers []*udpConn writers []*udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules") return nil, errors.New("no firewall rules")
} }
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
ifce := &Interface{ ifce := &Interface{
hostMap: c.HostMap, hostMap: c.HostMap,
outside: c.Outside, outside: c.Outside,
@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager, handshakeManager: c.HandshakeManager,
createTime: time.Now(), createTime: time.Now(),
lightHouse: c.lightHouse, lightHouse: c.lightHouse,
localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask), localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast, dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast, dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]*udpConn, c.routines), writers: make([]*udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool, caPool: c.caPool,
disconnectInvalid: c.disconnectInvalid, disconnectInvalid: c.disconnectInvalid,
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP), myVpnIp: myVpnIp,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -190,14 +192,17 @@ func (f *Interface) run() {
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li *udpConn var li *udp.Conn
// TODO clean this up with a coherent interface for each outside connection // TODO clean this up with a coherent interface for each outside connection
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
} else { } else {
li = f.outside li = f.outside
} }
li.ListenOut(f, i)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &FirewallPacket{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)
@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
} }
} }
func (f *Interface) RegisterConfigChangeCallbacks(c *Config) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadCA) c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadFirewall)
for _, udpConn := range f.writers { for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.reloadConfig) c.RegisterReloadCallback(udpConn.ReloadConfig)
} }
} }
func (f *Interface) reloadCA(c *Config) { func (f *Interface) reloadCA(c *config.C) {
// reload and check regardless // reload and check regardless
// todo: need mutex? // todo: need mutex?
newCAs, err := loadCAFromConfig(f.l, c) newCAs, err := loadCAFromConfig(f.l, c)
@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) {
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
} }
func (f *Interface) reloadCertKey(c *Config) { func (f *Interface) reloadCertKey(c *config.C) {
// reload and check in all cases // reload and check in all cases
cs, err := NewCertStateFromConfig(c) cs, err := NewCertStateFromConfig(c)
if err != nil { if err != nil {
@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) {
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
} }
func (f *Interface) reloadFirewall(c *Config) { func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too //TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false { if c.HasChanged("firewall") == false {
f.l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i) ticker := time.NewTicker(i)
defer ticker.Stop() defer ticker.Stop()
udpStats := NewUDPStatsEmitter(f.writers) udpStats := udp.NewUDPStatsEmitter(f.writers)
for { for {
select { select {

66
iputil/util.go Normal file
View File

@ -0,0 +1,66 @@
package iputil
import (
"encoding/binary"
"fmt"
"net"
)
type VpnIp uint32
const maxIPv4StringLen = len("255.255.255.255")
func (ip VpnIp) String() string {
b := make([]byte, maxIPv4StringLen)
n := ubtoa(b, 0, byte(ip>>24))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>16&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>8&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip&255))
return string(b[:n])
}
func (ip VpnIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
}
func (ip VpnIp) ToIP() net.IP {
nip := make(net.IP, 4)
binary.BigEndian.PutUint32(nip, uint32(ip))
return nip
}
func Ip2VpnIp(ip []byte) VpnIp {
if len(ip) == 16 {
return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
}
return VpnIp(binary.BigEndian.Uint32(ip))
}
// ubtoa encodes the string form of the integer v to dst[start:] and
// returns the number of bytes written to dst. The caller must ensure
// that dst has sufficient length.
func ubtoa(dst []byte, start int, v byte) int {
if v < 10 {
dst[start] = v + '0'
return 1
} else if v < 100 {
dst[start+1] = v%10 + '0'
dst[start] = v/10 + '0'
return 2
}
dst[start+2] = v%10 + '0'
dst[start+1] = (v/10)%10 + '0'
dst[start] = v/100 + '0'
return 3
}

17
iputil/util_test.go Normal file
View File

@ -0,0 +1,17 @@
package iputil
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestVpnIp_String(t *testing.T) {
assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
}

View File

@ -12,6 +12,9 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
@ -23,13 +26,13 @@ type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
amLighthouse bool amLighthouse bool
myVpnIp uint32 myVpnIp iputil.VpnIp
myVpnZeros uint32 myVpnZeros iputil.VpnIp
punchConn *udpConn punchConn *udp.Conn
// Local cache of answers from light houses // Local cache of answers from light houses
// map of vpn Ip to answers // map of vpn Ip to answers
addrMap map[uint32]*RemoteList addrMap map[iputil.VpnIp]*RemoteList
// filters remote addresses allowed for each host // filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and // - When we are a lighthouse, this filters what addresses we store and
@ -42,12 +45,12 @@ type LightHouse struct {
localAllowList *LocalAllowList localAllowList *LocalAllowList
// used to trigger the HandshakeManager when we receive HostQueryReply // used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- uint32 handshakeTrigger chan<- iputil.VpnIp
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList map[uint32]struct{} staticList map[iputil.VpnIp]struct{}
lighthouses map[uint32]struct{} lighthouses map[iputil.VpnIp]struct{}
interval int interval int
nebulaPort uint32 // 32 bits because protobuf does not have a uint16 nebulaPort uint32 // 32 bits because protobuf does not have a uint16
punchBack bool punchBack bool
@ -58,20 +61,16 @@ type LightHouse struct {
l *logrus.Logger l *logrus.Logger
} }
type EncWriter interface { func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
ones, _ := myVpnIpNet.Mask.Size() ones, _ := myVpnIpNet.Mask.Size()
h := LightHouse{ h := LightHouse{
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnIp: ip2int(myVpnIpNet.IP), myVpnIp: iputil.Ip2VpnIp(myVpnIpNet.IP),
myVpnZeros: uint32(32 - ones), myVpnZeros: iputil.VpnIp(32 - ones),
addrMap: make(map[uint32]*RemoteList), addrMap: make(map[iputil.VpnIp]*RemoteList),
nebulaPort: nebulaPort, nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}), lighthouses: make(map[iputil.VpnIp]struct{}),
staticList: make(map[uint32]struct{}), staticList: make(map[iputil.VpnIp]struct{}),
interval: interval, interval: interval,
punchConn: pc, punchConn: pc,
punchBack: punchBack, punchBack: punchBack,
@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
func (lh *LightHouse) ValidateLHStaticEntries() error { func (lh *LightHouse) ValidateLHStaticEntries() error {
for lhIP, _ := range lh.lighthouses { for lhIP, _ := range lh.lighthouses {
if _, ok := lh.staticList[lhIP]; !ok { if _, ok := lh.staticList[lhIP]; !ok {
return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP)) return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
} }
} }
return nil return nil
} }
func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList { func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) { if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f) lh.QueryServer(ip, f)
} }
@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
} }
// This is asynchronous so no reply should be expected // This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
if lh.amLighthouse { if lh.amLighthouse {
return return
} }
@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
// Send a query to the lighthouses and hope for the best next time // Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip)) query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil { if err != nil {
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
return return
} }
@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for n := range lh.lighthouses { for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
} }
} }
func (lh *LightHouse) QueryCache(ip uint32) *RemoteList { func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
lh.RLock() lh.RLock()
if v, ok := lh.addrMap[ip]; ok { if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock() lh.RUnlock()
@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) { func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock() lh.RLock()
// Do we have an entry in the main cache? // Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok { if v, ok := lh.addrMap[vpnIp]; ok {
@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err
return false, 0, nil return false, 0, nil
} }
func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// First we check the static mapping // First we check the static mapping
// and do nothing if it is there // and do nothing if it is there
if _, ok := lh.staticList[vpnIP]; ok { if _, ok := lh.staticList[vpnIp]; ok {
return return
} }
lh.Lock() lh.Lock()
//l.Debugln(lh.addrMap) //l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP) delete(lh.addrMap, vpnIp)
if lh.l.Level >= logrus.DebugLevel { if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) lh.l.Debugf("deleting %s from lighthouse.", vpnIp)
} }
lh.Unlock() lh.Unlock()
@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
lh.Lock() lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp) am := lh.unlockedGetRemoteList(vpnIp)
am.Lock() am.Lock()
@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
} }
// unlockedGetRemoteList assumes you have the lh lock // unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList { func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
am, ok := lh.addrMap[vpnIP] am, ok := lh.addrMap[vpnIp]
if !ok { if !ok {
am = NewRemoteList() am = NewRemoteList()
lh.addrMap[vpnIP] = am lh.addrMap[vpnIp] = am
} }
return am return am
} }
// unlockedShouldAddV4 checks if to is allowed by our allow list // unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip) allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
} }
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) { if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
return false return false
} }
@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
} }
// unlockedShouldAddV6 checks if to is allowed by our allow list // unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool { func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo) allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip return ip
} }
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool { func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
if _, ok := lh.lighthouses[vpnIP]; ok { if _, ok := lh.lighthouses[vpnIp]; ok {
return true return true
} }
return false return false
} }
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta { func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
return &NebulaMeta{ return &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: VpnIp, VpnIp: uint32(VpnIp),
}, },
} }
} }
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
ipp := Ip4AndPort{Port: port} ipp := Ip4AndPort{Port: port}
ipp.Ip = ip2int(ip) ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
return &ipp return &ipp
} }
@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
} }
} }
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr { func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
ip := ipp.Ip ip := ipp.Ip
return NewUDPAddr( return udp.NewAddr(
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
uint16(ipp.Port), uint16(ipp.Port),
) )
} }
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr { func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
} }
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
if lh.amLighthouse || lh.interval == 0 { if lh.amLighthouse || lh.interval == 0 {
return return
} }
@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
} }
} }
func (lh *LightHouse) SendUpdate(f EncWriter) { func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
var v4 []*Ip4AndPort var v4 []*Ip4AndPort
var v6 []*Ip6AndPort var v6 []*Ip6AndPort
for _, e := range *localIps(lh.l, lh.localAllowList) { for _, e := range *localIps(lh.l, lh.localAllowList) {
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) { if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
continue continue
} }
@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
m := &NebulaMeta{ m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: lh.myVpnIp, VpnIp: uint32(lh.myVpnIp),
Ip4AndPorts: v4, Ip4AndPorts: v4,
Ip6AndPorts: v6, Ip6AndPorts: v6,
}, },
@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
} }
for vpnIp := range lh.lighthouses { for vpnIp := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out) f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
} }
} }
@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
} }
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) { func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i) lh.metrics.Rx(header.MessageType(t), 0, i)
} }
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) { func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i) lh.metrics.Tx(header.MessageType(t), 0, i)
} }
// This method is similar to Reset(), but it re-uses the pointer structs // This method is similar to Reset(), but it re-uses the pointer structs
@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta return lhh.meta
} }
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) { func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
n := lhh.resetMeta() n := lhh.resetMeta()
err := n.Unmarshal(p) err := n.Unmarshal(p)
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet") Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error? //TODO: send recv_error?
return return
} }
if n.Details == nil { if n.Details == nil {
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update") Error("Invalid lighthouse update")
//TODO: send recv_error? //TODO: send recv_error?
return return
@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
} }
} }
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) { func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
// Exit if we don't answer queries // Exit if we don't answer queries
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
} }
//TODO: we can DRY this further //TODO: we can DRY this further
reqVpnIP := n.Details.VpnIp reqVpnIp := n.Details.VpnIp
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) { found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIP n.Details.VpnIp = reqVpnIp
lhh.coalesceAnswers(c, n) lhh.coalesceAnswers(c, n)
@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
} }
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply")
return return
} }
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets // This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = vpnIp n.Details.VpnIp = uint32(vpnIp)
lhh.coalesceAnswers(c, n) lhh.coalesceAnswers(c, n)
@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
} }
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for") lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for")
return return
} }
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
} }
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
} }
} }
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) { func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
if !lhh.lh.IsLighthouseIP(vpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) {
return return
} }
lhh.lh.Lock() lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp) am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) certVpnIp := iputil.VpnIp(n.Details.VpnIp)
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock() am.Unlock()
// Non-blocking attempt to trigger, skip if it would block // Non-blocking attempt to trigger, skip if it would block
select { select {
case lhh.lh.handshakeTrigger <- n.Details.VpnIp: case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
default: default:
} }
} }
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) { func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
} }
//Simple check that the host sent this not someone else //Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp { if n.Details.VpnIp != uint32(vpnIp) {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
} }
return return
} }
@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) certVpnIp := iputil.VpnIp(n.Details.VpnIp)
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock() am.Unlock()
} }
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) {
return return
} }
empty := []byte{0} empty := []byte{0}
punch := func(vpnPeer *udpAddr) { punch := func(vpnPeer *udp.Addr) {
if vpnPeer == nil { if vpnPeer == nil {
return return
} }
@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp)) lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
} }
} }
@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
go func() { go func() {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp))
} }
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}() }()
} }
} }
// ipMaskContains checks if testIp is contained by ip after applying a cidr // ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size() // zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool { func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
return (testIp^ip)>>zeros == 0 return (testIp^ip)>>zeros == 0
} }

View File

@ -6,6 +6,10 @@ import (
"testing" "testing"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) {
var m Ip4AndPort var m Ip4AndPort
err := proto.Unmarshal(b, &m) err := proto.Unmarshal(b, &m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String()) assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
} }
func TestNewLhQuery(t *testing.T) { func TestNewLhQuery(t *testing.T) {
myIp := net.ParseIP("192.1.1.1") myIp := net.ParseIP("192.1.1.1")
myIpint := ip2int(myIp) myIpint := iputil.Ip2VpnIp(myIp)
// Generating a new lh query should work // Generating a new lh query should work
a := NewLhQueryByInt(myIpint) a := NewLhQueryByInt(myIpint)
@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) {
} }
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true) udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
err := meh.ValidateLHStaticEntries() err := meh.ValidateLHStaticEntries()
assert.Nil(t, err) assert.Nil(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2) lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
err = meh.ValidateLHStaticEntries() err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := NewTestLogger() l := util.NewTestLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true) udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
hAddr := NewUDPAddrFromString("4.5.6.7:12345") hAddr := udp.NewAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList() lh.addrMap[3] = NewRemoteList()
lh.addrMap[3].unlockedSetV4( lh.addrMap[3].unlockedSetV4(
3, 3,
@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
}, },
func(uint32, *Ip4AndPort) bool { return true }, func(iputil.VpnIp, *Ip4AndPort) bool { return true },
) )
rAddr := NewUDPAddrFromString("1.2.2.3:12345") rAddr := udp.NewAddrFromString("1.2.2.3:12345")
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList() lh.addrMap[2] = NewRemoteList()
lh.addrMap[2].unlockedSetV4( lh.addrMap[2].unlockedSetV4(
3, 3,
@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
}, },
func(uint32, *Ip4AndPort) bool { return true }, func(iputil.VpnIp, *Ip4AndPort) bool { return true },
) )
mw := &mockEncWriter{} mw := &mockEncWriter{}
@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
} }
func TestLighthouse_Memory(t *testing.T) { func TestLighthouse_Memory(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242} myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242} myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242} myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242} myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242} myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243} myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244} myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245} myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246} myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247} myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248} myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249} myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
myVpnIp := ip2int(net.ParseIP("10.128.0.2")) myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242} theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242} theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242} theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242} theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242} theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
theirVpnIp := ip2int(net.ParseIP("10.128.0.3")) theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
udpServer, _ := NewListener(l, "0.0.0.0", 0, true) udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false) lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses // Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2 // Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host // Update a different host
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) {
newLHHostUpdate( newLHHostUpdate(
myUdpAddr0, myUdpAddr0,
myVpnIp, myVpnIp,
[]*udpAddr{ []*udp.Addr{
myUdpAddr1, myUdpAddr1,
myUdpAddr2, myUdpAddr2,
myUdpAddr3, myUdpAddr3,
@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) {
) )
// Make sure we won't add ips in our vpn network // Make sure we won't add ips in our vpn network
bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242} bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242} bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242} good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
} }
func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: queryVpnIp, VpnIp: uint32(queryVpnIp),
}, },
} }
@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH
return w.lastReply return w.lastReply
} }
func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) { func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: vpnIp, VpnIp: uint32(vpnIp),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
}, },
} }
for k, v := range addrs { for k, v := range addrs {
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)} req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
} }
b, err := req.Marshal() b, err := req.Marshal()
@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
//} //}
func Test_ipMaskContains(t *testing.T) { func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255")))) assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1")))) assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1")))) assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
} }
type testLhReply struct { type testLhReply struct {
nebType NebulaMessageType nebType header.MessageType
nebSubType NebulaMessageSubType nebSubType header.MessageSubType
vpnIp uint32 vpnIp iputil.VpnIp
msg *NebulaMeta msg *NebulaMeta
} }
@ -343,7 +347,7 @@ type testEncWriter struct {
lastReply testLhReply lastReply testLhReply
} }
func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
tw.lastReply = testLhReply{ tw.lastReply = testLhReply{
nebType: t, nebType: t,
nebSubType: st, nebSubType: st,
@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag
} }
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) { func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
assert.Len(t, have, len(want)) assert.Len(t, have, len(want))
for k, w := range want { for k, w := range want {
if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) { if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
} }
} }
} }
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) { func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
assert.Len(t, have, len(want)) assert.Len(t, have, len(want))
for k, w := range want { for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
} }
} }
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr { func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
addrs := make([]*udpAddr, len(ips)) addrs := make([]*udp.Addr, len(ips))
for k, v := range ips { for k, v := range ips {
addrs[k] = NewUDPAddrFromLH4(v) addrs[k] = NewUDPAddrFromLH4(v)
} }

View File

@ -2,8 +2,12 @@ package nebula
import ( import (
"errors" "errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
) )
type ContextualError struct { type ContextualError struct {
@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) {
lr.WithFields(ce.Fields).Error(ce.Context) lr.WithFields(ce.Fields).Error(ce.Context)
} }
} }
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

139
main.go
View File

@ -8,14 +8,16 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
type m map[string]interface{} type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() { defer func() {
@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// Print the config if in test, the exit comes later // Print the config if in test, the exit comes later
if configTest { if configTest {
b, err := yaml.Marshal(config.Settings) b, err := yaml.Marshal(c.Settings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
l.Println(string(b)) l.Println(string(b))
} }
err := configLogger(config) err := configLogger(l, c)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to configure the logger", nil, err) return nil, NewContextualError("Failed to configure the logger", nil, err)
} }
config.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(c) err := configLogger(l, c)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to configure the logger") l.WithError(err).Error("Failed to configure the logger")
} }
}) })
caPool, err := loadCAFromConfig(l, config) caPool, err := loadCAFromConfig(l, c)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err) return nil, NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config) cs, err := NewCertStateFromConfig(c)
if err != nil { if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted //The errors coming out of NewCertStateFromConfig are already nicely formatted
return nil, NewContextualError("Failed to load certificate from config", nil, err) return nil, NewContextualError("Failed to load certificate from config", nil, err)
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(l, cs.certificate, config) fw, err := NewFirewallFromConfig(l, cs.certificate, c)
if err != nil { if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err) return nil, NewContextualError("Error while loading firewall rules", nil, err)
} }
@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// TODO: make sure mask is 4 bytes // TODO: make sure mask is 4 bytes
tunCidr := cs.certificate.Details.Ips[0] tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr) routes, err := parseRoutes(c, tunCidr)
if err != nil { if err != nil {
return nil, NewContextualError("Could not parse tun.routes", nil, err) return nil, NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil { if err != nil {
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, config) wireSSHReload(l, ssh, c)
var sshStart func() var sshStart func()
if config.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, config) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err) return nil, NewContextualError("Error while configuring the sshd", nil, err)
} }
@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var routines int var routines int
// If `routines` is set, use that and ignore the specific values // If `routines` is set, use that and ignore the specific values
if routines = config.GetInt("routines", 0); routines != 0 { if routines = c.GetInt("routines", 0); routines != 0 {
if routines < 1 { if routines < 1 {
routines = 1 routines = 1
} }
@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
} else { } else {
// deprecated and undocumented // deprecated and undocumented
tunQueues := config.GetInt("tun.routines", 1) tunQueues := c.GetInt("tun.routines", 1)
udpQueues := config.GetInt("listen.routines", 1) udpQueues := c.GetInt("listen.routines", 1)
if tunQueues > udpQueues { if tunQueues > udpQueues {
routines = tunQueues routines = tunQueues
} else { } else {
@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// EXPERIMENTAL // EXPERIMENTAL
// Intentionally not documented yet while we do more testing and determine // Intentionally not documented yet while we do more testing and determine
// a good default value. // a good default value.
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0) conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") { if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
// Use a different default if we are running with multiple routines // Use a different default if we are running with multiple routines
conntrackCacheTimeout = 1 * time.Second conntrackCacheTimeout = 1 * time.Second
} }
@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var tun Inside var tun Inside
if !configTest { if !configTest {
config.CatchHUP(ctx) c.CatchHUP(ctx)
switch { switch {
case config.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l) tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
case tunFd != nil: case tunFd != nil:
tun, err = newTunFromFd( tun, err = newTunFromFd(
l, l,
*tunFd, *tunFd,
tunCidr, tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU), c.GetInt("tun.mtu", DEFAULT_MTU),
routes, routes,
unsafeRoutes, unsafeRoutes,
config.GetInt("tun.tx_queue", 500), c.GetInt("tun.tx_queue", 500),
) )
default: default:
tun, err = newTun( tun, err = newTun(
l, l,
config.GetString("tun.dev", ""), c.GetString("tun.dev", ""),
tunCidr, tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU), c.GetInt("tun.mtu", DEFAULT_MTU),
routes, routes,
unsafeRoutes, unsafeRoutes,
config.GetInt("tun.tx_queue", 500), c.GetInt("tun.tx_queue", 500),
routines > 1, routines > 1,
) )
} }
@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}() }()
// set up our UDP listener // set up our UDP listener
udpConns := make([]*udpConn, routines) udpConns := make([]*udp.Conn, routines)
port := config.GetInt("listen.port", 0) port := c.GetInt("listen.port", 0)
if !configTest { if !configTest {
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1) udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil { if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
udpServer.reloadConfig(config) udpServer.ReloadConfig(c)
udpConns[i] = udpServer udpConns[i] = udpServer
// If port is dynamic, discover it // If port is dynamic, discover it
@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// Set up my internal host map // Set up my internal host map
var preferredRanges []*net.IPNet var preferredRanges []*net.IPNet
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range' // First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 { if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// local_range was superseded by preferred_ranges. If it is still present, // local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably // merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future. // deprecate local_range and remove in the future.
rawLocalRange := config.GetString("local_range", "") rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" { if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange) _, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil { if err != nil {
@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go hostMap.Promoter(config.GetInt("promoter.interval")) go hostMap.Promoter(config.GetInt("promoter.interval"))
*/ */
punchy := NewPunchyFromConfig(config) punchy := NewPunchyFromConfig(c)
if punchy.Punch && !configTest { if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled") l.Info("UDP hole punching enabled")
go hostMap.Punchy(ctx, udpConns[0]) go hostMap.Punchy(ctx, udpConns[0])
} }
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
// fatal if am_lighthouse is enabled but we are using an ephemeral port // fatal if am_lighthouse is enabled but we are using an ephemeral port
if amLighthouse && (config.GetInt("listen.port", 0) == 0) { if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
} }
// warn if am_lighthouse is enabled but upstream lighthouses exists // warn if am_lighthouse is enabled but upstream lighthouses exists
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{}) rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
if amLighthouse && len(rawLighthouseHosts) != 0 { if amLighthouse && len(rawLighthouseHosts) != 0 {
l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
} }
lighthouseHosts := make([]uint32, len(rawLighthouseHosts)) lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
for i, host := range rawLighthouseHosts { for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if !tunCidr.Contains(ip) { if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
} }
lighthouseHosts[i] = ip2int(ip) lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
} }
lightHouse := NewLightHouse( lightHouse := NewLightHouse(
@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
tunCidr, tunCidr,
lighthouseHosts, lighthouseHosts,
//TODO: change to a duration //TODO: change to a duration
config.GetInt("lighthouse.interval", 10), c.GetInt("lighthouse.interval", 10),
uint32(port), uint32(port),
udpConns[0], udpConns[0],
punchy.Respond, punchy.Respond,
punchy.Delay, punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false), c.GetBool("stats.lighthouse_metrics", false),
) )
remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
if err != nil { if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
} }
lightHouse.SetRemoteAllowList(remoteAllowList) lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list") localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
if err != nil { if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
} }
lightHouse.SetLocalAllowList(localAllowList) lightHouse.SetLocalAllowList(localAllowList)
//TODO: Move all of this inside functions in lighthouse.go //TODO: Move all of this inside functions in lighthouse.go
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) ip := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) { vpnIp := iputil.Ip2VpnIp(ip)
if !tunCidr.Contains(ip) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
} }
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if ok {
for _, v := range vals { for _, v := range vals {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
} }
} else { } else {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
} }
} }
@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
var messageMetrics *MessageMetrics var messageMetrics *MessageMetrics
if config.GetBool("stats.message_metrics", false) { if c.GetBool("stats.message_metrics", false) {
messageMetrics = newMessageMetrics() messageMetrics = newMessageMetrics()
} else { } else {
messageMetrics = newMessageMetricsOnlyRecvError() messageMetrics = newMessageMetricsOnlyRecvError()
} }
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics, messageMetrics: messageMetrics,
} }
@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
serveDns := false serveDns := false
if config.GetBool("lighthouse.serve_dns", false) { if c.GetBool("lighthouse.serve_dns", false) {
if config.GetBool("lighthouse.am_lighthouse", false) { if c.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true serveDns = true
} else { } else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.") l.Warn("DNS server refusing to run because this host is not a lighthouse.")
} }
} }
checkInterval := config.GetInt("timers.connection_alive_interval", 5) checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10) pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
Outside: udpConns[0], Outside: udpConns[0],
certState: cs, certState: cs,
Cipher: config.GetString("cipher", "aes"), Cipher: c.GetString("cipher", "aes"),
Firewall: fw, Firewall: fw,
ServeDns: serveDns, ServeDns: serveDns,
HandshakeManager: handshakeManager, HandshakeManager: handshakeManager,
lightHouse: lightHouse, lightHouse: lightHouse,
checkInterval: checkInterval, checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval, pendingDeletionInterval: pendingDeletionInterval,
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false), DropMulticast: c.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
routines: routines, routines: routines,
MessageMetrics: messageMetrics, MessageMetrics: messageMetrics,
version: buildVersion, version: buildVersion,
caPool: caPool, caPool: caPool,
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false), disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
l: l, l: l,
@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// I don't want to make this initial commit too far-reaching though // I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns ifce.writers = udpConns
ifce.RegisterConfigChangeCallbacks(config) ifce.RegisterConfigChangeCallbacks(c)
go handshakeManager.Run(ctx, ifce) go handshakeManager.Run(ctx, ifce)
go lightHouse.LhUpdateWorker(ctx, ifce) go lightHouse.LhUpdateWorker(ctx, ifce)
@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done. // a context so that they can exit when the context is Done.
statsStart, err := startStats(l, config, buildVersion, configTest) statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err) return nil, NewContextualError("Failed to start stats emitter", nil, err)
} }
@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var dnsStart func() var dnsStart func()
if amLighthouse && serveDns { if amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, config) dnsStart = dnsMain(l, hostMap, c)
} }
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil

View File

@ -4,8 +4,11 @@ import (
"fmt" "fmt"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/header"
) )
//TODO: this can probably move into the header package
type MessageMetrics struct { type MessageMetrics struct {
rx [][]metrics.Counter rx [][]metrics.Counter
tx [][]metrics.Counter tx [][]metrics.Counter
@ -14,7 +17,7 @@ type MessageMetrics struct {
txUnknown metrics.Counter txUnknown metrics.Counter
} }
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) { func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) { if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i) m.rx[t][s].Inc(i)
@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64
} }
} }
} }
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) { func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) { if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i) m.tx[t][s].Inc(i)

View File

@ -10,6 +10,10 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@ -17,8 +21,8 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) { func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := header.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// TODO: best if we return this and let caller log // TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
// verify if we've seen this index before, otherwise respond to the handshake initiation // verify if we've seen this index before, otherwise respond to the handshake initiation
hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex) hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex)
var ci *ConnectionState var ci *ConnectionState
if err == nil { if err == nil {
ci = hostinfo.ConnectionState ci = hostinfo.ConnectionState
} }
switch header.Type { switch h.Type {
case message: case header.Message:
if !f.handleEncrypted(ci, addr, header) { if !f.handleEncrypted(ci, addr, h) {
return return
} }
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache) f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case lightHouse: case header.LightHouse:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) { if !f.handleEncrypted(ci, addr, h) {
return return
} }
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet). WithField("packet", packet).
@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return return
} }
lhh.HandleRequest(addr, hostinfo.hostId, d, f) lhf(addr, hostinfo.vpnIp, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case test: case header.Test:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) { if !f.handleEncrypted(ci, addr, h) {
return return
} }
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet). WithField("packet", packet).
@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return return
} }
if header.Subtype == testRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, addr) f.handleHostRoaming(hostinfo, addr)
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out)
} }
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated // are unauthenticated
case handshake: case header.Handshake:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, header, hostinfo) HandleIncomingHandshake(f, addr, packet, h, hostinfo)
return return
case recvError: case header.RecvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(addr, header) f.handleRecvError(addr, h)
return return
case closeTunnel: case header.CloseTunnel:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) { if !f.handleEncrypted(ci, addr, h) {
return return
} }
@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return return
default: default:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
return return
} }
f.handleHostRoaming(hostinfo, addr) f.handleHostRoaming(hostinfo, addr)
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.vpnIp)
} }
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) { func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId) f.connectionManager.ClearIP(hostInfo.vpnIp)
f.connectionManager.ClearPendingDeletion(hostInfo.hostId) f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp)
f.lightHouse.DeleteVpnIP(hostInfo.hostId) f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
if hasHostMapLock { if hasHostMapLock {
f.hostMap.unlockedDeleteHostInfo(hostInfo) f.hostMap.unlockedDeleteHostInfo(hostInfo)
@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
if hostDidRoam(hostinfo.remote, addr) { if !hostinfo.remote.Equals(addr) {
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
} }
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet // If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection. // Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) { if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex) f.sendRecvError(addr, h.RemoteIndex)
return false return false
} }
@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *
} }
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data? // Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen { if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
// Accounting for a variable header length, do we have enough data for our src/dst tuples? // Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl minLen := ihl
if !fp.Fragment && fp.Protocol != fwProtoICMP { if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
minLen += minFwPacketLen minLen += minFwPacketLen
} }
if len(data) < minLen { if len(data) < minLen {
@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
// Firewall packets are locally oriented // Firewall packets are locally oriented
if incoming { if incoming {
fp.RemoteIP = binary.BigEndian.Uint32(data[12:16]) fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
fp.LocalIP = binary.BigEndian.Uint32(data[16:20]) fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
} else { } else {
@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
} }
} else { } else {
fp.LocalIP = binary.BigEndian.Uint32(data[12:16]) fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
fp.RemoteIP = binary.BigEndian.Uint32(data[16:20]) fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
} else { } else {
@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
return nil return nil
} }
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) { func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !hostinfo.ConnectionState.window.Update(f.l, mc) { if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", header). hostinfo.logger(f.l).WithField("header", h).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return nil, errors.New("out of window packet") return nil, errors.New("out of window packet")
} }
@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) { func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB //TODO: maybe after build 64 is out? 06/14/2018 - NB
@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return return
} }
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.vpnIp)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
} }
} }
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
f.messageMetrics.Tx(recvError, 0, 1) f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index //TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
f.outside.WriteTo(b, endpoint) f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index). f.l.WithField("index", index).
@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
} }
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex). f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr). WithField("udpAddr", addr).

View File

@ -4,12 +4,14 @@ import (
"net" "net"
"testing" "testing"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
func Test_newPacket(t *testing.T) { func Test_newPacket(t *testing.T) {
p := &FirewallPacket{} p := &firewall.Packet{}
// length fail // length fail
err := newPacket([]byte{0, 1}, true, p) err := newPacket([]byte{0, 1}, true, p)
@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) {
Src: net.IPv4(10, 0, 0, 1), Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2), Dst: net.IPv4(10, 0, 0, 2),
Options: []byte{0, 1, 0, 2}, Options: []byte{0, 1, 0, 2},
Protocol: fwProtoTCP, Protocol: firewall.ProtoTCP,
} }
b, _ = h.Marshal() b, _ = h.Marshal()
@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) {
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(fwProtoTCP)) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2))) assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1))) assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4)) assert.Equal(t, p.LocalPort, uint16(4))
@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2)) assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1))) assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2))) assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5)) assert.Equal(t, p.LocalPort, uint16(5))
} }

View File

@ -1,6 +1,10 @@
package nebula package nebula
import "time" import (
"time"
"github.com/slackhq/nebula/config"
)
type Punchy struct { type Punchy struct {
Punch bool Punch bool
@ -8,7 +12,7 @@ type Punchy struct {
Delay time.Duration Delay time.Duration
} }
func NewPunchyFromConfig(c *Config) *Punchy { func NewPunchyFromConfig(c *config.C) *Punchy {
p := &Punchy{} p := &Punchy{}
if c.IsSet("punchy.punch") { if c.IsSet("punchy.punch") {

View File

@ -4,12 +4,14 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := NewConfig(l) c := config.NewC(l)
// Test defaults // Test defaults
p := NewPunchyFromConfig(c) p := NewPunchyFromConfig(c)

View File

@ -5,14 +5,17 @@ import (
"net" "net"
"sort" "sort"
"sync" "sync"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// forEachFunc is used to benefit folks that want to do work inside the lock // forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udpAddr, preferred bool) type forEachFunc func(addr *udp.Addr, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans // CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp // The string key is the owners vpnIp
@ -21,8 +24,8 @@ type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here // We don't reason about ipv4 vs ipv6 here
type Cache struct { type Cache struct {
Learned []*udpAddr `json:"learned,omitempty"` Learned []*udp.Addr `json:"learned,omitempty"`
Reported []*udpAddr `json:"reported,omitempty"` Reported []*udp.Addr `json:"reported,omitempty"`
} }
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@ -53,16 +56,16 @@ type RemoteList struct {
sync.RWMutex sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udpAddr addrs []*udp.Addr
// These are maps to store v4 and v6 addresses per lighthouse // These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath. // Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet // For learned addresses, this is the vpnIp that sent the packet
cache map[uint32]*cache cache map[iputil.VpnIp]*cache
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
badRemotes []*udpAddr badRemotes []*udp.Addr
// A flag that the cache may have changed and addrs needs to be rebuilt // A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool shouldRebuild bool
@ -71,8 +74,8 @@ type RemoteList struct {
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList { func NewRemoteList() *RemoteList {
return &RemoteList{ return &RemoteList{
addrs: make([]*udpAddr, 0), addrs: make([]*udp.Addr, 0),
cache: make(map[uint32]*cache), cache: make(map[iputil.VpnIp]*cache),
} }
} }
@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc)
// CopyAddrs locks and makes a deep copy of the deduplicated address list // CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr { func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
if r == nil { if r == nil {
return nil return nil
} }
@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]*udpAddr, len(r.addrs)) c := make([]*udp.Addr, len(r.addrs))
for i, v := range r.addrs { for i, v := range r.addrs {
c[i] = v.Copy() c[i] = v.Copy()
} }
@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list //TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) { func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil { if v4 := addr.IP.To4(); v4 != nil {
@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
c := cm[vpnIp] c := cm[vpnIp]
if c == nil { if c == nil {
c = &Cache{ c = &Cache{
Learned: make([]*udpAddr, 0), Learned: make([]*udp.Addr, 0),
Reported: make([]*udpAddr, 0), Reported: make([]*udp.Addr, 0),
} }
cm[vpnIp] = c cm[vpnIp] = c
} }
@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
for owner, mc := range r.cache { for owner, mc := range r.cache {
c := getOrMake(IntIp(owner).String()) c := getOrMake(owner.String())
if mc.v4 != nil { if mc.v4 != nil {
if mc.v4.learned != nil { if mc.v4.learned != nil {
@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udpAddr) { func (r *RemoteList) BlockRemote(bad *udp.Addr) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) {
} }
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr { func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]*udpAddr, len(r.badRemotes)) c := make([]*udp.Addr, len(r.badRemotes))
for i, v := range r.badRemotes { for i, v := range r.badRemotes {
c[i] = v.Copy() c[i] = v.Copy()
} }
@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
} }
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool { func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
for _, v := range r.badRemotes { for _, v := range r.badRemotes {
if v.Equals(remote) { if v.Equals(remote) {
return true return true
@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) { func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
} }
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) { func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) { func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
} }
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) { func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required // The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 { func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
am := r.cache[ownerVpnIp] am := r.cache[ownerVpnIp]
if am == nil { if am == nil {
am = &cache{} am = &cache{}
@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required // The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 { func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
am := r.cache[ownerVpnIp] am := r.cache[ownerVpnIp]
if am == nil { if am == nil {
am = &cache{} am = &cache{}

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) {
0, 0,
0, 0,
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
}, },
func(uint32, *Ip4AndPort) bool { return true }, func(iputil.VpnIp, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
}, },
func(uint32, *Ip6AndPort) bool { return true }, func(iputil.VpnIp, *Ip6AndPort) bool { return true },
) )
rl.Rebuild([]*net.IPNet{}) rl.Rebuild([]*net.IPNet{})
@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) {
0, 0,
0, 0,
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
}, },
func(uint32, *Ip4AndPort) bool { return true }, func(iputil.VpnIp, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) {
NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
}, },
func(uint32, *Ip6AndPort) bool { return true }, func(iputil.VpnIp, *Ip6AndPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) {
0, 0,
0, 0,
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
}, },
func(uint32, *Ip4AndPort) bool { return true }, func(iputil.VpnIp, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) {
NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
}, },
func(uint32, *Ip6AndPort) bool { return true }, func(iputil.VpnIp, *Ip6AndPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {

56
ssh.go
View File

@ -15,7 +15,11 @@ import (
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
) )
type sshListHostMapFlags struct { type sshListHostMapFlags struct {
@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct {
Address string Address string
} }
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) { func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshRun, err := configSSH(l, ssh, c) sshRun, err := configSSH(l, ssh, c)
if err != nil { if err != nil {
@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
// updates the passed-in SSHServer. On success, it returns a function // updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On // that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error. // failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) { func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
//TODO conntrack list //TODO conntrack list
//TODO print firewall rules or hash? //TODO print firewall rules or hash?
@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
hm := listHostMap(hostMap) hm := listHostMap(hostMap)
sort.Slice(hm, func(i, j int) bool { sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0 return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
} else { } else {
for _, v := range hm { for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs)) err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
if err != nil { if err != nil {
return err return err
} }
@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
} }
type lighthouseInfo struct { type lighthouseInfo struct {
VpnIP net.IP `json:"vpnIp"` VpnIp string `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"` Addrs *CacheMap `json:"addrs"`
} }
@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
x := 0 x := 0
for k, v := range lightHouse.addrMap { for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{ addrMap[x] = lighthouseInfo{
VpnIP: int2ip(k), VpnIp: k.String(),
Addrs: v.CopyCache(), Addrs: v.CopyCache(),
} }
x++ x++
@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
lightHouse.RUnlock() lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool { sort.Slice(addrMap, func(i, j int) bool {
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0 return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
if err != nil { if err != nil {
return err return err
} }
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b))) err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
if err != nil { if err != nil {
return err return err
} }
@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
if !flags.LocalOnly { if !flags.LocalOnly {
ifce.send( ifce.send(
closeTunnel, header.CloseTunnel,
0, 0,
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists")) return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
} }
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
} }
var addr *udpAddr var addr *udp.Addr
if flags.Address != "" { if flags.Address != "" {
addr = NewUDPAddrFromString(flags.Address) addr = udp.NewAddrFromString(flags.Address)
if addr == nil { if addr == nil {
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
} }
hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp) hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
if addr != nil { if addr != nil {
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
} }
@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No address was provided") return w.WriteLine("No address was provided")
} }
addr := NewUDPAddrFromString(flags.Address) addr := udp.NewAddrFromString(flags.Address)
if addr == nil { if addr == nil {
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := ip2int(parsedIp) vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp) hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }

View File

@ -15,12 +15,13 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
) )
// startStats initializes stats from config. On success, if any futher work // startStats initializes stats from config. On success, if any futher work
// is needed to serve stats, it returns a func to handle that work. If no // is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error. // work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) { func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "") mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" { if mType == "" || mType == "none" {
return nil, nil return nil, nil
@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
return startFn, nil return startFn, nil
} }
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error { func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp") proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "") host := c.GetString("stats.host", "")
if host == "" { if host == "" {
@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
return nil return nil
} }
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) { func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "") namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "") subsystem := c.GetString("stats.subsystem", "")

View File

@ -2,12 +2,14 @@ package nebula
import ( import (
"time" "time"
"github.com/slackhq/nebula/firewall"
) )
// How many timer objects should be cached // How many timer objects should be cached
const timerCacheMax = 50000 const timerCacheMax = 50000
var emptyFWPacket = FirewallPacket{} var emptyFWPacket = firewall.Packet{}
type TimerWheel struct { type TimerWheel struct {
// Current tick // Current tick
@ -42,7 +44,7 @@ type TimeoutList struct {
// Represents an item within a tick // Represents an item within a tick
type TimeoutItem struct { type TimeoutItem struct {
Packet FirewallPacket Packet firewall.Packet
Next *TimeoutItem Next *TimeoutItem
} }
@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
return &tw return &tw
} }
// Add will add a FirewallPacket to the wheel in it's proper timeout // Add will add a firewall.Packet to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem { func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
// Check and see if we should progress the tick // Check and see if we should progress the tick
tw.advance(time.Now()) tw.advance(time.Now())
@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem
return ti return ti
} }
func (tw *TimerWheel) Purge() (FirewallPacket, bool) { func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
if tw.expired.Head == nil { if tw.expired.Head == nil {
return emptyFWPacket, false return emptyFWPacket, false
} }

View File

@ -3,6 +3,8 @@ package nebula
import ( import (
"sync" "sync"
"time" "time"
"github.com/slackhq/nebula/iputil"
) )
// How many timer objects should be cached // How many timer objects should be cached
@ -43,7 +45,7 @@ type SystemTimeoutList struct {
// Represents an item within a tick // Represents an item within a tick
type SystemTimeoutItem struct { type SystemTimeoutItem struct {
Item uint32 Item iputil.VpnIp
Next *SystemTimeoutItem Next *SystemTimeoutItem
} }
@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
return &tw return &tw
} }
func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem { func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
tw.lock.Lock() tw.lock.Lock()
defer tw.lock.Unlock() defer tw.lock.Unlock()

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
func TestSystemTimerWheel_Add(t *testing.T) { func TestSystemTimerWheel_Add(t *testing.T) {
tw := NewSystemTimerWheel(time.Second, time.Second*10) tw := NewSystemTimerWheel(time.Second, time.Second*10)
fp1 := ip2int(net.ParseIP("1.2.3.4")) fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
tw.Add(fp1, time.Second*1) tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly // Make sure we set head and tail properly
@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) {
assert.Nil(t, tw.wheel[2].Tail.Next) assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head // Make sure we only modify head
fp2 := ip2int(net.ParseIP("1.2.3.4")) fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
tw.Add(fp2, time.Second*1) tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Item) assert.Equal(t, fp2, tw.wheel[2].Head.Item)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
assert.NotNil(t, tw.lastTick) assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current) assert.Equal(t, 0, tw.current)
fps := []uint32{9, 10, 11, 12} fps := []iputil.VpnIp{9, 10, 11, 12}
//fp1 := ip2int(net.ParseIP("1.2.3.4")) //fp1 := ip2int(net.ParseIP("1.2.3.4"))

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) {
func TestTimerWheel_Add(t *testing.T) { func TestTimerWheel_Add(t *testing.T) {
tw := NewTimerWheel(time.Second, time.Second*10) tw := NewTimerWheel(time.Second, time.Second*10)
fp1 := FirewallPacket{} fp1 := firewall.Packet{}
tw.Add(fp1, time.Second*1) tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly // Make sure we set head and tail properly
@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) {
assert.Nil(t, tw.wheel[2].Tail.Next) assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head // Make sure we only modify head
fp2 := FirewallPacket{} fp2 := firewall.Packet{}
tw.Add(fp2, time.Second*1) tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Packet) assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet) assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.NotNil(t, tw.lastTick) assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current) assert.Equal(t, 0, tw.current)
fps := []FirewallPacket{ fps := []firewall.Packet{
{LocalIP: 1}, {LocalIP: 1},
{LocalIP: 2}, {LocalIP: 2},
{LocalIP: 3}, {LocalIP: 3},

View File

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"github.com/slackhq/nebula/config"
) )
const DEFAULT_MTU = 1300 const DEFAULT_MTU = 1300
@ -14,10 +16,10 @@ type route struct {
via *net.IP via *net.IP
} }
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
var err error var err error
r := config.Get("tun.routes") r := c.Get("tun.routes")
if r == nil { if r == nil {
return []route{}, nil return []route{}, nil
} }
@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
return routes, nil return routes, nil
} }
func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) { func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
var err error var err error
r := config.Get("tun.unsafe_routes") r := c.Get("tun.unsafe_routes")
if r == nil { if r == nil {
return []route{}, nil return []route{}, nil
} }
@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
rMtu, ok := m["mtu"] rMtu, ok := m["mtu"]
if !ok { if !ok {
rMtu = config.GetInt("tun.mtu", DEFAULT_MTU) rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
} }
mtu, ok := rMtu.(int) mtu, ok := rMtu.(int)

View File

@ -5,12 +5,14 @@ import (
"net" "net"
"testing" "testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := NewConfig(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config // test no routes config
@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) {
} }
func Test_parseUnsafeRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) {
l := NewTestLogger() l := util.NewTestLogger()
c := NewConfig(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config // test no routes config

20
udp/conn.go Normal file
View File

@ -0,0 +1,20 @@
package udp
import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
const MTU = 9001
type EncReader func(
addr *Addr,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
)

14
udp/temp.go Normal file
View File

@ -0,0 +1,14 @@
package udp
import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
type EncWriter interface {
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
}
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)

View File

@ -1,4 +1,4 @@
package nebula package udp
import ( import (
"encoding/json" "encoding/json"
@ -7,32 +7,34 @@ import (
"strconv" "strconv"
) )
type udpAddr struct { type m map[string]interface{}
type Addr struct {
IP net.IP IP net.IP
Port uint16 Port uint16
} }
func NewUDPAddr(ip net.IP, port uint16) *udpAddr { func NewAddr(ip net.IP, port uint16) *Addr {
addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port} addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip.To16()) copy(addr.IP, ip.To16())
return &addr return &addr
} }
func NewUDPAddrFromString(s string) *udpAddr { func NewAddrFromString(s string) *Addr {
ip, port, err := parseIPAndPort(s) ip, port, err := ParseIPAndPort(s)
//TODO: handle err //TODO: handle err
_ = err _ = err
return &udpAddr{IP: ip.To16(), Port: port} return &Addr{IP: ip.To16(), Port: port}
} }
func (ua *udpAddr) Equals(t *udpAddr) bool { func (ua *Addr) Equals(t *Addr) bool {
if t == nil || ua == nil { if t == nil || ua == nil {
return t == nil && ua == nil return t == nil && ua == nil
} }
return ua.IP.Equal(t.IP) && ua.Port == t.Port return ua.IP.Equal(t.IP) && ua.Port == t.Port
} }
func (ua *udpAddr) String() string { func (ua *Addr) String() string {
if ua == nil { if ua == nil {
return "<nil>" return "<nil>"
} }
@ -40,7 +42,7 @@ func (ua *udpAddr) String() string {
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
} }
func (ua *udpAddr) MarshalJSON() ([]byte, error) { func (ua *Addr) MarshalJSON() ([]byte, error) {
if ua == nil { if ua == nil {
return nil, nil return nil, nil
} }
@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) {
return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
} }
func (ua *udpAddr) Copy() *udpAddr { func (ua *Addr) Copy() *Addr {
if ua == nil { if ua == nil {
return nil return nil
} }
nu := udpAddr{ nu := Addr{
Port: ua.Port, Port: ua.Port,
IP: make(net.IP, len(ua.IP)), IP: make(net.IP, len(ua.IP)),
} }
@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr {
return &nu return &nu
} }
func parseIPAndPort(s string) (net.IP, uint16, error) { func ParseIPAndPort(s string) (net.IP, uint16, error) {
rIp, sPort, err := net.SplitHostPort(s) rIp, sPort, err := net.SplitHostPort(s)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing //go:build !e2e_testing
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
import ( import (
"fmt" "fmt"
@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
return nil return nil
} }

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing //go:build !e2e_testing
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
file, err := u.File() file, err := u.File()
if err != nil { if err != nil {
return err return err

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing //go:build !e2e_testing
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
return nil return nil
} }

View File

@ -5,7 +5,7 @@
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.
package nebula package udp
import ( import (
"context" "context"
@ -13,36 +13,39 @@ import (
"net" "net"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
type udpConn struct { type Conn struct {
*net.UDPConn *net.UDPConn
l *logrus.Logger l *logrus.Logger
} }
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
lc := NewListenConfig(multi) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if uc, ok := pc.(*net.UDPConn); ok { if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc, l: l}, nil return &Conn{UDPConn: uc, l: l}, nil
} }
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
} }
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) _, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
return err return err
} }
func (uc *udpConn) LocalAddr() (*udpAddr, error) { func (uc *Conn) LocalAddr() (*Addr, error) {
a := uc.UDPConn.LocalAddr() a := uc.UDPConn.LocalAddr()
switch v := a.(type) { switch v := a.(type) {
case *net.UDPAddr: case *net.UDPAddr:
addr := &udpAddr{IP: make([]byte, len(v.IP))} addr := &Addr{IP: make([]byte, len(v.IP))}
copy(addr.IP, v.IP) copy(addr.IP, v.IP)
addr.Port = uint16(v.Port) addr.Port = uint16(v.Port)
return addr, nil return addr, nil
@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
} }
} }
func (u *udpConn) reloadConfig(c *Config) { func (u *Conn) ReloadConfig(c *config.C) {
// TODO // TODO
} }
func NewUDPStatsEmitter(udpConns []*udpConn) func() { func NewUDPStatsEmitter(udpConns []*Conn) func() {
// No UDP stats for non-linux // No UDP stats for non-linux
return func() {} return func() {}
} }
@ -65,32 +68,24 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *udpConn) ListenOut(f *Interface, q int) { func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, mtu) plaintext := make([]byte, MTU)
buffer := make([]byte, mtu) buffer := make([]byte, MTU)
header := &Header{} h := &header.H{}
fwPacket := &FirewallPacket{} fwPacket := &firewall.Packet{}
udpAddr := &udpAddr{IP: make([]byte, 16)} udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer) n, rua, err := u.ReadFromUDP(buffer)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to read packets") u.l.WithError(err).Error("Failed to read packets")
continue continue
} }
udpAddr.IP = rua.IP udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port) udpAddr.Port = uint16(rua.Port)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l)) r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -1,7 +1,7 @@
//go:build !android && !e2e_testing //go:build !android && !e2e_testing
// +build !android,!e2e_testing // +build !android,!e2e_testing
package nebula package udp
import ( import (
"encoding/binary" "encoding/binary"
@ -12,14 +12,18 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
//TODO: make it support reload as best you can! //TODO: make it support reload as best you can!
type udpConn struct { type Conn struct {
sysFd int sysFd int
l *logrus.Logger l *logrus.Logger
batch int
} }
var x int var x int
@ -41,7 +45,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) { func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil { if err == nil {
@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err) //l.Println(v, err)
return &udpConn{sysFd: fd, l: l}, err return &Conn{sysFd: fd, l: l, batch: batch}, err
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
return nil return nil
} }
func (u *udpConn) SetRecvBuffer(n int) error { func (u *Conn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
} }
func (u *udpConn) SetSendBuffer(n int) error { func (u *Conn) SetSendBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
} }
func (u *udpConn) GetRecvBuffer() (int, error) { func (u *Conn) GetRecvBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
} }
func (u *udpConn) GetSendBuffer() (int, error) { func (u *Conn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
} }
func (u *udpConn) LocalAddr() (*udpAddr, error) { func (u *Conn) LocalAddr() (*Addr, error) {
sa, err := unix.Getsockname(u.sysFd) sa, err := unix.Getsockname(u.sysFd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
addr := &udpAddr{} addr := &Addr{}
switch sa := sa.(type) { switch sa := sa.(type) {
case *unix.SockaddrInet4: case *unix.SockaddrInet4:
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
return addr, nil return addr, nil
} }
func (u *udpConn) ListenOut(f *Interface, q int) { func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, mtu) plaintext := make([]byte, MTU)
header := &Header{} h := &header.H{}
fwPacket := &FirewallPacket{} fwPacket := &firewall.Packet{}
udpAddr := &udpAddr{} udpAddr := &Addr{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
//TODO: should we track this? //TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti read := u.ReadMulti
if f.udpBatchSize == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
} }
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24] udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l)) r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }
} }
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) { func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
unix.SYS_RECVMSG, unix.SYS_RECVMSG,
@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
} }
} }
func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) { func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
unix.SYS_RECVMMSG, unix.SYS_RECVMMSG,
@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
} }
} }
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { func (u *Conn) WriteTo(b []byte, addr *Addr) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
} }
} }
func (u *udpConn) reloadConfig(c *Config) { func (u *Conn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0) b := c.GetInt("listen.read_buffer", 0)
if b > 0 { if b > 0 {
err := u.SetRecvBuffer(b) err := u.SetRecvBuffer(b)
@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) {
} }
} }
func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error { func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS var vallen uint32 = 4 * _SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 { if err != 0 {
@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
return nil return nil
} }
func NewUDPStatsEmitter(udpConns []*udpConn) func() { func NewUDPStatsEmitter(udpConns []*Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges // Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO var meminfo _SK_MEMINFO
@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() {
} }
} }
} }
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -4,7 +4,7 @@
// +build !android // +build !android
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
import ( import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -30,13 +30,13 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, mtu) buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array //TODO: this is still silly, no need for an array

View File

@ -4,7 +4,7 @@
// +build !android // +build !android
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
import ( import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -33,13 +33,13 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, mtu) buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array //TODO: this is still silly, no need for an array

View File

@ -1,16 +1,19 @@
//go:build e2e_testing //go:build e2e_testing
// +build e2e_testing // +build e2e_testing
package nebula package udp
import ( import (
"fmt" "fmt"
"net" "net"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
type UdpPacket struct { type Packet struct {
ToIp net.IP ToIp net.IP
ToPort uint16 ToPort uint16
FromIp net.IP FromIp net.IP
@ -18,8 +21,8 @@ type UdpPacket struct {
Data []byte Data []byte
} }
func (u *UdpPacket) Copy() *UdpPacket { func (u *Packet) Copy() *Packet {
n := &UdpPacket{ n := &Packet{
ToIp: make(net.IP, len(u.ToIp)), ToIp: make(net.IP, len(u.ToIp)),
ToPort: u.ToPort, ToPort: u.ToPort,
FromIp: make(net.IP, len(u.FromIp)), FromIp: make(net.IP, len(u.FromIp)),
@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket {
return n return n
} }
type udpConn struct { type Conn struct {
addr *udpAddr Addr *Addr
rxPackets chan *UdpPacket // Packets to receive into nebula RxPackets chan *Packet // Packets to receive into nebula
txPackets chan *UdpPacket // Packets transmitted outside by nebula TxPackets chan *Packet // Packets transmitted outside by nebula
l *logrus.Logger l *logrus.Logger
} }
func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) { func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
return &udpConn{ return &Conn{
addr: &udpAddr{net.ParseIP(ip), uint16(port)}, Addr: &Addr{net.ParseIP(ip), uint16(port)},
rxPackets: make(chan *UdpPacket, 1), RxPackets: make(chan *Packet, 1),
txPackets: make(chan *UdpPacket, 1), TxPackets: make(chan *Packet, 1),
l: l, l: l,
}, nil }, nil
} }
@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
// Send will place a UdpPacket onto the receive queue for nebula to consume // Send will place a UdpPacket onto the receive queue for nebula to consume
// this is an encrypted packet or a handshake message in most cases // this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send // packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *udpConn) Send(packet *UdpPacket) { func (u *Conn) Send(packet *Packet) {
h := &Header{} h := &header.H{}
if err := h.Parse(packet.Data); err != nil { if err := h.Parse(packet.Data); err != nil {
panic(err) panic(err)
} }
@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) {
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)). WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet") Info("UDP receiving injected packet")
u.rxPackets <- packet u.RxPackets <- packet
} }
// Get will pull a UdpPacket from the transmit queue // Get will pull a UdpPacket from the transmit queue
// nebula meant to send this message on the network, it will be encrypted // nebula meant to send this message on the network, it will be encrypted
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send // packets were ingested from the tun side (in most cases), you can send them with Tun.Send
func (u *udpConn) Get(block bool) *UdpPacket { func (u *Conn) Get(block bool) *Packet {
if block { if block {
return <-u.txPackets return <-u.TxPackets
} }
select { select {
case p := <-u.txPackets: case p := <-u.TxPackets:
return p return p
default: default:
return nil return nil
@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket {
// Below this is boilerplate implementation to make nebula actually work // Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************// //********************************************************************************************************************//
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { func (u *Conn) WriteTo(b []byte, addr *Addr) error {
p := &UdpPacket{ p := &Packet{
Data: make([]byte, len(b), len(b)), Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16), FromIp: make([]byte, 16),
FromPort: u.addr.Port, FromPort: u.Addr.Port,
ToIp: make([]byte, 16), ToIp: make([]byte, 16),
ToPort: addr.Port, ToPort: addr.Port,
} }
copy(p.Data, b) copy(p.Data, b)
copy(p.ToIp, addr.IP.To16()) copy(p.ToIp, addr.IP.To16())
copy(p.FromIp, u.addr.IP.To16()) copy(p.FromIp, u.Addr.IP.To16())
u.txPackets <- p u.TxPackets <- p
return nil return nil
} }
func (u *udpConn) ListenOut(f *Interface, q int) { func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, mtu) plaintext := make([]byte, MTU)
header := &Header{} h := &header.H{}
fwPacket := &FirewallPacket{} fwPacket := &firewall.Packet{}
ua := &udpAddr{IP: make([]byte, 16)} ua := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
p := <-u.rxPackets p := <-u.RxPackets
ua.Port = p.FromPort ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16()) copy(ua.IP, p.FromIp.To16())
f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l)) r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }
func (u *udpConn) reloadConfig(*Config) {} func (u *Conn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []*udpConn) func() { func NewUDPStatsEmitter(_ []*Conn) func() {
// No UDP stats for non-linux // No UDP stats for non-linux
return func() {} return func() {}
} }
func (u *udpConn) LocalAddr() (*udpAddr, error) { func (u *Conn) LocalAddr() (*Addr, error) {
return u.addr, nil return u.Addr, nil
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
return nil return nil
} }
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing //go:build !e2e_testing
// +build !e2e_testing // +build !e2e_testing
package nebula package udp
// Windows support is primarily implemented in udp_generic, besides NewListenConfig // Windows support is primarily implemented in udp_generic, besides NewListenConfig
@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *udpConn) Rebind() error { func (u *Conn) Rebind() error {
return nil return nil
} }

View File

@ -1,4 +1,4 @@
package nebula package util
import ( import (
"io/ioutil" "io/ioutil"
@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger {
} }
switch v { switch v {
case "1":
// This is the default level but we are being explicit
l.SetLevel(logrus.InfoLevel)
case "2": case "2":
l.SetLevel(logrus.DebugLevel) l.SetLevel(logrus.DebugLevel)
case "3": case "3":
l.SetLevel(logrus.TraceLevel) l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
} }
return l return l