Rework some things into packages (#489)

This commit is contained in:
Nate Brown 2021-11-03 20:54:04 -05:00 committed by GitHub
parent 1f75fb3c73
commit bcabcfdaca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 2526 additions and 2374 deletions

View File

@ -4,11 +4,15 @@ import (
"fmt"
"net"
"regexp"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *CIDR6Tree
cidrTree *cidr.Tree6
}
type RemoteAllowList struct {
@ -16,7 +20,7 @@ type RemoteAllowList struct {
// Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList
insideAllowLists *CIDR6Tree
insideAllowLists *cidr.Tree6
}
type LocalAllowList struct {
@ -31,6 +35,223 @@ type AllowListNameRule struct {
Allow bool
}
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule
handleKey := func(key string, value interface{}) (bool, error) {
if key == "interfaces" {
var err error
nameRules, err = getAllowListInterfaces(k, value)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
al, err := newAllowListFromConfig(c, k, handleKey)
if err != nil {
return nil, err
}
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
}
func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) {
al, err := newAllowListFromConfig(c, k, nil)
if err != nil {
return nil, err
}
remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey)
if err != nil {
return nil, err
}
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
return newAllowList(k, r, handleKey)
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
tree := cidr.NewTree6()
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
firstValue bool
allValuesMatch bool
defaultSet bool
allValues bool
}
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue)
if err != nil {
return nil, err
}
if handled {
continue
}
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, ipNet, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(ipNet, value)
maskBits, maskSize := ipNet.Mask.Size()
var rules *allowListRules
if maskSize == 32 {
rules = &rules4
} else {
rules = &rules6
}
if rules.firstValue {
rules.allValues = value
rules.firstValue = false
} else {
if value != rules.allValues {
rules.allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0 or ::/0
if maskBits == 0 {
rules.defaultSet = true
}
}
if !rules4.defaultSet {
if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
if !rules6.defaultSet {
if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0")
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
}
return &AllowList{cidrTree: tree}, nil
}
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
remoteAllowRanges := cidr.NewTree6()
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil {
return nil, err
}
_, ipNet, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
remoteAllowRanges.AddCIDR(ipNet, allowList)
}
return remoteAllowRanges, nil
}
func (al *AllowList) Allow(ip net.IP) bool {
if al == nil {
return true
@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool {
}
}
func (al *AllowList) AllowIpV4(ip uint32) bool {
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
if al == nil {
return true
}
@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
return al.AllowList.Allow(ip)
}
func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool {
func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) {
return false
}
return al.AllowList.Allow(ip)
}
func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
if al == nil {
return true
}
@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
return al.AllowList.AllowIpV4(ip)
}
func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
if al == nil {
return true
}
@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
return al.AllowList.AllowIpV6(hi, lo)
}
func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList {
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
if al.insideAllowLists != nil {
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if inside != nil {

View File

@ -5,21 +5,110 @@ import (
"regexp"
"testing"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestNewAllowListFromConfig(t *testing.T) {
l := util.NewTestLogger()
c := config.NewC(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"::/0": false,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
if assert.NoError(t, err) {
assert.NotNil(t, lr)
}
}
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
tree.AddCIDR(getCIDR("10.42.42.42/32"), true)
tree.AddCIDR(getCIDR("10.42.0.0/16"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), false)
tree.AddCIDR(getCIDR("::1/128"), true)
tree.AddCIDR(getCIDR("::2/128"), false)
tree := cidr.NewTree6()
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
tree.AddCIDR(cidr.Parse("::1/128"), true)
tree.AddCIDR(cidr.Parse("::2/128"), false)
al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))

View File

@ -3,11 +3,12 @@ package nebula
import (
"testing"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestBits(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
b := NewBits(10)
// make sure it is the right size
@ -75,7 +76,7 @@ func TestBits(t *testing.T) {
}
func TestBitsDupeCounter(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
}
func TestBitsOutOfWindowCounter(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
}
func TestBitsLostCounter(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()

View File

@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
)
type CertState struct {
@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
return cs, nil
}
func NewCertStateFromConfig(c *Config) (*CertState, error) {
func NewCertStateFromConfig(c *config.C) (*CertState, error) {
var pemPrivateKey []byte
var err error
@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
return NewCertState(nebulaCert, rawKey)
}
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error

10
cidr/parse.go Normal file
View File

@ -0,0 +1,10 @@
package cidr
import "net"
// Parse is a convenience function that returns only the IPNet
// This function ignores errors since it is primarily a test helper, the result could be nil
func Parse(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

View File

@ -1,39 +1,39 @@
package nebula
package cidr
import (
"encoding/binary"
"fmt"
"net"
"github.com/slackhq/nebula/iputil"
)
type CIDRNode struct {
left *CIDRNode
right *CIDRNode
parent *CIDRNode
type Node struct {
left *Node
right *Node
parent *Node
value interface{}
}
type CIDRTree struct {
root *CIDRNode
type Tree4 struct {
root *Node
}
const (
startbit = uint32(0x80000000)
startbit = iputil.VpnIp(0x80000000)
)
func NewCIDRTree() *CIDRTree {
tree := new(CIDRTree)
tree.root = &CIDRNode{}
func NewTree4() *Tree4 {
tree := new(Tree4)
tree.root = &Node{}
return tree
}
func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
bit := startbit
node := tree.root
next := tree.root
ip := ip2int(cidr.IP)
mask := ip2int(cidr.Mask)
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for bit&mask != 0 {
@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &CIDRNode{}
next = &Node{}
next.parent = node
if ip&bit != 0 {
@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
}
// Finds the first match, which may be the least specific
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
}
// Finds the most specific match
func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
}
// Finds the most specific match
func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
lastNode := node
@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
}
return value
}
// A helper type to avoid converting to IP when logging
type IntIp uint32
func (ip IntIp) String() string {
return fmt.Sprintf("%v", int2ip(uint32(ip)))
}
func (ip IntIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}

153
cidr/tree4_test.go Normal file
View File

@ -0,0 +1,153 @@
package cidr
import (
"net"
"testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_Contains(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4a", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_MostSpecificContains(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4b", "4.1.1.2"},
{"4c", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_Match(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
tests := []struct {
Result interface{}
IP string
}{
{"1a", "4.1.1.0"},
{"1b", "4.1.1.1"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
}
tree = NewTree4()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewTree4()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewTree4()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}

View File

@ -1,26 +1,27 @@
package nebula
package cidr
import (
"encoding/binary"
"net"
"github.com/slackhq/nebula/iputil"
)
const startbit6 = uint64(1 << 63)
type CIDR6Tree struct {
root4 *CIDRNode
root6 *CIDRNode
type Tree6 struct {
root4 *Node
root6 *Node
}
func NewCIDR6Tree() *CIDR6Tree {
tree := new(CIDR6Tree)
tree.root4 = &CIDRNode{}
tree.root6 = &CIDRNode{}
func NewTree6() *Tree6 {
tree := new(Tree6)
tree.root4 = &Node{}
tree.root6 = &Node{}
return tree
}
func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
var node, next *CIDRNode
func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
var node, next *Node
cidrIP, ipv4 := isIPV4(cidr.IP)
if ipv4 {
@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
}
for i := 0; i < len(cidrIP); i += 4 {
ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
bit := startbit
// Find our last ancestor in the tree
@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &CIDRNode{}
next = &Node{}
next.parent = node
if ip&bit != 0 {
@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
}
// Finds the most specific match
func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
var node *CIDRNode
func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
var node *Node
wholeIP, ipv4 := isIPV4(ip)
if ipv4 {
@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
}
for i := 0; i < len(wholeIP); i += 4 {
ip := ip2int(wholeIP[i : i+4])
ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
bit := startbit
for node != nil {
@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
return value
}
func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root4
@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
return value
}
func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
ip := hi
node := tree.root6

View File

@ -1,6 +1,7 @@
package nebula
package cidr
import (
"encoding/binary"
"net"
"testing"
@ -8,17 +9,17 @@ import (
)
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
tree := NewTree6()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Result interface{}
@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
}
tree = NewCIDR6Tree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
tree.AddCIDR(getCIDR("::/0"), "cool6")
tree = NewTree6()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
tree.AddCIDR(Parse("::/0"), "cool6")
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
}
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
tree := NewTree6()
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Result interface{}
@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
}
for _, tt := range tests {
ip := NewIp6AndPort(net.ParseIP(tt.IP), 0)
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo))
ip := net.ParseIP(tt.IP)
hi := binary.BigEndian.Uint64(ip[:8])
lo := binary.BigEndian.Uint64(ip[8:])
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
}
}

View File

@ -1,157 +0,0 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_Contains(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4a", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_MostSpecificContains(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4b", "4.1.1.2"},
{"4c", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_Match(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
tests := []struct {
Result interface{}
IP string
}{
{"1a", "4.1.1.0"},
{"1b", "4.1.1.1"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}
func getCIDR(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

View File

@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
)
// A version string that can be set with
@ -49,14 +50,14 @@ func main() {
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
c := config.NewC(l)
err := c.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
c, err := nebula.Main(config, *configTest, Build, l, nil)
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
@ -68,8 +69,8 @@ func main() {
}
if !*configTest {
c.Start()
c.ShutdownBlock()
ctrl.Start()
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@ -9,6 +9,7 @@ import (
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
)
var logger service.Logger
@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error {
l := logrus.New()
HookLogger(l)
config := nebula.NewConfig(l)
err := config.Load(*p.configPath)
c := config.NewC(l)
err := c.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
}
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil {
return err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
)
// A version string that can be set with
@ -43,14 +44,14 @@ func main() {
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
c := config.NewC(l)
err := c.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
c, err := nebula.Main(config, *configTest, Build, l, nil)
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
@ -62,8 +63,8 @@ func main() {
}
if !*configTest {
c.Start()
c.ShutdownBlock()
ctrl.Start()
ctrl.ShutdownBlock()
}
os.Exit(0)

611
config.go
View File

@ -1,611 +0,0 @@
package nebula
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"os/signal"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type Config struct {
path string
files []string
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*Config)
l *logrus.Logger
}
func NewConfig(l *logrus.Logger) *Config {
return &Config{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
// Load will find all yaml files within path and load them in lexical order
func (c *Config) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path, true)
if err != nil {
return err
}
if len(c.files) == 0 {
return fmt.Errorf("no config files found at %s", path)
}
sort.Strings(c.files)
err = c.parse()
if err != nil {
return err
}
return nil
}
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
// These functions should return quickly or spawn their own go routine if they will take a while
func (c *Config) RegisterReloadCallback(f func(*Config)) {
c.callbacks = append(c.callbacks, f)
}
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
// If k is an empty string the entire config is tested.
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
// there is change when there actually wasn't any.
func (c *Config) HasChanged(k string) bool {
if c.oldSettings == nil {
return false
}
var (
nv interface{}
ov interface{}
)
if k == "" {
nv = c.Settings
ov = c.oldSettings
k = "all settings"
} else {
nv = c.get(k, c.Settings)
ov = c.get(k, c.oldSettings)
}
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
}
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *Config) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for {
select {
case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}
}()
}
func (c *Config) ReloadConfig() {
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
for _, v := range c.callbacks {
v(c)
}
}
// GetString will get the string for k or return the default d if not found or invalid
func (c *Config) GetString(k, d string) string {
r := c.Get(k)
if r == nil {
return d
}
return fmt.Sprintf("%v", r)
}
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
func (c *Config) GetStringSlice(k string, d []string) []string {
r := c.Get(k)
if r == nil {
return d
}
rv, ok := r.([]interface{})
if !ok {
return d
}
v := make([]string, len(rv))
for i := 0; i < len(v); i++ {
v[i] = fmt.Sprintf("%v", rv[i])
}
return v
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
return v
}
// GetInt will get the int for k or return the default d if not found or invalid
func (c *Config) GetInt(k string, d int) int {
r := c.GetString(k, strconv.Itoa(d))
v, err := strconv.Atoi(r)
if err != nil {
return d
}
return v
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *Config) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
v, err := strconv.ParseBool(r)
if err != nil {
switch r {
case "y", "yes":
return true
case "n", "no":
return false
}
return d
}
return v
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
v, err := time.ParseDuration(r)
if err != nil {
return d
}
return v
}
func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule
handleKey := func(key string, value interface{}) (bool, error) {
if key == "interfaces" {
var err error
nameRules, err = c.getAllowListInterfaces(k, value)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
al, err := c.GetAllowList(k, handleKey)
if err != nil {
return nil, err
}
return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
}
func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
al, err := c.GetAllowList(k, nil)
if err != nil {
return nil, err
}
remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
if err != nil {
return nil, err
}
return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
}
func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
remoteAllowRanges := NewCIDR6Tree()
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil {
return nil, err
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
remoteAllowRanges.AddCIDR(cidr, allowList)
}
return remoteAllowRanges, nil
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
return c.getAllowList(k, r, handleKey)
}
// If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`.
func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
tree := NewCIDR6Tree()
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
firstValue bool
allValuesMatch bool
defaultSet bool
allValues bool
}
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue)
if err != nil {
return nil, err
}
if handled {
continue
}
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(cidr, value)
maskBits, maskSize := cidr.Mask.Size()
var rules *allowListRules
if maskSize == 32 {
rules = &rules4
} else {
rules = &rules6
}
if rules.firstValue {
rules.allValues = value
rules.firstValue = false
} else {
if value != rules.allValues {
rules.allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0 or ::/0
if maskBits == 0 {
rules.defaultSet = true
}
}
if !rules4.defaultSet {
if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
if !rules6.defaultSet {
if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0")
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
}
return &AllowList{cidrTree: tree}, nil
}
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func (c *Config) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *Config) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *Config) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
v, ok = m[p]
if !ok {
return nil
}
}
return v
}
// direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *Config) resolve(path string, direct bool) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path, direct)
return nil
}
paths, err := readDirNames(path)
if err != nil {
return fmt.Errorf("problem while reading directory %s: %s", path, err)
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p), false)
if err != nil {
return err
}
}
return nil
}
func (c *Config) addFile(path string, direct bool) error {
ext := filepath.Ext(path)
if !direct && ext != ".yaml" && ext != ".yml" {
return nil
}
ap, err := filepath.Abs(path)
if err != nil {
return err
}
c.files = append(c.files, ap)
return nil
}
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error {
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err
}
// We need to use WithAppendSlice so that firewall rules in separate
// files are appended together
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
m = nm
if err != nil {
return err
}
}
c.Settings = m
return nil
}
func readDirNames(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
paths, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(paths)
return paths, nil
}
func configLogger(c *Config) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
c.l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
c.l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
c.l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

