diff --git a/allow_list.go b/allow_list.go new file mode 100644 index 0000000..223d185 --- /dev/null +++ b/allow_list.go @@ -0,0 +1,48 @@ +package nebula + +import ( + "fmt" + "regexp" +) + +type AllowList struct { + // The values of this cidrTree are `bool`, signifying allow/deny + cidrTree *CIDRTree + + // To avoid ambiguity, all rules must be true, or all rules must be false. + nameRules []AllowListNameRule +} + +type AllowListNameRule struct { + Name *regexp.Regexp + Allow bool +} + +func (al *AllowList) Allow(ip uint32) bool { + if al == nil { + return true + } + + result := al.cidrTree.MostSpecificContains(ip) + switch v := result.(type) { + case bool: + return v + default: + panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) + } +} + +func (al *AllowList) AllowName(name string) bool { + if al == nil || len(al.nameRules) == 0 { + return true + } + + for _, rule := range al.nameRules { + if rule.Name.MatchString(name) { + return rule.Allow + } + } + + // If no rules match, return the default, which is the inverse of the rules + return !al.nameRules[0].Allow +} diff --git a/allow_list_test.go b/allow_list_test.go new file mode 100644 index 0000000..56d9ef8 --- /dev/null +++ b/allow_list_test.go @@ -0,0 +1,47 @@ +package nebula + +import ( + "net" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllowList_Allow(t *testing.T) { + assert.Equal(t, true, ((*AllowList)(nil)).Allow(ip2int(net.ParseIP("1.1.1.1")))) + + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("0.0.0.0/0"), true) + tree.AddCIDR(getCIDR("10.0.0.0/8"), false) + tree.AddCIDR(getCIDR("10.42.42.0/24"), true) + al := &AllowList{cidrTree: tree} + + assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("1.1.1.1")))) + assert.Equal(t, false, al.Allow(ip2int(net.ParseIP("10.0.0.4")))) + assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("10.42.42.42")))) +} + +func TestAllowList_AllowName(t *testing.T) { + assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0")) + + rules := []AllowListNameRule{ + {Name: regexp.MustCompile("^docker.*$"), Allow: false}, + {Name: regexp.MustCompile("^tun.*$"), Allow: false}, + } + al := &AllowList{nameRules: rules} + + assert.Equal(t, false, al.AllowName("docker0")) + assert.Equal(t, false, al.AllowName("tun0")) + assert.Equal(t, true, al.AllowName("eth0")) + + rules = []AllowListNameRule{ + {Name: regexp.MustCompile("^eth.*$"), Allow: true}, + {Name: regexp.MustCompile("^ens.*$"), Allow: true}, + } + al = &AllowList{nameRules: rules} + + assert.Equal(t, false, al.AllowName("docker0")) + assert.Equal(t, true, al.AllowName("eth0")) + assert.Equal(t, true, al.AllowName("ens5")) +} diff --git a/config.go b/config.go index 8c88f0c..570cd85 100644 --- a/config.go +++ b/config.go @@ -6,9 +6,11 @@ import ( "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" "io/ioutil" + "net" "os" "os/signal" "path/filepath" + "regexp" "sort" "strconv" "strings" @@ -213,6 +215,129 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration { return v } +func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) { + r := c.Get(k) + if r == nil { + return nil, nil + } + + rawMap, ok := r.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r) + } + + tree := NewCIDRTree() + var nameRules []AllowListNameRule + + firstValue := true + allValuesMatch := true + defaultSet := false + var allValues bool + + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + // Special rule for interface names + if rawCIDR == "interfaces" { + if !allowInterfaces { + return nil, fmt.Errorf("config `%s` does not support `interfaces`", k) + } + var err error + nameRules, err = c.getAllowListInterfaces(k, rawValue) + if err != nil { + return nil, err + } + + continue + } + + value, ok := rawValue.(bool) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) + } + + _, cidr, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + // TODO: should we error on duplicate CIDRs in the config? + tree.AddCIDR(cidr, value) + + if firstValue { + allValues = value + firstValue = false + } else { + if value != allValues { + allValuesMatch = false + } + } + + // Check if this is 0.0.0.0/0 + bits, size := cidr.Mask.Size() + if bits == 0 && size == 32 { + defaultSet = true + } + } + + if !defaultSet { + if allValuesMatch { + _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") + tree.AddCIDR(zeroCIDR, !allValues) + } else { + return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) + } + } + + return &AllowList{cidrTree: tree, nameRules: nameRules}, nil +} + +func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { + var nameRules []AllowListNameRule + + rawRules, ok := v.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) + } + + firstEntry := true + var allValues bool + for rawName, rawAllow := range rawRules { + name, ok := rawName.(string) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName) + } + allow, ok := rawAllow.(bool) + if !ok { + return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) + } + + nameRE, err := regexp.Compile("^" + name + "$") + if err != nil { + return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err) + } + + nameRules = append(nameRules, AllowListNameRule{ + Name: nameRE, + Allow: allow, + }) + + if firstEntry { + allValues = allow + firstEntry = false + } else { + if allow != allValues { + return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k) + } + } + } + + return nameRules, nil +} + func (c *Config) Get(k string) interface{} { return c.get(k, c.Settings) } diff --git a/config_test.go b/config_test.go index 0e1036c..91485af 100644 --- a/config_test.go +++ b/config_test.go @@ -86,6 +86,76 @@ func TestConfig_GetBool(t *testing.T) { assert.Equal(t, false, c.GetBool("bool", true)) } +func TestConfig_GetAllowList(t *testing.T) { + c := NewConfig() + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0": true, + } + r, err := c.GetAllowList("allowlist", false) + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.Nil(t, r) + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0/16": "abc", + } + r, err = c.GetAllowList("allowlist", false) + assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "192.168.0.0/16": true, + "10.0.0.0/8": false, + } + r, err = c.GetAllowList("allowlist", false) + assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "0.0.0.0/0": true, + "10.0.0.0/8": false, + "10.42.42.0/24": true, + } + r, err = c.GetAllowList("allowlist", false) + if assert.NoError(t, err) { + assert.NotNil(t, r) + } + + // Test interface names + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: false, + }, + } + r, err = c.GetAllowList("allowlist", false) + assert.EqualError(t, err, "config `allowlist` does not support `interfaces`") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: "foo", + }, + } + r, err = c.GetAllowList("allowlist", true) + assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: false, + `eth.*`: true, + }, + } + r, err = c.GetAllowList("allowlist", true) + assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") + + c.Settings["allowlist"] = map[interface{}]interface{}{ + "interfaces": map[interface{}]interface{}{ + `docker.*`: false, + }, + } + r, err = c.GetAllowList("allowlist", true) + if assert.NoError(t, err) { + assert.NotNil(t, r) + } +} + func TestConfig_HasChanged(t *testing.T) { // No reload has occurred, return false c := NewConfig() diff --git a/examples/config.yml b/examples/config.yml index 2a6aa0d..ef7714b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -40,6 +40,37 @@ lighthouse: hosts: - "192.168.100.1" + # remoteAllowList allows you to control ip ranges that this node will + # consider when handshaking to another node. By default, any remote IPs are + # allowed. You can provide CIDRs here with `true` to allow and `false` to + # deny. The most specific CIDR rule applies to each remote. If all rules are + # "allow", the default will be "deny", and vice-versa. If both "allow" and + # "deny" rules are present, then you MUST set a rule for "0.0.0.0/0" as the + # default. + #remoteAllowList: + # Example to block IPs from this subnet from being used for remote IPs. + #"172.16.0.0/12": false + + # A more complicated example, allow public IPs but only private IPs from a specific subnet + #"0.0.0.0/0": true + #"10.0.0.0/8": false + #"10.42.42.0/24": true + + # localAllowList allows you to filter which local IP addresses we advertise + # to the lighthouses. This uses the same logic as `remoteAllowList`, but + # additionally, you can specify an `interfaces` map of regular expressions + # to match against interface names. The regexp must match the entire name. + # All interface rules must be either true or false (and the default will be + # the inverse). CIDR rules are matched after interface name rules. + # Default is all local IP addresses. + #localAllowList: + # Example to blacklist tun0 and all docker interfaces. + #interfaces: + #tun0: false + #'docker.*': false + # Example to only advertise this subnet to the lighthouse. + #"10.0.0.0/8": true + # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: diff --git a/handshake.go b/handshake.go index c9cf6e2..c6b6332 100644 --- a/handshake.go +++ b/handshake.go @@ -13,6 +13,11 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head // return //} + if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { + l.WithField("udpAddr", addr).Debug("lighthouse.remoteAllowList denied incoming handshake") + return + } + tearDown := false switch h.Subtype { case handshakeIXPSK0: diff --git a/hostmap.go b/hostmap.go index bb4e9fa..a93ffc5 100644 --- a/hostmap.go +++ b/hostmap.go @@ -755,11 +755,16 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) { // Utility functions -func localIps() *[]net.IP { +func localIps(allowList *AllowList) *[]net.IP { //FIXME: This function is pretty garbage var ips []net.IP ifaces, _ := net.Interfaces() for _, i := range ifaces { + allow := allowList.AllowName(i.Name) + l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName") + if !allow { + continue + } addrs, _ := i.Addrs() for _, addr := range addrs { var ip net.IP @@ -771,6 +776,12 @@ func localIps() *[]net.IP { ip = v.IP } if ip.To4() != nil && ip.IsLoopback() == false { + allow := allowList.Allow(ip2int(ip)) + l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow") + if !allow { + continue + } + ips = append(ips, ip) } } diff --git a/lighthouse.go b/lighthouse.go index ad49196..5ccb4a6 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -19,6 +19,16 @@ type LightHouse struct { // Local cache of answers from light houses addrMap map[uint32][]udpAddr + // filters remote addresses allowed for each host + // - When we are a lighthouse, this filters what addresses we store and + // respond with. + // - When we are not a lighthouse, this filters which addresses we accept + // from lighthouses. + remoteAllowList *AllowList + + // filters local addresses that we advertise to lighthouses + localAllowList *AllowList + // staticList exists to avoid having a bool in each addrMap entry // since static should be rare staticList map[uint32]struct{} @@ -55,6 +65,20 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n return &h } +func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) { + lh.Lock() + defer lh.Unlock() + + lh.remoteAllowList = allowList +} + +func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) { + lh.Lock() + defer lh.Unlock() + + lh.localAllowList = allowList +} + func (lh *LightHouse) ValidateLHStaticEntries() error { for lhIP, _ := range lh.lighthouses { if _, ok := lh.staticList[lhIP]; !ok { @@ -135,6 +159,13 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) { return } } + + allow := lh.remoteAllowList.Allow(udp2ipInt(toIp)) + l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow") + if !allow { + return + } + //l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp) if static { lh.staticList[vpnIP] = struct{}{} @@ -203,7 +234,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) { for { ipp := []*IpAndPort{} - for _, e := range *localIps() { + for _, e := range *localIps(lh.localAllowList) { // Only add IPs that aren't my VPN/tun IP if ip2int(e) != lh.myIp { ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)}) diff --git a/main.go b/main.go index e34c51b..98b8425 100644 --- a/main.go +++ b/main.go @@ -228,6 +228,18 @@ func Main(configPath string, configTest bool, buildVersion string) { punchy.Delay, ) + remoteAllowList, err := config.GetAllowList("lighthouse.remoteAllowList", false) + if err != nil { + l.WithError(err).Fatal("Invalid lighthouse.remoteAllowList") + } + lightHouse.SetRemoteAllowList(remoteAllowList) + + localAllowList, err := config.GetAllowList("lighthouse.localAllowList", true) + if err != nil { + l.WithError(err).Fatal("Invalid lighthouse.localAllowList") + } + lightHouse.SetLocalAllowList(localAllowList) + //TODO: Move all of this inside functions in lighthouse.go for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) diff --git a/outside.go b/outside.go index c36bb93..309e406 100644 --- a/outside.go +++ b/outside.go @@ -142,6 +142,10 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { if hostDidRoam(hostinfo.remote, addr) { + if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { + hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remoteAllowList denied roaming") + return + } if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second { if l.Level >= logrus.DebugLevel { hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).