358
config/config.go Normal file
View File

@ -0,0 +1,358 @@
package config
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type C struct {
path string
files []string
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*C)
l *logrus.Logger
}
func NewC(l *logrus.Logger) *C {
return &C{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
// Load will find all yaml files within path and load them in lexical order
func (c *C) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path, true)
if err != nil {
return err
}
if len(c.files) == 0 {
return fmt.Errorf("no config files found at %s", path)
}
sort.Strings(c.files)
err = c.parse()
if err != nil {
return err
}
return nil
}
func (c *C) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
// These functions should return quickly or spawn their own go routine if they will take a while
func (c *C) RegisterReloadCallback(f func(*C)) {
c.callbacks = append(c.callbacks, f)
}
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
// If k is an empty string the entire config is tested.
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
// there is change when there actually wasn't any.
func (c *C) HasChanged(k string) bool {
if c.oldSettings == nil {
return false
}
var (
nv interface{}
ov interface{}
)
if k == "" {
nv = c.Settings
ov = c.oldSettings
k = "all settings"
} else {
nv = c.get(k, c.Settings)
ov = c.get(k, c.oldSettings)
}
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
}
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *C) CatchHUP(ctx context.Context) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for {
select {
case <-ctx.Done():
signal.Stop(ch)
close(ch)
return
case <-ch:
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}
}()
}
func (c *C) ReloadConfig() {
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
for _, v := range c.callbacks {
v(c)
}
}
// GetString will get the string for k or return the default d if not found or invalid
func (c *C) GetString(k, d string) string {
r := c.Get(k)
if r == nil {
return d
}
return fmt.Sprintf("%v", r)
}
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
func (c *C) GetStringSlice(k string, d []string) []string {
r := c.Get(k)
if r == nil {
return d
}
rv, ok := r.([]interface{})
if !ok {
return d
}
v := make([]string, len(rv))
for i := 0; i < len(v); i++ {
v[i] = fmt.Sprintf("%v", rv[i])
}
return v
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
return v
}
// GetInt will get the int for k or return the default d if not found or invalid
func (c *C) GetInt(k string, d int) int {
r := c.GetString(k, strconv.Itoa(d))
v, err := strconv.Atoi(r)
if err != nil {
return d
}
return v
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *C) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
v, err := strconv.ParseBool(r)
if err != nil {
switch r {
case "y", "yes":
return true
case "n", "no":
return false
}
return d
}
return v
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
v, err := time.ParseDuration(r)
if err != nil {
return d
}
return v
}
func (c *C) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *C) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *C) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
v, ok = m[p]
if !ok {
return nil
}
}
return v
}
// direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *C) resolve(path string, direct bool) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path, direct)
return nil
}
paths, err := readDirNames(path)
if err != nil {
return fmt.Errorf("problem while reading directory %s: %s", path, err)
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p), false)
if err != nil {
return err
}
}
return nil
}
func (c *C) addFile(path string, direct bool) error {
ext := filepath.Ext(path)
if !direct && ext != ".yaml" && ext != ".yml" {
return nil
}
ap, err := filepath.Abs(path)
if err != nil {
return err
}
c.files = append(c.files, ap)
return nil
}
func (c *C) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *C) parse() error {
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err
}
// We need to use WithAppendSlice so that firewall rules in separate
// files are appended together
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
m = nm
if err != nil {
return err
}
}
c.Settings = m
return nil
}
func readDirNames(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
paths, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(paths)
return paths, nil
}

View File

@ -1,4 +1,4 @@
package nebula
package config
import (
"io/ioutil"
@ -7,19 +7,20 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestConfig_Load(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
dir, err := ioutil.TempDir("", "config-test")
// invalid yaml
c := NewConfig(l)
c := NewC(l)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewConfig(l)
c = NewC(l)
os.RemoveAll(dir)
os.Mkdir(dir, 0755)
@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) {
}
func TestConfig_Get(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
// test simple type
c := NewConfig(l)
c := NewC(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound"))
@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) {
}
func TestConfig_GetStringSlice(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
l := util.NewTestLogger()
c := NewC(l)
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
func TestConfig_GetBool(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
l := util.NewTestLogger()
c := NewC(l)
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) {
assert.Equal(t, false, c.GetBool("bool", true))
}
func TestConfig_GetAllowList(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = c.GetAllowList("allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"::/0": false,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", nil)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
lr, err := c.GetLocalAllowList("allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
lr, err = c.GetLocalAllowList("allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
lr, err = c.GetLocalAllowList("allowlist")
if assert.NoError(t, err) {
assert.NotNil(t, lr)
}
}
func TestConfig_HasChanged(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
// No reload has occurred, return false
c := NewConfig(l)
c := NewC(l)
c.Settings["test"] = "hi"
assert.False(t, c.HasChanged(""))
// Test key change
c = NewConfig(l)
c = NewC(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged(""))
// No key change
c = NewConfig(l)
c = NewC(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test"))
@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) {
}
func TestConfig_ReloadConfig(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig(l)
c := NewC(l)
assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner"))
@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
c.RegisterReloadCallback(func(c *Config) {
c.RegisterReloadCallback(func(c *C) {
done <- true
})

View File

@ -6,6 +6,8 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
)
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
@ -13,16 +15,16 @@ import (
type connectionManager struct {
hostMap *HostMap
in map[uint32]struct{}
in map[iputil.VpnIp]struct{}
inLock *sync.RWMutex
inCount int
out map[uint32]struct{}
out map[iputil.VpnIp]struct{}
outLock *sync.RWMutex
outCount int
TrafficTimer *SystemTimerWheel
intf *Interface
pendingDeletion map[uint32]int
pendingDeletion map[iputil.VpnIp]int
pendingDeletionLock *sync.RWMutex
pendingDeletionTimer *SystemTimerWheel
@ -36,15 +38,15 @@ type connectionManager struct {
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
in: make(map[iputil.VpnIp]struct{}),
inLock: &sync.RWMutex{},
inCount: 0,
out: make(map[uint32]struct{}),
out: make(map[iputil.VpnIp]struct{}),
outLock: &sync.RWMutex{},
outCount: 0,
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
intf: intf,
pendingDeletion: make(map[uint32]int),
pendingDeletion: make(map[iputil.VpnIp]int),
pendingDeletionLock: &sync.RWMutex{},
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval,
@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
return nc
}
func (n *connectionManager) In(ip uint32) {
func (n *connectionManager) In(ip iputil.VpnIp) {
n.inLock.RLock()
// If this already exists, return
if _, ok := n.in[ip]; ok {
@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) {
n.inLock.Unlock()
}
func (n *connectionManager) Out(ip uint32) {
func (n *connectionManager) Out(ip iputil.VpnIp) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[ip]; ok {
@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) {
n.outLock.Unlock()
}
func (n *connectionManager) CheckIn(vpnIP uint32) bool {
func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool {
n.inLock.RLock()
if _, ok := n.in[vpnIP]; ok {
if _, ok := n.in[vpnIp]; ok {
n.inLock.RUnlock()
return true
}
@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool {
return false
}
func (n *connectionManager) ClearIP(ip uint32) {
func (n *connectionManager) ClearIP(ip iputil.VpnIp) {
n.inLock.Lock()
n.outLock.Lock()
delete(n.in, ip)
@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) {
n.outLock.Unlock()
}
func (n *connectionManager) ClearPendingDeletion(ip uint32) {
func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) {
n.pendingDeletionLock.Lock()
delete(n.pendingDeletion, ip)
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) AddPendingDeletion(ip uint32) {
func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) {
n.pendingDeletionLock.Lock()
if _, ok := n.pendingDeletion[ip]; ok {
n.pendingDeletion[ip] += 1
@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) {
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool {
n.pendingDeletionLock.RLock()
if _, ok := n.pendingDeletion[ip]; ok {
@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
return false
}
func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) {
n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds))
}
func (n *connectionManager) Start(ctx context.Context) {
@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
break
}
vpnIP := ep.(uint32)
vpnIp := ep.(iputil.VpnIp)
// Check for traffic coming back in from this host.
traf := n.CheckIn(vpnIP)
traf := n.CheckIn(vpnIp)
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
if err != nil {
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", vpnIp)
if !n.intf.disconnectInvalid {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
continue
}
}
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
continue
}
@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// expired, just ignore.
if traf {
if n.l.Level >= logrus.DebugLevel {
n.l.WithField("vpnIp", IntIp(vpnIP)).
n.l.WithField("vpnIp", vpnIp).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
continue
}
@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
if hostinfo != nil && hostinfo.ConnectionState != nil {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out)
} else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp)
}
n.AddPendingDeletion(vpnIP)
n.AddPendingDeletion(vpnIp)
}
}
@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
break
}
vpnIP := ep.(uint32)
vpnIp := ep.(iputil.VpnIp)
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
if err != nil {
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", vpnIp)
if !n.intf.disconnectInvalid {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
continue
}
}
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
continue
}
// If we saw an incoming packets from this ip and peer's certificate is not
// expired, just ignore.
traf := n.CheckIn(vpnIP)
traf := n.CheckIn(vpnIp)
if traf {
n.l.WithField("vpnIp", IntIp(vpnIP)).
n.l.WithField("vpnIp", vpnIp).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status")
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
continue
}
// If it comes around on deletion wheel and hasn't resolved itself, delete
if n.checkPendingDeletion(vpnIP) {
if n.checkPendingDeletion(vpnIp) {
cn := ""
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
WithField("certName", cn).
Info("Tunnel status")
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
// TODO: This is only here to let tests work. Should do proper mocking
if n.intf.lightHouse != nil {
n.intf.lightHouse.DeleteVpnIP(vpnIP)
n.intf.lightHouse.DeleteVpnIp(vpnIp)
}
n.hostMap.DeleteHostInfo(hostinfo)
} else {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
}
}
}
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool {
if !n.intf.disconnectInvalid {
return false
}
@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
}
fingerprint, _ := remoteCert.Sha256Sum()
n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
n.l.WithField("vpnIp", vpnIp).WithError(err).
WithField("certName", remoteCert.Details.Name).
WithField("fingerprint", fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo, false)
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
n.ClearIP(vpnIp)
n.ClearPendingDeletion(vpnIp)
return true
}

View File

@ -10,17 +10,20 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
var vpnIP uint32
var vpnIp iputil.VpnIp
func Test_NewConnectionManagerTest(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
outside: &udpConn{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
}
// We saw traffic out to vpnIP
nc.Out(vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// We saw traffic out to vpnIp
nc.Out(vpnIp)
assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out)
@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
assert.Contains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.NotContains(t, nc.hostMap.Hosts, vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.NotContains(t, nc.hostMap.Hosts, vpnIp)
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
outside: &udpConn{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
}
// We saw traffic out to vpnIP
nc.Out(vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// We saw traffic out to vpnIp
nc.Out(vpnIp)
assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out)
@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
assert.Contains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
// We heard back this time
nc.In(vpnIP)
nc.In(vpnIp)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIp)
assert.Contains(t, nc.hostMap.Hosts, vpnIp)
}
@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Disconnect only if disconnectInvalid: true is set.
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
now := time.Now()
l := NewTestLogger()
l := util.NewTestLogger()
ipNet := net.IPNet{
IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0},
@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
outside: &udpConn{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
caPool: ncp,
@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
defer cancel()
nc := newConnectionManager(ctx, l, ifce, 5, 10)
ifce.connectionManager = nc
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo := nc.hostMap.AddVpnIp(vpnIp)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
peerCert: &peerCert,
@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Check if to disconnect with invalid certificate.
// Should be alive.
nextTick := now.Add(45 * time.Second)
destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
assert.False(t, destroyed)
// Move ahead 61s.
// Check if to disconnect with invalid certificate.
// Should be disconnected.
nextTick = now.Add(61 * time.Second)
destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
assert.True(t, destroyed)
}

View File

@ -10,6 +10,9 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@ -25,14 +28,14 @@ type Control struct {
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
VpnIp net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []*udpAddr `json:"remoteAddrs"`
RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote *udpAddr `json:"currentRemote"`
CurrentRemote *udp.Addr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
}
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
var hm *HostMap
if pending {
hm = c.f.handshakeManager.pendingHostMap
@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
hm = c.f.hostMap
}
h, err := hm.QueryVpnIP(vpnIP)
h, err := hm.QueryVpnIp(vpnIp)
if err != nil {
return nil
}
@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return nil
}
@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return false
}
if !localOnly {
c.f.send(
closeTunnel,
header.CloseTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if excludeLighthouses {
if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok {
if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
continue
}
}
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h, true)
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
closed++
}
@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
VpnIp: h.vpnIp.ToIP(),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),

View File

@ -8,17 +8,19 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l := NewTestLogger()
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := util.NewTestLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(int2ip(100), 4444)
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
})
hm.Add(ip2int(ipNet2.IP), &HostInfo{
hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet2.IP),
vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
})
c := Control{
@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
VpnIp: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []*udpAddr{remote2, remote1},
RemoteAddrs: []*udp.Addr{remote2, remote1},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: NewUDPAddr(int2ip(100), 4444),
CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
})
}

View File

@ -8,12 +8,15 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
// returning after a message matching the criteria has been piped
func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
h := &Header{}
func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &header.H{}
for {
p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil {
@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu
// WaitForTypeByIndex is similar to WaitForType except it adds an index check
// Useful if you have many nodes communicating and want to wait to find a specific nodes packet
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
h := &Header{}
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &header.H{}
for {
p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil {
@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
iVpnIp := ip2int(vpnIp)
iVpnIp := iputil.Ip2VpnIp(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else {
@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte {
}
// GetFromUDP will pull a udp packet off the udp side of nebula
func (c *Control) GetFromUDP(block bool) *UdpPacket {
func (c *Control) GetFromUDP(block bool) *udp.Packet {
return c.f.outside.Get(block)
}
func (c *Control) GetUDPTxChan() <-chan *UdpPacket {
return c.f.outside.txPackets
func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
return c.f.outside.TxPackets
}
func (c *Control) GetTunTxChan() <-chan []byte {
@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
}
// InjectUDPPacket will inject a packet into the udp side of nebula
func (c *Control) InjectUDPPacket(p *UdpPacket) {
func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.Send(p)
}
@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
}
func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String()
return c.f.outside.Addr.String()
}
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
if !ok {
return false
}

View File

@ -8,6 +8,8 @@ import (
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
)
// This whole thing should be rewritten to use context
@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string {
if ip == nil {
return ""
}
iip := ip2int(ip)
hostinfo, err := d.hostMap.QueryVpnIP(iip)
iip := iputil.Ip2VpnIp(ip)
hostinfo, err := d.hostMap.QueryVpnIp(iip)
if err != nil {
return ""
}
@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap)
// attach request handler func
@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(func(c *Config) {
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)
})
@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
}
}
func getDnsServerAddr(c *Config) string {
func getDnsServerAddr(c *config.C) string {
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
}
func startDns(l *logrus.Logger, c *Config) {
func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) {
}
}
func reloadDns(l *logrus.Logger, c *Config) {
func reloadDns(l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return

View File

@ -10,6 +10,9 @@ import (
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) {
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
h := &nebula.Header{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) {
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone

View File

@ -5,7 +5,6 @@ package e2e
import (
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -19,7 +18,9 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err)
}
config := nebula.NewConfig(l)
config.LoadString(string(cb))
c := config.NewC(l)
c.LoadString(string(cb))
control, err := nebula.Main(config, false, "e2e-test", l, nil)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) {
return pubkey, privkey
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}
type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb {
@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
// Get both host infos
hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B")
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")

View File

@ -11,6 +11,8 @@ import (
"sync"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
type R struct {
@ -41,7 +43,7 @@ const (
RouteAndExit ExitType = 2
)
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
func NewR(controls ...*nebula.Control) *R {
r := &R{
@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
// OnceFrom will route a single packet from sender then return
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) OnceFrom(sender *nebula.Control) {
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
return RouteAndExit
})
}
@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
h := &nebula.Header{}
h := &header.H{}
for {
p := sender.GetFromUDP(true)
r.Lock()
@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
h := &nebula.Header{}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
h := &header.H{}
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
if err := h.Parse(p.Data); err != nil {
panic(err)
}
@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr
finish = RouteAndExit
}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
return finish
}
@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
x, rx, _ := reflect.Select(sc)
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@ -277,7 +279,7 @@ func (r *R) FlushAll() {
}
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@ -292,7 +294,7 @@ func (r *R) FlushAll() {
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
p.FromIp = newAddr.IP
p.FromPort = uint16(newAddr.Port)

View File

@ -4,7 +4,6 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
@ -12,22 +11,14 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
const (
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
fwProtoTCP = 6
fwProtoUDP = 17
fwProtoICMP = 1
fwPortAny = 0 // Special value for matching `port: any`
fwPortFragment = -1 // Special value for matching `port: fragment`
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
)
const tcpACK = 0x10
@ -63,7 +54,7 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree
localIps *cidr.Tree4
rules string
rulesVersion uint16
@ -85,7 +76,7 @@ type firewallMetrics struct {
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
Conns map[firewall.Packet]*conn
TimerWheel *TimerWheel
}
@ -116,55 +107,13 @@ type FirewallRule struct {
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *CIDRTree
CIDR *cidr.Tree4
}
// Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA
type FirewallPacket struct {
LocalIP uint32
RemoteIP uint32
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *FirewallPacket) Copy() *FirewallPacket {
return &FirewallPacket{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case fwProtoTCP:
proto = "tcp"
case fwProtoICMP:
proto = "icmp"
case fwProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": int2ip(fp.LocalIP).String(),
"RemoteIP": int2ip(fp.RemoteIP).String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
localIps := NewCIDRTree()
localIps := cidr.NewTree4()
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn),
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(),
@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
}
switch proto {
case fwProtoTCP:
case firewall.ProtoTCP:
fp = ft.TCP
case fwProtoUDP:
case firewall.ProtoUDP:
fp = ft.UDP
case fwProtoICMP:
case firewall.ProtoICMP:
fp = ft.ICMP
case fwProtoAny:
case firewall.ProtoAny:
fp = ft.AnyProto
default:
return fmt.Errorf("unknown protocol %v", proto)
@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:])
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
table = "firewall.outbound"
}
r := config.Get(table)
r := c.Get(table)
if r == nil {
return nil
}
@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
var proto uint8
switch r.Proto {
case "any":
proto = fwProtoAny
proto = firewall.ProtoAny
case "tcp":
proto = fwProtoTCP
proto = firewall.ProtoTCP
case "udp":
proto = fwProtoUDP
proto = firewall.ProtoUDP
case "icmp":
proto = fwProtoICMP
proto = firewall.ProtoICMP
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
return nil
@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId {
if fp.RemoteIP != h.vpnIp {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
}
@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
}
switch fp.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
if incoming {
f.checkTCPRTT(c, packet)
} else {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
c.Expires = time.Now().Add(f.DefaultTimeout)
@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return true
}
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
timeout = f.TCPTimeout
if !incoming {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
case firewall.ProtoUDP:
timeout = f.UDPTimeout
default:
timeout = f.DefaultTimeout
@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p FirewallPacket) {
func (f *Firewall) evict(p firewall.Packet) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
conntrack := f.Conntrack
@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
delete(conntrack.Conns, p)
}
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) {
return true
}
switch p.Protocol {
case fwProtoTCP:
case firewall.ProtoTCP:
if ft.TCP.match(p, incoming, c, caPool) {
return true
}
case fwProtoUDP:
case firewall.ProtoUDP:
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case fwProtoICMP:
case firewall.ProtoICMP:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
return nil
}
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
// We don't have any allowed ports, bail
if fp == nil {
return false
@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
var port int32
if p.Fragment {
port = fwPortFragment
port = firewall.PortFragment
} else if incoming {
port = int32(p.LocalPort)
} else {
@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
return true
}
return fp[fwPortAny].match(p, c, caPool)
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: NewCIDRTree(),
CIDR: cidr.NewTree4(),
}
}
@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
return nil
}
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if fc == nil {
return false
}
@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = NewCIDRTree()
fr.CIDR = cidr.NewTree4()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return false
}
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if fr == nil {
return false
}
@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = fwPortAny
endPort = fwPortAny
startPort = firewall.PortAny
endPort = firewall.PortAny
} else if s == "fragment" {
startPort = fwPortFragment
endPort = fwPortFragment
startPort = firewall.PortFragment
endPort = firewall.PortFragment
} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
startPort = int32(rStartPort)
endPort = int32(rEndPort)
if startPort == fwPortAny {
endPort = fwPortAny
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
} else {
@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
c.Seq = 0
return true
}
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[FirewallPacket]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

59
firewall/cache.go Normal file
View File

@ -0,0 +1,59 @@
package firewall
import (
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[Packet]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

62
firewall/packet.go Normal file
View File

@ -0,0 +1,62 @@
package firewall
import (
"encoding/json"
"fmt"
"github.com/slackhq/nebula/iputil"
)
type m map[string]interface{}
const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment`
)
type Packet struct {
LocalIP iputil.VpnIp
RemoteIP iputil.VpnIp
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp Packet) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case ProtoTCP:
proto = "tcp"
case ProtoICMP:
proto = "icmp"
case ProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": fp.LocalIP.String(),
"RemoteIP": fp.RemoteIP.String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}

View File

@ -11,11 +11,15 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestNewFirewall(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
c := &cert.NebulaCertificate{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack
@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
}
func TestFirewall_AddRule(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) {
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
// run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
p := firewall.Packet{
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
firewall.ProtoUDP,
false,
}
@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) {
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("fail on proto", func(b *testing.B) {
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
}
})
b.Run("fail on port", func(b *testing.B) {
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
}
})
@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
}
})
@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
}
})
@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
}
})
b.Run("pass on ip", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1))
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
}
})
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1))
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
}
})
}
func TestFirewall_Drop2(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
p := firewall.Packet{
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
firewall.ProtoUDP,
false,
}
@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
p := firewall.Packet{
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
1,
1,
fwProtoUDP,
firewall.ProtoUDP,
false,
}
@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c1,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h1.CreateRemoteCIDR(&c1)
@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c2,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h2.CreateRemoteCIDR(&c2)
@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c3,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
p := firewall.Packet{
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
firewall.ProtoUDP,
false,
}
@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
vpnIp: iputil.Ip2VpnIp(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) {
}
func TestNewFirewallFromConfig(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
// Test a bad rule definition
c := &cert.NebulaCertificate{}
conf := NewConfig(l)
conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) {
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups
conf = NewConfig(l)
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
// Test adding tcp rule
conf := NewConfig(l)
conf := config.NewC(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error
conf = NewConfig(l)
conf = config.NewC(l)
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
func resetConntrack(fw *Firewall) {
fw.Conntrack.Lock()
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Unlock()
}

View File

@ -1,11 +1,11 @@
package nebula
const (
handshakeIXPSK0 = 0
handshakeXXPSK0 = 1
import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
// First remote allow list check before we know the vpnIp
if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
}
switch h.Subtype {
case handshakeIXPSK0:
case header.HandshakeIXPSK0:
switch h.MessageCounter {
case 1:
ixHandshakeStage1(f, addr, packet, h)

View File

@ -6,13 +6,16 @@ import (
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
// NOISE IX Handshakes
// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hsBytes, err = proto.Marshal(hs)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
atomic.AddUint64(&ci.atomicMessageCounter, 1)
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
}
@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hostinfo.handshakeStart = time.Now()
}
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Invalid certificate from host")
return
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
if vpnIp == f.myVpnIp {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
}
if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP,
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
}
@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hostinfo.Lock()
defer hostinfo.Unlock()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hsBytes, err := proto.Marshal(hs)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert)
// Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
overwrite := vpnIp > f.myVpnIp
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil {
switch err {
@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
existing.Unlock()
hostinfo.Lock()
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
Error("Failed to add HostInfo due to localIndex collision")
return
case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
}
// Do the send
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
}
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.Lock()
defer hostinfo.Unlock()
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
ci := hostinfo.ConnectionState
if ci.ready {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Handshake is already complete")
@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false
}
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host")
@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
// Ensure the right host responded
if vpnIP != hostinfo.hostId {
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
if vpnIp != hostinfo.vpnIp {
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.hostId)
newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
newHostInfo.Lock()
// Block the current used address
@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
newHostInfo.remotes.BlockRemote(addr)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")
@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.ConnectionState.queueLock.Unlock()
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.hostId = vpnIP
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ci.window.Update(f.l, 2)
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).

View File

@ -11,6 +11,9 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
const (
@ -39,7 +42,7 @@ type HandshakeManager struct {
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
outside *udp.Conn
config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
@ -47,18 +50,18 @@ type HandshakeManager struct {
metricTimedOut metrics.Counter
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
// can be used to trigger outbound handshake for the given vpnIp
trigger chan iputil.VpnIp
}
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan uint32, config.triggerBuffer),
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
}
}
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now, f)
@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
c.OutboundHandshakeTimer.advance(now)
for {
ep := c.OutboundHandshakeTimer.Purge()
if ep == nil {
break
}
vpnIP := ep.(uint32)
c.handleOutbound(vpnIP, f, false)
vpnIp := ep.(iputil.VpnIp)
c.handleOutbound(vpnIp, f, false)
}
}
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
return
}
@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
if !hostinfo.HandshakeReady {
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
return
}
@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
if hostinfo.remotes == nil {
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
}
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIP, f)
c.lightHouse.QueryServer(vpnIp, f)
}
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udpAddr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr).
@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
//TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
}
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
//TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1)
return hostinfo
@ -208,12 +211,12 @@ var (
// CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
//
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet
//
// ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and the new handshake was older than the one we currently have
// VpnIp and the new handshake was older than the one we currently have
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
defer c.mainHostMap.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil {
// Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex")
}
// Check if we are also handshaking with this vpn ip
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
if found && pendingHostInfo != nil {
if !overwrite {
// We won, let our pending handshake win
@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex")
}

View File

@ -5,25 +5,29 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
l := NewTestLogger()
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := util.NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
i := blah.AddVpnIP(ip)
i := blah.AddVpnIp(ip)
i.remotes = NewRemoteList()
i.HandshakeReady = true
@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.AddVpnIP(ip)
hi := blah.AddVpnIp(ip)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
// Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
uaddr := NewUDPAddrFromString("10.1.1.1:4242")
uaddr := udp.NewAddrFromString("10.1.1.1:4242")
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
// We now have remotes but only the first trigger should have pushed things forward
@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
return
}

View File

@ -1,4 +1,4 @@
package nebula
package header
import (
"encoding/binary"
@ -19,82 +19,78 @@ import (
// |-----------------------------------------------------------------------|
// | payload... |
const (
Version uint8 = 1
HeaderLen = 16
)
type NebulaMessageType uint8
type NebulaMessageSubType uint8
type m map[string]interface{}
const (
handshake NebulaMessageType = 0
message NebulaMessageType = 1
recvError NebulaMessageType = 2
lightHouse NebulaMessageType = 3
test NebulaMessageType = 4
closeTunnel NebulaMessageType = 5
//TODO These are deprecated as of 06/12/2018 - NB
testRemote NebulaMessageType = 6
testRemoteReply NebulaMessageType = 7
Version uint8 = 1
Len = 16
)
var typeMap = map[NebulaMessageType]string{
handshake: "handshake",
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
type MessageType uint8
type MessageSubType uint8
//TODO These are deprecated as of 06/12/2018 - NB
testRemote: "testRemote",
testRemoteReply: "testRemoteReply",
const (
Handshake MessageType = 0
Message MessageType = 1
RecvError MessageType = 2
LightHouse MessageType = 3
Test MessageType = 4
CloseTunnel MessageType = 5
)
var typeMap = map[MessageType]string{
Handshake: "handshake",
Message: "message",
RecvError: "recvError",
LightHouse: "lightHouse",
Test: "test",
CloseTunnel: "closeTunnel",
}
const (
testRequest NebulaMessageSubType = 0
testReply NebulaMessageSubType = 1
TestRequest MessageSubType = 0
TestReply MessageSubType = 1
)
var eHeaderTooShort = errors.New("header is too short")
const (
HandshakeIXPSK0 MessageSubType = 0
HandshakeXXPSK0 MessageSubType = 1
)
var subTypeTestMap = map[NebulaMessageSubType]string{
testRequest: "testRequest",
testReply: "testReply",
var ErrHeaderTooShort = errors.New("header is too short")
var subTypeTestMap = map[MessageSubType]string{
TestRequest: "testRequest",
TestReply: "testReply",
}
var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"}
var subTypeNoneMap = map[MessageSubType]string{0: "none"}
var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{
message: &subTypeNoneMap,
recvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap,
test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap,
handshake: {
handshakeIXPSK0: "ix_psk0",
var subTypeMap = map[MessageType]*map[MessageSubType]string{
Message: &subTypeNoneMap,
RecvError: &subTypeNoneMap,
LightHouse: &subTypeNoneMap,
Test: &subTypeTestMap,
CloseTunnel: &subTypeNoneMap,
Handshake: {
HandshakeIXPSK0: "ix_psk0",
},
//TODO: these are deprecated
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
}
type Header struct {
type H struct {
Version uint8
Type NebulaMessageType
Subtype NebulaMessageSubType
Type MessageType
Subtype MessageSubType
Reserved uint16
RemoteIndex uint32
MessageCounter uint64
}
// HeaderEncode uses the provided byte array to encode the provided header values into.
// Encode uses the provided byte array to encode the provided header values into.
// Byte array must be capped higher than HeaderLen or this will panic
func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte {
b = b[:HeaderLen]
b[0] = byte(v<<4 | (t & 0x0f))
func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte {
b = b[:Len]
b[0] = v<<4 | byte(t&0x0f)
b[1] = byte(st)
binary.BigEndian.PutUint16(b[2:4], 0)
binary.BigEndian.PutUint32(b[4:8], ri)
@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b
}
// String creates a readable string representation of a header
func (h *Header) String() string {
func (h *H) String() string {
if h == nil {
return "<nil>"
}
@ -112,7 +108,7 @@ func (h *Header) String() string {
}
// MarshalJSON creates a json string representation of a header
func (h *Header) MarshalJSON() ([]byte, error) {
func (h *H) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"version": h.Version,
"type": h.TypeName(),
@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) {
}
// Encode turns header into bytes
func (h *Header) Encode(b []byte) ([]byte, error) {
func (h *H) Encode(b []byte) ([]byte, error) {
if h == nil {
return nil, errors.New("nil header")
}
return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil
return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil
}
// Parse is a helper function to parses given bytes into new Header struct
func (h *Header) Parse(b []byte) error {
if len(b) < HeaderLen {
return eHeaderTooShort
func (h *H) Parse(b []byte) error {
if len(b) < Len {
return ErrHeaderTooShort
}
// get upper 4 bytes
h.Version = uint8((b[0] >> 4) & 0x0f)
// get lower 4 bytes
h.Type = NebulaMessageType(b[0] & 0x0f)
h.Subtype = NebulaMessageSubType(b[1])
h.Type = MessageType(b[0] & 0x0f)
h.Subtype = MessageSubType(b[1])
h.Reserved = binary.BigEndian.Uint16(b[2:4])
h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error {
}
// TypeName will transform the headers message type into a human string
func (h *Header) TypeName() string {
func (h *H) TypeName() string {
return TypeName(h.Type)
}
// TypeName will transform a nebula message type into a human string
func TypeName(t NebulaMessageType) string {
func TypeName(t MessageType) string {
if n, ok := typeMap[t]; ok {
return n
}
@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string {
}
// SubTypeName will transform the headers message sub type into a human string
func (h *Header) SubTypeName() string {
func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok {
if x, ok := (*n)[s]; ok {
return x
@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
}
// NewHeader turns bytes into a header
func NewHeader(b []byte) (*Header, error) {
h := new(Header)
func NewHeader(b []byte) (*H, error) {
h := new(H)
if err := h.Parse(b); err != nil {
return nil, err
}

115
header/header_test.go Normal file
View File

@ -0,0 +1,115 @@
package header
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type headerTest struct {
expectedBytes []byte
*H
}
// 0001 0010 00010010
var headerBigEndianTests = []headerTest{{
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
// 1010 0000
H: &H{
// 1111 1+2+4+8 = 15
Version: 5,
Type: 4,
Subtype: 0,
Reserved: 0,
RemoteIndex: 10,
MessageCounter: 9,
},
},
}
func TestEncode(t *testing.T) {
for _, tt := range headerBigEndianTests {
b, err := tt.Encode(make([]byte, Len))
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tt.expectedBytes, b)
}
}
func TestParse(t *testing.T) {
for _, tt := range headerBigEndianTests {
b := tt.expectedBytes
parsedHeader := &H{}
parsedHeader.Parse(b)
if !reflect.DeepEqual(tt.H, parsedHeader) {
t.Fatalf("got %#v; want %#v", parsedHeader, tt.H)
}
}
}
func TestTypeName(t *testing.T) {
assert.Equal(t, "test", TypeName(Test))
assert.Equal(t, "test", (&H{Type: Test}).TypeName())
assert.Equal(t, "unknown", TypeName(99))
assert.Equal(t, "unknown", (&H{Type: 99}).TypeName())
}
func TestSubTypeName(t *testing.T) {
assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest))
assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(99, TestRequest))
assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(Test, 99))
assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName())
assert.Equal(t, "none", SubTypeName(Message, 0))
assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName())
}
func TestTypeMap(t *testing.T) {
// Force people to document this stuff
assert.Equal(t, map[MessageType]string{
Handshake: "handshake",
Message: "message",
RecvError: "recvError",
LightHouse: "lightHouse",
Test: "test",
CloseTunnel: "closeTunnel",
}, typeMap)
assert.Equal(t, map[MessageType]*map[MessageSubType]string{
Message: &subTypeNoneMap,
RecvError: &subTypeNoneMap,
LightHouse: &subTypeNoneMap,
Test: &subTypeTestMap,
CloseTunnel: &subTypeNoneMap,
Handshake: {
HandshakeIXPSK0: "ix_psk0",
},
}, subTypeMap)
}
func TestHeader_String(t *testing.T) {
assert.Equal(
t,
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
(&H{100, Test, TestRequest, 99, 98, 97}).String(),
)
}
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
string(b),
)
}

View File

@ -1,119 +0,0 @@
package nebula
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type headerTest struct {
expectedBytes []byte
*Header
}
// 0001 0010 00010010
var headerBigEndianTests = []headerTest{{
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
// 1010 0000
Header: &Header{
// 1111 1+2+4+8 = 15
Version: 5,
Type: 4,
Subtype: 0,
Reserved: 0,
RemoteIndex: 10,
MessageCounter: 9,
},
},
}
func TestEncode(t *testing.T) {
for _, tt := range headerBigEndianTests {
b, err := tt.Encode(make([]byte, HeaderLen))
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tt.expectedBytes, b)
}
}
func TestParse(t *testing.T) {
for _, tt := range headerBigEndianTests {
b := tt.expectedBytes
parsedHeader := &Header{}
parsedHeader.Parse(b)
if !reflect.DeepEqual(tt.Header, parsedHeader) {
t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
}
}
}
func TestTypeName(t *testing.T) {
assert.Equal(t, "test", TypeName(test))
assert.Equal(t, "test", (&Header{Type: test}).TypeName())
assert.Equal(t, "unknown", TypeName(99))
assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
}
func TestSubTypeName(t *testing.T) {
assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(99, testRequest))
assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(test, 99))
assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
assert.Equal(t, "none", SubTypeName(message, 0))
assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
}
func TestTypeMap(t *testing.T) {
// Force people to document this stuff
assert.Equal(t, map[NebulaMessageType]string{
handshake: "handshake",
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
testRemote: "testRemote",
testRemoteReply: "testRemoteReply",
}, typeMap)
assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
message: &subTypeNoneMap,
recvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap,
test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap,
handshake: {
handshakeIXPSK0: "ix_psk0",
},
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
}, subTypeMap)
}
func TestHeader_String(t *testing.T) {
assert.Equal(
t,
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
(&Header{100, test, testRequest, 99, 98, 97}).String(),
)
}
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
string(b),
)
}

View File

@ -12,6 +12,10 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
//const ProbeLen = 100
@ -28,10 +32,10 @@ type HostMap struct {
name string
Indexes map[uint32]*HostInfo
RemoteIndexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
unsafeRoutes *CIDRTree
unsafeRoutes *cidr.Tree4
metricsEnabled bool
l *logrus.Logger
}
@ -39,7 +43,7 @@ type HostMap struct {
type HostInfo struct {
sync.RWMutex
remote *udpAddr
remote *udp.Addr
remotes *RemoteList
promoteCounter uint32
ConnectionState *ConnectionState
@ -51,9 +55,9 @@ type HostInfo struct {
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
hostId uint32
vpnIp iputil.VpnIp
recvError int
remoteCidr *CIDRTree
remoteCidr *cidr.Tree4
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
@ -66,17 +70,17 @@ type HostInfo struct {
lastHandshakeTime uint64
lastRoam time.Time
lastRoamRemote *udpAddr
lastRoamRemote *udp.Addr
}
type cachedPacket struct {
messageType NebulaMessageType
messageSubType NebulaMessageSubType
messageType header.MessageType
messageSubType header.MessageSubType
callback packetCallback
packet []byte
}
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
type cachedPacketMetrics struct {
sent metrics.Counter
@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
}
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
m := HostMap{
@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
unsafeRoutes: NewCIDRTree(),
unsafeRoutes: cidr.NewTree4(),
l: l,
}
return &m
@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
}
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
if i, ok := hm.Hosts[vpnIp]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
h := &HostInfo{}
hm.RLock()
if _, ok := hm.Hosts[vpnIP]; !ok {
if _, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock()
h = &HostInfo{
promoteCounter: 0,
hostId: vpnIP,
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Lock()
hm.Hosts[vpnIP] = h
hm.Hosts[vpnIp] = h
hm.Unlock()
return h
} else {
h = hm.Hosts[vpnIP]
h = hm.Hosts[vpnIp]
hm.RUnlock()
return h
}
}
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
hm.Lock()
delete(hm.Hosts, vpnIP)
delete(hm.Hosts, vpnIp)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
Debug("Hostmap remoteIndex added")
}
}
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
hm.Lock()
h.hostId = vpnIP
hm.Hosts[vpnIP] = h
h.vpnIp = vpnIp
hm.Hosts[vpnIp] = h
hm.Indexes[h.localIndexId] = h
hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock()
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
Debug("Hostmap vpnIp added")
}
}
@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 != hostinfo {
delete(hm.Hosts, hostinfo2.hostId)
delete(hm.Hosts, hostinfo2.vpnIp)
delete(hm.Indexes, hostinfo2.localIndexId)
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
}
delete(hm.Hosts, hostinfo.hostId)
delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
}
delete(hm.Indexes, hostinfo.localIndexId)
if len(hm.Indexes) == 0 {
@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
}
@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
}
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil)
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
return hm.queryVpnIp(vpnIp, nil)
}
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, ifce)
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIp(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
return nil, errors.New("unable to find host")
}
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil {
return r.(uint32)
return r.(iputil.VpnIp)
} else {
return 0
}
@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
hm.Hosts[hostinfo.hostId] = hostinfo
hm.Hosts[hostinfo.vpnIp] = hostinfo
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added")
}
}
@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
}
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
}
}
@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
}
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
if addr == nil || !preferred {
return
}
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
})
}
// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce)
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
}
}
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
func (i *HostInfo) SetRemote(remote *udpAddr) {
func (i *HostInfo) SetRemote(remote *udp.Addr) {
// We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) {
i.remote = remote.Copy()
i.remotes.LearnRemote(i.hostId, remote.Copy())
i.remotes.LearnRemote(i.vpnIp, remote.Copy())
}
}
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
currentRemote := i.remote
if currentRemote == nil {
i.SetRemote(newRemote)
@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return
}
remoteCidr := NewCIDRTree()
remoteCidr := cidr.NewTree4()
for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", IntIp(i.hostId))
li := l.WithField("vpnIp", i.vpnIp)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return li
}
//########################
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
s := "\n"
for _, h := range hm.Hosts {
for _, r := range h.Remotes {
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
}
}
return s
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
r.ProbeReceived(counter)
}
}
}
func (i *HostInfo) Probes() []*Probe {
p := []*Probe{}
for _, d := range i.Remotes {
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
}
return p
}
*/
// Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {

View File

@ -5,9 +5,13 @@ import (
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
f.l.WithField("vpnIp", fwPacket.RemoteIP).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue.
ci.queueLock.Lock()
if !ci.ready {
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
ci.queueLock.Unlock()
return
}
@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil {
f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
} else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
}
// getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
//TODO: we can find contains without converting back to bytes
if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
if vpnIp == 0 {
return nil
}
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp)
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
}
}
ci := hostinfo.ConnectionState
@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
return hostinfo
}
func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
fp := &FirewallPacket{}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
return
}
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
f.l.WithField("vpnIp", vpnIp).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
return
}
func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
if ci.eKey == nil {
//TODO: log warning
return
@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.hostId)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.vpnIp)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming.
if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount {
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.hostId, f)
f.lightHouse.QueryServer(hostinfo.vpnIp, f)
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
return
}
func isMulticast(ip uint32) bool {
func isMulticast(ip iputil.VpnIp) bool {
// Class D multicast
if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
return true

View File

@ -12,6 +12,10 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
@ -27,7 +31,7 @@ type Inside interface {
type InterfaceConfig struct {
HostMap *HostMap
Outside *udpConn
Outside *udp.Conn
Inside Inside
certState *CertState
Cipher string
@ -39,7 +43,6 @@ type InterfaceConfig struct {
pendingDeletionInterval int
DropLocalBroadcast bool
DropMulticast bool
UDPBatchSize int
routines int
MessageMetrics *MessageMetrics
version string
@ -52,7 +55,7 @@ type InterfaceConfig struct {
type Interface struct {
hostMap *HostMap
outside *udpConn
outside *udp.Conn
inside Inside
certState *CertState
cipher string
@ -62,11 +65,10 @@ type Interface struct {
serveDns bool
createTime time.Time
lightHouse *LightHouse
localBroadcast uint32
myVpnIp uint32
localBroadcast iputil.VpnIp
myVpnIp iputil.VpnIp
dropLocalBroadcast bool
dropMulticast bool
udpBatchSize int
routines int
caPool *cert.NebulaCAPool
disconnectInvalid bool
@ -77,7 +79,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []*udpConn
writers []*udp.Conn
readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram
@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules")
}
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
ifce := &Interface{
hostMap: c.HostMap,
outside: c.Outside,
@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask),
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
routines: c.routines,
version: c.version,
writers: make([]*udpConn, c.routines),
writers: make([]*udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
disconnectInvalid: c.disconnectInvalid,
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
myVpnIp: myVpnIp,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -190,14 +192,17 @@ func (f *Interface) run() {
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li *udpConn
var li *udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
} else {
li = f.outside
}
li.ListenOut(f, i)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &FirewallPacket{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.reloadConfig)
c.RegisterReloadCallback(udpConn.ReloadConfig)
}
}
func (f *Interface) reloadCA(c *Config) {
func (f *Interface) reloadCA(c *config.C) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(f.l, c)
@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) {
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *Config) {
func (f *Interface) reloadCertKey(c *config.C) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) {
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *Config) {
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
f.l.Debug("No firewall config change detected")
@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()
udpStats := NewUDPStatsEmitter(f.writers)
udpStats := udp.NewUDPStatsEmitter(f.writers)
for {
select {

66
iputil/util.go Normal file
View File

@ -0,0 +1,66 @@
package iputil
import (
"encoding/binary"
"fmt"
"net"
)
type VpnIp uint32
const maxIPv4StringLen = len("255.255.255.255")
func (ip VpnIp) String() string {
b := make([]byte, maxIPv4StringLen)
n := ubtoa(b, 0, byte(ip>>24))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>16&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>8&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip&255))
return string(b[:n])
}
func (ip VpnIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
}
func (ip VpnIp) ToIP() net.IP {
nip := make(net.IP, 4)
binary.BigEndian.PutUint32(nip, uint32(ip))
return nip
}
func Ip2VpnIp(ip []byte) VpnIp {
if len(ip) == 16 {
return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
}
return VpnIp(binary.BigEndian.Uint32(ip))
}
// ubtoa encodes the string form of the integer v to dst[start:] and
// returns the number of bytes written to dst. The caller must ensure
// that dst has sufficient length.
func ubtoa(dst []byte, start int, v byte) int {
if v < 10 {
dst[start] = v + '0'
return 1
} else if v < 100 {
dst[start+1] = v%10 + '0'
dst[start] = v/10 + '0'
return 2
}
dst[start+2] = v%10 + '0'
dst[start+1] = (v/10)%10 + '0'
dst[start] = v/100 + '0'
return 3
}

17
iputil/util_test.go Normal file
View File

@ -0,0 +1,17 @@
package iputil
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestVpnIp_String(t *testing.T) {
assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
}

View File

@ -12,6 +12,9 @@ import (
"github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
@ -23,13 +26,13 @@ type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
amLighthouse bool
myVpnIp uint32
myVpnZeros uint32
punchConn *udpConn
myVpnIp iputil.VpnIp
myVpnZeros iputil.VpnIp
punchConn *udp.Conn
// Local cache of answers from light houses
// map of vpn Ip to answers
addrMap map[uint32]*RemoteList
addrMap map[iputil.VpnIp]*RemoteList
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
@ -42,12 +45,12 @@ type LightHouse struct {
localAllowList *LocalAllowList
// used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- uint32
handshakeTrigger chan<- iputil.VpnIp
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList map[uint32]struct{}
lighthouses map[uint32]struct{}
staticList map[iputil.VpnIp]struct{}
lighthouses map[iputil.VpnIp]struct{}
interval int
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
punchBack bool
@ -58,20 +61,16 @@ type LightHouse struct {
l *logrus.Logger
}
type EncWriter interface {
SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
ones, _ := myVpnIpNet.Mask.Size()
h := LightHouse{
amLighthouse: amLighthouse,
myVpnIp: ip2int(myVpnIpNet.IP),
myVpnZeros: uint32(32 - ones),
addrMap: make(map[uint32]*RemoteList),
myVpnIp: iputil.Ip2VpnIp(myVpnIpNet.IP),
myVpnZeros: iputil.VpnIp(32 - ones),
addrMap: make(map[iputil.VpnIp]*RemoteList),
nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}),
staticList: make(map[uint32]struct{}),
lighthouses: make(map[iputil.VpnIp]struct{}),
staticList: make(map[iputil.VpnIp]struct{}),
interval: interval,
punchConn: pc,
punchBack: punchBack,
@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
func (lh *LightHouse) ValidateLHStaticEntries() error {
for lhIP, _ := range lh.lighthouses {
if _, ok := lh.staticList[lhIP]; !ok {
return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP))
return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
}
}
return nil
}
func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
if lh.amLighthouse {
return
}
@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
return
}
@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
}
}
func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock()
// Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok {
@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err
return false, 0, nil
}
func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.staticList[vpnIP]; ok {
if _, ok := lh.staticList[vpnIp]; ok {
return
}
lh.Lock()
//l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP)
delete(lh.addrMap, vpnIp)
if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
lh.l.Debugf("deleting %s from lighthouse.", vpnIp)
}
lh.Unlock()
@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
}
// unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
am, ok := lh.addrMap[vpnIP]
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
am, ok := lh.addrMap[vpnIp]
if !ok {
am = NewRemoteList()
lh.addrMap[vpnIP] = am
lh.addrMap[vpnIp] = am
}
return am
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip)
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) {
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
return false
}
@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
}
// unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool {
func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip
}
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
if _, ok := lh.lighthouses[vpnIP]; ok {
func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
if _, ok := lh.lighthouses[vpnIp]; ok {
return true
}
return false
}
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
return &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: VpnIp,
VpnIp: uint32(VpnIp),
},
}
}
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
ipp := Ip4AndPort{Port: port}
ipp.Ip = ip2int(ip)
ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
return &ipp
}
@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
}
}
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr {
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
ip := ipp.Ip
return NewUDPAddr(
return udp.NewAddr(
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
uint16(ipp.Port),
)
}
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
if lh.amLighthouse || lh.interval == 0 {
return
}
@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
}
}
func (lh *LightHouse) SendUpdate(f EncWriter) {
func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
var v4 []*Ip4AndPort
var v6 []*Ip6AndPort
for _, e := range *localIps(lh.l, lh.localAllowList) {
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) {
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
continue
}
@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: lh.myVpnIp,
VpnIp: uint32(lh.myVpnIp),
Ip4AndPorts: v4,
Ip6AndPorts: v6,
},
@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
}
for vpnIp := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
}
}
@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
}
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
lh.metrics.Rx(header.MessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
lh.metrics.Tx(header.MessageType(t), 0, i)
}
// This method is similar to Reset(), but it re-uses the pointer structs
@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error?
return
}
if n.Details == nil {
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
//TODO: send recv_error?
return
@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
}
}
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
}
//TODO: we can DRY this further
reqVpnIP := n.Details.VpnIp
reqVpnIp := n.Details.VpnIp
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIP
n.Details.VpnIp = reqVpnIp
lhh.coalesceAnswers(c, n)
@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
}
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply")
return
}
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = vpnIp
n.Details.VpnIp = uint32(vpnIp)
lhh.coalesceAnswers(c, n)
@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
}
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for")
lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for")
return
}
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
}
}
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
// Non-blocking attempt to trigger, skip if it would block
select {
case lhh.lh.handshakeTrigger <- n.Details.VpnIp:
case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
default:
}
}
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) {
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
}
//Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp {
if n.Details.VpnIp != uint32(vpnIp) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
}
return
}
@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
certVpnIp := iputil.VpnIp(n.Details.VpnIp)
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
}
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
empty := []byte{0}
punch := func(vpnPeer *udpAddr) {
punch := func(vpnPeer *udp.Addr) {
if vpnPeer == nil {
return
}
@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp))
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
}
}
@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
go func() {
time.Sleep(time.Second * 5)
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp))
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
// ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
return (testIp^ip)>>zeros == 0
}

View File

@ -6,6 +6,10 @@ import (
"testing"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) {
var m Ip4AndPort
err := proto.Unmarshal(b, &m)
assert.NoError(t, err)
assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String())
assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
}
func TestNewLhQuery(t *testing.T) {
myIp := net.ParseIP("192.1.1.1")
myIpint := ip2int(myIp)
myIpint := iputil.Ip2VpnIp(myIp)
// Generating a new lh query should work
a := NewLhQueryByInt(myIpint)
@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) {
}
func Test_lhStaticMapping(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
err := meh.ValidateLHStaticEntries()
assert.Nil(t, err)
lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := NewTestLogger()
l := util.NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
hAddr := udp.NewAddrFromString("4.5.6.7:12345")
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList()
lh.addrMap[3].unlockedSetV4(
3,
@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
},
func(uint32, *Ip4AndPort) bool { return true },
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
)
rAddr := NewUDPAddrFromString("1.2.2.3:12345")
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
rAddr := udp.NewAddrFromString("1.2.2.3:12345")
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList()
lh.addrMap[2].unlockedSetV4(
3,
@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
},
func(uint32, *Ip4AndPort) bool { return true },
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{}
@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}
func TestLighthouse_Memory(t *testing.T) {
l := NewTestLogger()
l := util.NewTestLogger()
myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
myVpnIp := ip2int(net.ParseIP("10.128.0.2"))
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
theirVpnIp := ip2int(net.ParseIP("10.128.0.3"))
theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false)
udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh)
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh)
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh)
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) {
newLHHostUpdate(
myUdpAddr0,
myVpnIp,
[]*udpAddr{
[]*udp.Addr{
myUdpAddr1,
myUdpAddr2,
myUdpAddr3,
@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) {
)
// Make sure we won't add ips in our vpn network
bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh)
bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
}
func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply {
func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: queryVpnIp,
VpnIp: uint32(queryVpnIp),
},
}
@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH
return w.lastReply
}
func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) {
func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: vpnIp,
VpnIp: uint32(vpnIp),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
}
for k, v := range addrs {
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)}
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
}
b, err := req.Marshal()
@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
//}
func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
}
type testLhReply struct {
nebType NebulaMessageType
nebSubType NebulaMessageSubType
vpnIp uint32
nebType header.MessageType
nebSubType header.MessageSubType
vpnIp iputil.VpnIp
msg *NebulaMeta
}
@ -343,7 +347,7 @@ type testEncWriter struct {
lastReply testLhReply
}
func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) {
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
tw.lastReply = testLhReply{
nebType: t,
nebSubType: st,
@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) {
if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
}
}
}
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
}
}
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
addrs := make([]*udpAddr, len(ips))
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
addrs := make([]*udp.Addr, len(ips))
for k, v := range ips {
addrs[k] = NewUDPAddrFromLH4(v)
}

View File

@ -2,8 +2,12 @@ package nebula
import (
"errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
type ContextualError struct {
@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

139
main.go
View File

@ -8,14 +8,16 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
"gopkg.in/yaml.v2"
)
type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(config.Settings)
b, err := yaml.Marshal(c.Settings)
if err != nil {
return nil, err
}
@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
l.Println(string(b))
}
err := configLogger(config)
err := configLogger(l, c)
if err != nil {
return nil, NewContextualError("Failed to configure the logger", nil, err)
}
config.RegisterReloadCallback(func(c *Config) {
err := configLogger(c)
c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(l, c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
caPool, err := loadCAFromConfig(l, config)
caPool, err := loadCAFromConfig(l, c)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
cs, err := NewCertStateFromConfig(c)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
return nil, NewContextualError("Failed to load certificate from config", nil, err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err)
}
@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// TODO: make sure mask is 4 bytes
tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr)
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, config)
wireSSHReload(l, ssh, c)
var sshStart func()
if config.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, config)
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var routines int
// If `routines` is set, use that and ignore the specific values
if routines = config.GetInt("routines", 0); routines != 0 {
if routines = c.GetInt("routines", 0); routines != 0 {
if routines < 1 {
routines = 1
}
@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
} else {
// deprecated and undocumented
tunQueues := config.GetInt("tun.routines", 1)
udpQueues := config.GetInt("listen.routines", 1)
tunQueues := c.GetInt("tun.routines", 1)
udpQueues := c.GetInt("listen.routines", 1)
if tunQueues > udpQueues {
routines = tunQueues
} else {
@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// EXPERIMENTAL
// Intentionally not documented yet while we do more testing and determine
// a good default value.
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
// Use a different default if we are running with multiple routines
conntrackCacheTimeout = 1 * time.Second
}
@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var tun Inside
if !configTest {
config.CatchHUP(ctx)
c.CatchHUP(ctx)
switch {
case config.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
case c.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
case tunFd != nil:
tun, err = newTunFromFd(
l,
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
c.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
c.GetInt("tun.tx_queue", 500),
)
default:
tun, err = newTun(
l,
config.GetString("tun.dev", ""),
c.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
c.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
c.GetInt("tun.tx_queue", 500),
routines > 1,
)
}
@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}()
// set up our UDP listener
udpConns := make([]*udpConn, routines)
port := config.GetInt("listen.port", 0)
udpConns := make([]*udp.Conn, routines)
port := c.GetInt("listen.port", 0)
if !configTest {
for i := 0; i < routines; i++ {
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
udpServer.reloadConfig(config)
udpServer.ReloadConfig(c)
udpConns[i] = udpServer
// If port is dynamic, discover it
@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := config.GetString("local_range", "")
rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go hostMap.Promoter(config.GetInt("promoter.interval"))
*/
punchy := NewPunchyFromConfig(config)
punchy := NewPunchyFromConfig(c)
if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled")
go hostMap.Punchy(ctx, udpConns[0])
}
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
// fatal if am_lighthouse is enabled but we are using an ephemeral port
if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
}
// warn if am_lighthouse is enabled but upstream lighthouses exists
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
if amLighthouse && len(rawLighthouseHosts) != 0 {
l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host)
if ip == nil {
@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
}
lighthouseHosts[i] = ip2int(ip)
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
}
lightHouse := NewLightHouse(
@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
tunCidr,
lighthouseHosts,
//TODO: change to a duration
config.GetInt("lighthouse.interval", 10),
c.GetInt("lighthouse.interval", 10),
uint32(port),
udpConns[0],
punchy.Respond,
punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false),
c.GetBool("stats.lighthouse_metrics", false),
)
remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
}
lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list")
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
}
lightHouse.SetLocalAllowList(localAllowList)
//TODO: Move all of this inside functions in lighthouse.go
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) {
for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
ip := net.ParseIP(fmt.Sprintf("%v", k))
vpnIp := iputil.Ip2VpnIp(ip)
if !tunCidr.Contains(ip) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
}
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
}
} else {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
}
}
@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
var messageMetrics *MessageMetrics
if config.GetBool("stats.message_metrics", false) {
if c.GetBool("stats.message_metrics", false) {
messageMetrics = newMessageMetrics()
} else {
messageMetrics = newMessageMetricsOnlyRecvError()
}
handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
}
@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
serveDns := false
if config.GetBool("lighthouse.serve_dns", false) {
if config.GetBool("lighthouse.am_lighthouse", false) {
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true
} else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
}
checkInterval := config.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Outside: udpConns[0],
certState: cs,
Cipher: config.GetString("cipher", "aes"),
Cipher: c.GetString("cipher", "aes"),
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
lightHouse: lightHouse,
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
caPool: caPool,
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false),
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns
ifce.RegisterConfigChangeCallbacks(config)
ifce.RegisterConfigChangeCallbacks(c)
go handshakeManager.Run(ctx, ifce)
go lightHouse.LhUpdateWorker(ctx, ifce)
@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, config, buildVersion, configTest)
statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
var dnsStart func()
if amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, config)
dnsStart = dnsMain(l, hostMap, c)
}
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil

View File

@ -4,8 +4,11 @@ import (
"fmt"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/header"
)
//TODO: this can probably move into the header package
type MessageMetrics struct {
rx [][]metrics.Counter
tx [][]metrics.Counter
@ -14,7 +17,7 @@ type MessageMetrics struct {
txUnknown metrics.Counter
}
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i)
@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64
}
}
}
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i)

View File

@ -10,6 +10,10 @@ import (
"github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
)
@ -17,8 +21,8 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
err := header.Parse(packet)
func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
//l.Error("in packet ", header, packet[HeaderLen:])
// verify if we've seen this index before, otherwise respond to the handshake initiation
hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex)
hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex)
var ci *ConnectionState
if err == nil {
ci = hostinfo.ConnectionState
}
switch header.Type {
case message:
if !f.handleEncrypted(ci, addr, header) {
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, addr, h) {
return
}
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
// Fallthrough to the bottom to record incoming traffic
case lightHouse:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
lhh.HandleRequest(addr, hostinfo.hostId, d, f)
lhf(addr, hostinfo.vpnIp, d, f)
// Fallthrough to the bottom to record incoming traffic
case test:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
if header.Subtype == testRequest {
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, addr)
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case handshake:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, h, hostinfo)
return
case recvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
f.handleRecvError(addr, header)
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(addr, h)
return
case closeTunnel:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) {
return
}
@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
default:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
return
}
f.handleHostRoaming(hostinfo, addr)
f.connectionManager.In(hostinfo.hostId)
f.connectionManager.In(hostinfo.vpnIp)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId)
f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
f.lightHouse.DeleteVpnIP(hostInfo.hostId)
f.connectionManager.ClearIP(hostInfo.vpnIp)
f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp)
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
if hasHostMapLock {
f.hostMap.unlockedDeleteHostInfo(hostInfo)
@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
if !hostinfo.remote.Equals(addr) {
if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
}
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex)
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
f.sendRecvError(addr, h.RemoteIndex)
return false
}
@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *
}
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl
if !fp.Fragment && fp.Protocol != fwProtoICMP {
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
minLen += minFwPacketLen
}
if len(data) < minLen {
@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
// Firewall packets are locally oriented
if incoming {
fp.RemoteIP = binary.BigEndian.Uint32(data[12:16])
fp.LocalIP = binary.BigEndian.Uint32(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP {
fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
fp.LocalIP = binary.BigEndian.Uint32(data[12:16])
fp.RemoteIP = binary.BigEndian.Uint32(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP {
fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
return nil
}
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) {
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb)
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
if err != nil {
return nil, err
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", header).
hostinfo.logger(f.l).WithField("header", h).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
}
@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return
}
f.connectionManager.In(hostinfo.hostId)
f.connectionManager.In(hostinfo.vpnIp)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
}
}
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
f.messageMetrics.Tx(recvError, 0, 1)
func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
}
}
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).

View File

@ -4,12 +4,14 @@ import (
"net"
"testing"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
)
func Test_newPacket(t *testing.T) {
p := &FirewallPacket{}
p := &firewall.Packet{}
// length fail
err := newPacket([]byte{0, 1}, true, p)
@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) {
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Options: []byte{0, 1, 0, 2},
Protocol: fwProtoTCP,
Protocol: firewall.ProtoTCP,
}
b, _ = h.Marshal()
@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) {
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(fwProtoTCP))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}

View File

@ -1,6 +1,10 @@
package nebula
import "time"
import (
"time"
"github.com/slackhq/nebula/config"
)
type Punchy struct {
Punch bool
@ -8,7 +12,7 @@ type Punchy struct {
Delay time.Duration
}
func NewPunchyFromConfig(c *Config) *Punchy {
func NewPunchyFromConfig(c *config.C) *Punchy {
p := &Punchy{}
if c.IsSet("punchy.punch") {

View File

@ -4,12 +4,14 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestNewPunchyFromConfig(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
l := util.NewTestLogger()
c := config.NewC(l)
// Test defaults
p := NewPunchyFromConfig(c)

View File

@ -5,14 +5,17 @@ import (
"net"
"sort"
"sync"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udpAddr, preferred bool)
type forEachFunc func(addr *udp.Addr, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool
type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool
type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
@ -21,8 +24,8 @@ type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
Learned []*udpAddr `json:"learned,omitempty"`
Reported []*udpAddr `json:"reported,omitempty"`
Learned []*udp.Addr `json:"learned,omitempty"`
Reported []*udp.Addr `json:"reported,omitempty"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@ -53,16 +56,16 @@ type RemoteList struct {
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udpAddr
addrs []*udp.Addr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
cache map[uint32]*cache
cache map[iputil.VpnIp]*cache
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
badRemotes []*udp.Addr
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
@ -71,8 +74,8 @@ type RemoteList struct {
// NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList {
return &RemoteList{
addrs: make([]*udpAddr, 0),
cache: make(map[uint32]*cache),
addrs: make([]*udp.Addr, 0),
cache: make(map[iputil.VpnIp]*cache),
}
}
@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc)
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
if r == nil {
return nil
}
@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.addrs))
c := make([]*udp.Addr, len(r.addrs))
for i, v := range r.addrs {
c[i] = v.Copy()
}
@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
r.Lock()
defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil {
@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
c := cm[vpnIp]
if c == nil {
c = &Cache{
Learned: make([]*udpAddr, 0),
Reported: make([]*udpAddr, 0),
Learned: make([]*udp.Addr, 0),
Reported: make([]*udp.Addr, 0),
}
cm[vpnIp] = c
}
@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
for owner, mc := range r.cache {
c := getOrMake(IntIp(owner).String())
c := getOrMake(owner.String())
if mc.v4 != nil {
if mc.v4.learned != nil {
@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udpAddr) {
func (r *RemoteList) BlockRemote(bad *udp.Addr) {
r.Lock()
defer r.Unlock()
@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) {
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.badRemotes))
c := make([]*udp.Addr, len(r.badRemotes))
for i, v := range r.badRemotes {
c[i] = v.Copy()
}
@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
for _, v := range r.badRemotes {
if v.Equals(remote) {
return true
@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}

View File

@ -4,6 +4,7 @@ import (
"net"
"testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) {
0,
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
},
func(uint32, *Ip4AndPort) bool { return true },
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
},
func(uint32, *Ip6AndPort) bool { return true },
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
)
rl.Rebuild([]*net.IPNet{})
@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) {
0,
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
},
func(uint32, *Ip4AndPort) bool { return true },
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) {
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(uint32, *Ip6AndPort) bool { return true },
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) {
0,
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
},
func(uint32, *Ip4AndPort) bool { return true },
func(iputil.VpnIp, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) {
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(uint32, *Ip6AndPort) bool { return true },
func(iputil.VpnIp, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {

56
ssh.go
View File

@ -15,7 +15,11 @@ import (
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
)
type sshListHostMapFlags struct {
@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct {
Address string
}
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) {
sshRun, err := configSSH(l, ssh, c)
if err != nil {
@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
//TODO conntrack list
//TODO print firewall rules or hash?
@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
hm := listHostMap(hostMap)
sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
})
if fs.Json || fs.Pretty {
@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
} else {
for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
if err != nil {
return err
}
@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
}
type lighthouseInfo struct {
VpnIP net.IP `json:"vpnIp"`
VpnIp string `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"`
}
@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
x := 0
for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{
VpnIP: int2ip(k),
VpnIp: k.String(),
Addrs: v.CopyCache(),
}
x++
@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool {
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
})
if fs.Json || fs.Pretty {
@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
if err != nil {
return err
}
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
if err != nil {
return err
}
@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
if !flags.LocalOnly {
ifce.send(
closeTunnel,
header.CloseTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
var addr *udpAddr
var addr *udp.Addr
if flags.Address != "" {
addr = NewUDPAddrFromString(flags.Address)
addr = udp.NewAddrFromString(flags.Address)
if addr == nil {
return w.WriteLine("Address could not be parsed")
}
}
hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
if addr != nil {
hostInfo.SetRemote(addr)
}
@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No address was provided")
}
addr := NewUDPAddrFromString(flags.Address)
addr := udp.NewAddrFromString(flags.Address)
if addr == nil {
return w.WriteLine("Address could not be parsed")
}
@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
vpnIp := iputil.Ip2VpnIp(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}

View File

@ -15,12 +15,13 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
// startStats initializes stats from config. On success, if any futher work
// is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil, nil
@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
return startFn, nil
}
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
return nil
}
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")

View File

@ -2,12 +2,14 @@ package nebula
import (
"time"
"github.com/slackhq/nebula/firewall"
)
// How many timer objects should be cached
const timerCacheMax = 50000
var emptyFWPacket = FirewallPacket{}
var emptyFWPacket = firewall.Packet{}
type TimerWheel struct {
// Current tick
@ -42,7 +44,7 @@ type TimeoutList struct {
// Represents an item within a tick
type TimeoutItem struct {
Packet FirewallPacket
Packet firewall.Packet
Next *TimeoutItem
}
@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
return &tw
}
// Add will add a FirewallPacket to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem {
// Add will add a firewall.Packet to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
// Check and see if we should progress the tick
tw.advance(time.Now())
@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem
return ti
}
func (tw *TimerWheel) Purge() (FirewallPacket, bool) {
func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
if tw.expired.Head == nil {
return emptyFWPacket, false
}

View File

@ -3,6 +3,8 @@ package nebula
import (
"sync"
"time"
"github.com/slackhq/nebula/iputil"
)
// How many timer objects should be cached
@ -43,7 +45,7 @@ type SystemTimeoutList struct {
// Represents an item within a tick
type SystemTimeoutItem struct {
Item uint32
Item iputil.VpnIp
Next *SystemTimeoutItem
}
@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
return &tw
}
func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem {
func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
tw.lock.Lock()
defer tw.lock.Unlock()

View File

@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
func TestSystemTimerWheel_Add(t *testing.T) {
tw := NewSystemTimerWheel(time.Second, time.Second*10)
fp1 := ip2int(net.ParseIP("1.2.3.4"))
fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly
@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) {
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head
fp2 := ip2int(net.ParseIP("1.2.3.4"))
fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current)
fps := []uint32{9, 10, 11, 12}
fps := []iputil.VpnIp{9, 10, 11, 12}
//fp1 := ip2int(net.ParseIP("1.2.3.4"))

View File

@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
)
@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) {
func TestTimerWheel_Add(t *testing.T) {
tw := NewTimerWheel(time.Second, time.Second*10)
fp1 := FirewallPacket{}
fp1 := firewall.Packet{}
tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly
@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) {
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head
fp2 := FirewallPacket{}
fp2 := firewall.Packet{}
tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current)
fps := []FirewallPacket{
fps := []firewall.Packet{
{LocalIP: 1},
{LocalIP: 2},
{LocalIP: 3},

View File

@ -4,6 +4,8 @@ import (
"fmt"
"net"
"strconv"
"github.com/slackhq/nebula/config"
)
const DEFAULT_MTU = 1300
@ -14,10 +16,10 @@ type route struct {
via *net.IP
}
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
var err error
r := config.Get("tun.routes")
r := c.Get("tun.routes")
if r == nil {
return []route{}, nil
}
@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
return routes, nil
}
func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
var err error
r := config.Get("tun.unsafe_routes")
r := c.Get("tun.unsafe_routes")
if r == nil {
return []route{}, nil
}
@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
rMtu, ok := m["mtu"]
if !ok {
rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
}
mtu, ok := rMtu.(int)

View File

@ -5,12 +5,14 @@ import (
"net"
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func Test_parseRoutes(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
l := util.NewTestLogger()
c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) {
}
func Test_parseUnsafeRoutes(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
l := util.NewTestLogger()
c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config

20
udp/conn.go Normal file
View File

@ -0,0 +1,20 @@
package udp
import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
const MTU = 9001
type EncReader func(
addr *Addr,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
)

14
udp/temp.go Normal file
View File

@ -0,0 +1,14 @@
package udp
import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
type EncWriter interface {
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
}
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)

View File

@ -1,4 +1,4 @@
package nebula
package udp
import (
"encoding/json"
@ -7,32 +7,34 @@ import (
"strconv"
)
type udpAddr struct {
type m map[string]interface{}
type Addr struct {
IP net.IP
Port uint16
}
func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
func NewAddr(ip net.IP, port uint16) *Addr {
addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip.To16())
return &addr
}
func NewUDPAddrFromString(s string) *udpAddr {
ip, port, err := parseIPAndPort(s)
func NewAddrFromString(s string) *Addr {
ip, port, err := ParseIPAndPort(s)
//TODO: handle err
_ = err
return &udpAddr{IP: ip.To16(), Port: port}
return &Addr{IP: ip.To16(), Port: port}
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
func (ua *Addr) Equals(t *Addr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *udpAddr) String() string {
func (ua *Addr) String() string {
if ua == nil {
return "<nil>"
}
@ -40,7 +42,7 @@ func (ua *udpAddr) String() string {
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
}
func (ua *udpAddr) MarshalJSON() ([]byte, error) {
func (ua *Addr) MarshalJSON() ([]byte, error) {
if ua == nil {
return nil, nil
}
@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) {
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
}
func (ua *udpAddr) Copy() *udpAddr {
func (ua *Addr) Copy() *Addr {
if ua == nil {
return nil
}
nu := udpAddr{
nu := Addr{
Port: ua.Port,
IP: make(net.IP, len(ua.IP)),
}
@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr {
return &nu
}
func parseIPAndPort(s string) (net.IP, uint16, error) {
func ParseIPAndPort(s string) (net.IP, uint16, error) {
rIp, sPort, err := net.SplitHostPort(s)
if err != nil {
return nil, 0, err

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
package udp
import (
"fmt"
@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
return nil
}

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
file, err := u.File()
if err != nil {
return err

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
package udp
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
return nil
}

View File

@ -5,7 +5,7 @@
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
package nebula
package udp
import (
"context"
@ -13,36 +13,39 @@ import (
"net"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
type udpConn struct {
type Conn struct {
*net.UDPConn
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc, l: l}, nil
return &Conn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
return err
}
func (uc *udpConn) LocalAddr() (*udpAddr, error) {
func (uc *Conn) LocalAddr() (*Addr, error) {
a := uc.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
addr := &udpAddr{IP: make([]byte, len(v.IP))}
addr := &Addr{IP: make([]byte, len(v.IP))}
copy(addr.IP, v.IP)
addr.Port = uint16(v.Port)
return addr, nil
@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
}
}
func (u *udpConn) reloadConfig(c *Config) {
func (u *Conn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
func NewUDPStatsEmitter(udpConns []*Conn) func() {
// No UDP stats for non-linux
return func() {}
}
@ -65,32 +68,24 @@ type rawMessage struct {
Len uint32
}
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
buffer := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{IP: make([]byte, 16)}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
f.l.WithError(err).Error("Failed to read packets")
u.l.WithError(err).Error("Failed to read packets")
continue
}
udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -1,7 +1,7 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package nebula
package udp
import (
"encoding/binary"
@ -12,14 +12,18 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/unix"
)
//TODO: make it support reload as best you can!
type udpConn struct {
type Conn struct {
sysFd int
l *logrus.Logger
batch int
}
var x int
@ -41,7 +45,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &udpConn{sysFd: fd, l: l}, err
return &Conn{sysFd: fd, l: l, batch: batch}, err
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
return nil
}
func (u *udpConn) SetRecvBuffer(n int) error {
func (u *Conn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
}
func (u *udpConn) SetSendBuffer(n int) error {
func (u *Conn) SetSendBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
}
func (u *udpConn) GetRecvBuffer() (int, error) {
func (u *Conn) GetRecvBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}
func (u *udpConn) GetSendBuffer() (int, error) {
func (u *Conn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
func (u *Conn) LocalAddr() (*Addr, error) {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
return nil, err
}
addr := &udpAddr{}
addr := &Addr{}
switch sa := sa.(type) {
case *unix.SockaddrInet4:
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
return addr, nil
}
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
udpAddr := &Addr{}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if f.udpBatchSize == 1 {
if u.batch == 1 {
read = u.ReadSingle
}
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := read(msgs)
if err != nil {
@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
}
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMSG,
@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
}
}
func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMMSG,
@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
}
}
func (u *udpConn) reloadConfig(c *Config) {
func (u *Conn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
err := u.SetRecvBuffer(b)
@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) {
}
}
func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
return nil
}
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
func NewUDPStatsEmitter(udpConns []*Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO
@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() {
}
}
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -4,7 +4,7 @@
// +build !android
// +build !e2e_testing
package nebula
package udp
import (
"golang.org/x/sys/unix"
@ -30,13 +30,13 @@ type rawMessage struct {
Len uint32
}
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, mtu)
buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array

View File

@ -4,7 +4,7 @@
// +build !android
// +build !e2e_testing
package nebula
package udp
import (
"golang.org/x/sys/unix"
@ -33,13 +33,13 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, mtu)
buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array

View File

@ -1,16 +1,19 @@
//go:build e2e_testing
// +build e2e_testing
package nebula
package udp
import (
"fmt"
"net"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
type UdpPacket struct {
type Packet struct {
ToIp net.IP
ToPort uint16
FromIp net.IP
@ -18,8 +21,8 @@ type UdpPacket struct {
Data []byte
}
func (u *UdpPacket) Copy() *UdpPacket {
n := &UdpPacket{
func (u *Packet) Copy() *Packet {
n := &Packet{
ToIp: make(net.IP, len(u.ToIp)),
ToPort: u.ToPort,
FromIp: make(net.IP, len(u.FromIp)),
@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket {
return n
}
type udpConn struct {
addr *udpAddr
type Conn struct {
Addr *Addr
rxPackets chan *UdpPacket // Packets to receive into nebula
txPackets chan *UdpPacket // Packets transmitted outside by nebula
RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) {
return &udpConn{
addr: &udpAddr{net.ParseIP(ip), uint16(port)},
rxPackets: make(chan *UdpPacket, 1),
txPackets: make(chan *UdpPacket, 1),
func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
return &Conn{
Addr: &Addr{net.ParseIP(ip), uint16(port)},
RxPackets: make(chan *Packet, 1),
TxPackets: make(chan *Packet, 1),
l: l,
}, nil
}
@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
// Send will place a UdpPacket onto the receive queue for nebula to consume
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *udpConn) Send(packet *UdpPacket) {
h := &Header{}
func (u *Conn) Send(packet *Packet) {
h := &header.H{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) {
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet")
u.rxPackets <- packet
u.RxPackets <- packet
}
// Get will pull a UdpPacket from the transmit queue
// nebula meant to send this message on the network, it will be encrypted
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
func (u *udpConn) Get(block bool) *UdpPacket {
func (u *Conn) Get(block bool) *Packet {
if block {
return <-u.txPackets
return <-u.TxPackets
}
select {
case p := <-u.txPackets:
case p := <-u.TxPackets:
return p
default:
return nil
@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
p := &UdpPacket{
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
p := &Packet{
Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16),
FromPort: u.addr.Port,
FromPort: u.Addr.Port,
ToIp: make([]byte, 16),
ToPort: addr.Port,
}
copy(p.Data, b)
copy(p.ToIp, addr.IP.To16())
copy(p.FromIp, u.addr.IP.To16())
copy(p.FromIp, u.Addr.IP.To16())
u.txPackets <- p
u.TxPackets <- p
return nil
}
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
ua := &udpAddr{IP: make([]byte, 16)}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
ua := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
p := <-u.rxPackets
p := <-u.RxPackets
ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16())
f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
func (u *udpConn) reloadConfig(*Config) {}
func (u *Conn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []*udpConn) func() {
func NewUDPStatsEmitter(_ []*Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
return u.addr, nil
func (u *Conn) LocalAddr() (*Addr, error) {
return u.Addr, nil
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
return nil
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@ -1,7 +1,7 @@
//go:build !e2e_testing
// +build !e2e_testing
package nebula
package udp
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *udpConn) Rebind() error {
func (u *Conn) Rebind() error {
return nil
}

View File

@ -1,4 +1,4 @@
package nebula
package util
import (
"io/ioutil"
@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger {
}
switch v {
case "1":
// This is the default level but we are being explicit
l.SetLevel(logrus.InfoLevel)
case "2":
l.SetLevel(logrus.DebugLevel)
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
}
return l