From ea2c186a7776dd5b0413b56bb0e05da60f5db73f Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 19 Oct 2021 10:54:30 -0400 Subject: [PATCH] remote_allow_ranges: allow inside CIDR specific remote_allow_lists (#540) This allows you to configure remote allow lists specific to different subnets of the inside CIDR. Example: remote_allow_ranges: 10.42.42.0/24: 192.168.0.0/16: true This would only allow hosts with a VPN IP in the 10.42.42.0/24 range to have private IPs (and thus don't connect over public IPs). The PR also refactors AllowList into RemoteAllowList and LocalAllowList to make it clearer which methods are allowed on which allow list. --- allow_list.go | 65 ++++++++++++++++++++++++++++- allow_list_test.go | 8 ++-- config.go | 99 ++++++++++++++++++++++++++++++++++++++------- config_test.go | 28 +++++-------- examples/config.yml | 8 ++++ handshake.go | 3 +- handshake_ix.go | 10 +++++ hostmap.go | 2 +- lighthouse.go | 28 ++++++------- lighthouse_test.go | 6 ++- main.go | 4 +- outside.go | 2 +- remote_list.go | 12 +++--- remote_list_test.go | 18 ++++++--- 14 files changed, 223 insertions(+), 70 deletions(-) diff --git a/allow_list.go b/allow_list.go index 1f782b8..97c13a0 100644 --- a/allow_list.go +++ b/allow_list.go @@ -9,6 +9,18 @@ import ( type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny cidrTree *CIDR6Tree +} + +type RemoteAllowList struct { + AllowList *AllowList + + // Inside Range Specific, keys of this tree are inside CIDRs and values + // are *AllowList + insideAllowLists *CIDR6Tree +} + +type LocalAllowList struct { + AllowList *AllowList // To avoid ambiguity, all rules must be true, or all rules must be false. nameRules []AllowListNameRule @@ -61,7 +73,14 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool { } } -func (al *AllowList) AllowName(name string) bool { +func (al *LocalAllowList) Allow(ip net.IP) bool { + if al == nil { + return true + } + return al.AllowList.Allow(ip) +} + +func (al *LocalAllowList) AllowName(name string) bool { if al == nil || len(al.nameRules) == 0 { return true } @@ -75,3 +94,47 @@ func (al *AllowList) AllowName(name string) bool { // If no rules match, return the default, which is the inverse of the rules return !al.nameRules[0].Allow } + +func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { + if al == nil { + return true + } + return al.AllowList.Allow(ip) +} + +func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool { + if !al.getInsideAllowList(vpnIp).Allow(ip) { + return false + } + return al.AllowList.Allow(ip) +} + +func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool { + if al == nil { + return true + } + if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { + return false + } + return al.AllowList.AllowIpV4(ip) +} + +func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool { + if al == nil { + return true + } + if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { + return false + } + return al.AllowList.AllowIpV6(hi, lo) +} + +func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList { + if al.insideAllowLists != nil { + inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + if inside != nil { + return inside.(*AllowList) + } + } + return nil +} diff --git a/allow_list_test.go b/allow_list_test.go index 9c03eb0..2dcc3a1 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -31,14 +31,14 @@ func TestAllowList_Allow(t *testing.T) { assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) } -func TestAllowList_AllowName(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0")) +func TestLocalAllowList_AllowName(t *testing.T) { + assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0")) rules := []AllowListNameRule{ {Name: regexp.MustCompile("^docker.*$"), Allow: false}, {Name: regexp.MustCompile("^tun.*$"), Allow: false}, } - al := &AllowList{nameRules: rules} + al := &LocalAllowList{nameRules: rules} assert.Equal(t, false, al.AllowName("docker0")) assert.Equal(t, false, al.AllowName("tun0")) @@ -48,7 +48,7 @@ func TestAllowList_AllowName(t *testing.T) { {Name: regexp.MustCompile("^eth.*$"), Allow: true}, {Name: regexp.MustCompile("^ens.*$"), Allow: true}, } - al = &AllowList{nameRules: rules} + al = &LocalAllowList{nameRules: rules} assert.Equal(t, false, al.AllowName("docker0")) assert.Equal(t, true, al.AllowName("eth0")) diff --git a/config.go b/config.go index a11b89a..152fd64 100644 --- a/config.go +++ b/config.go @@ -226,19 +226,94 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration { return v } -func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) { +func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) { + var nameRules []AllowListNameRule + handleKey := func(key string, value interface{}) (bool, error) { + if key == "interfaces" { + var err error + nameRules, err = c.getAllowListInterfaces(k, value) + if err != nil { + return false, err + } + + return true, nil + } + return false, nil + } + + al, err := c.GetAllowList(k, handleKey) + if err != nil { + return nil, err + } + return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil +} + +func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) { + al, err := c.GetAllowList(k, nil) + if err != nil { + return nil, err + } + remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey) + if err != nil { + return nil, err + } + return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil +} + +func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + remoteAllowRanges := NewCIDR6Tree() + + rawMap, ok := value.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) + if err != nil { + return nil, err + } + + _, cidr, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + remoteAllowRanges.AddCIDR(cidr, allowList) + } + + return remoteAllowRanges, nil +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { r := c.Get(k) if r == nil { return nil, nil } - rawMap, ok := r.(map[interface{}]interface{}) + return c.getAllowList(k, r, handleKey) +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { + rawMap, ok := raw.(map[interface{}]interface{}) if !ok { - return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r) + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } tree := NewCIDR6Tree() - var nameRules []AllowListNameRule // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -256,18 +331,14 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - // Special rule for interface names - if rawCIDR == "interfaces" { - if !allowInterfaces { - return nil, fmt.Errorf("config `%s` does not support `interfaces`", k) - } - var err error - nameRules, err = c.getAllowListInterfaces(k, rawValue) + if handleKey != nil { + handled, err := handleKey(rawCIDR, rawValue) if err != nil { return nil, err } - - continue + if handled { + continue + } } value, ok := rawValue.(bool) @@ -325,7 +396,7 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error } } - return &AllowList{cidrTree: tree, nameRules: nameRules}, nil + return &AllowList{cidrTree: tree}, nil } func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { diff --git a/config_test.go b/config_test.go index 5a1aea4..84848b8 100644 --- a/config_test.go +++ b/config_test.go @@ -97,21 +97,21 @@ func TestConfig_GetAllowList(t *testing.T) { c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0": true, } - r, err := c.GetAllowList("allowlist", false) + r, err := c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": "abc", } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": true, "10.0.0.0/8": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -121,7 +121,7 @@ func TestConfig_GetAllowList(t *testing.T) { "fd00::/8": true, "fd00:fd00::/16": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -129,7 +129,7 @@ func TestConfig_GetAllowList(t *testing.T) { "10.0.0.0/8": false, "10.42.42.0/24": true, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } @@ -142,27 +142,19 @@ func TestConfig_GetAllowList(t *testing.T) { "fd00::/8": true, "fd00:fd00::/16": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } // Test interface names - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ - `docker.*`: false, - }, - } - r, err = c.GetAllowList("allowlist", false) - assert.EqualError(t, err, "config `allowlist` does not support `interfaces`") - c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ `docker.*`: "foo", }, } - r, err = c.GetAllowList("allowlist", true) + lr, err := c.GetLocalAllowList("allowlist") assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -171,7 +163,7 @@ func TestConfig_GetAllowList(t *testing.T) { `eth.*`: true, }, } - r, err = c.GetAllowList("allowlist", true) + lr, err = c.GetLocalAllowList("allowlist") assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -179,9 +171,9 @@ func TestConfig_GetAllowList(t *testing.T) { `docker.*`: false, }, } - r, err = c.GetAllowList("allowlist", true) + lr, err = c.GetLocalAllowList("allowlist") if assert.NoError(t, err) { - assert.NotNil(t, r) + assert.NotNil(t, lr) } } diff --git a/examples/config.yml b/examples/config.yml index 08dfde7..baa4a1c 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -56,6 +56,14 @@ lighthouse: #"10.0.0.0/8": false #"10.42.42.0/24": true + # EXPERIMENTAL: This option my change or disappear in the future. + # Optionally allows the definition of remote_allow_list blocks + # specific to an inside VPN IP CIDR. + #remote_allow_ranges: + # This rule would only allow only private IPs for this VPN range + #"10.42.42.0/24": + #"192.168.0.0/16": true + # local_allow_list allows you to filter which local IP addresses we advertise # to the lighthouses. This uses the same logic as `remote_allow_list`, but # additionally, you can specify an `interfaces` map of regular expressions diff --git a/handshake.go b/handshake.go index a703ff8..8d8aef0 100644 --- a/handshake.go +++ b/handshake.go @@ -6,7 +6,8 @@ const ( ) func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { - if !f.lightHouse.remoteAllowList.Allow(addr.IP) { + // First remote allow list check before we know the vpnIp + if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) { f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } diff --git a/handshake_ix.go b/handshake_ix.go index 7fcdfb7..46fd1ec 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -111,6 +111,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } + if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) { + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + myIndex, err := generateIndex(f.l) if err != nil { f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). @@ -313,6 +318,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hostinfo.Lock() defer hostinfo.Unlock() + if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { + f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return false + } + ci := hostinfo.ConnectionState if ci.ready { f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). diff --git a/hostmap.go b/hostmap.go index 25982c0..c2b520e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -622,7 +622,7 @@ func (i *HostInfo) Probes() []*Probe { // Utility functions -func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { //FIXME: This function is pretty garbage var ips []net.IP ifaces, _ := net.Interfaces() diff --git a/lighthouse.go b/lighthouse.go index 7cbf411..56e2851 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -35,10 +35,10 @@ type LightHouse struct { // respond with. // - When we are not a lighthouse, this filters which addresses we accept // from lighthouses. - remoteAllowList *AllowList + remoteAllowList *RemoteAllowList // filters local addresses that we advertise to lighthouses - localAllowList *AllowList + localAllowList *LocalAllowList // used to trigger the HandshakeManager when we receive HostQueryReply handshakeTrigger chan<- uint32 @@ -93,14 +93,14 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i return &h } -func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) { +func (lh *LightHouse) SetRemoteAllowList(allowList *RemoteAllowList) { lh.Lock() defer lh.Unlock() lh.remoteAllowList = allowList } -func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) { +func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) { lh.Lock() defer lh.Unlock() @@ -223,14 +223,14 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { if ipv4 := toAddr.IP.To4(); ipv4 != nil { to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV4(to) { + if !lh.unlockedShouldAddV4(vpnIp, to) { return } am.unlockedPrependV4(lh.myVpnIp, to) } else { to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV6(to) { + if !lh.unlockedShouldAddV6(vpnIp, to) { return } am.unlockedPrependV6(lh.myVpnIp, to) @@ -251,8 +251,8 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool { - allow := lh.remoteAllowList.AllowIpV4(to.Ip) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool { + allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -265,8 +265,8 @@ func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool { } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool { - allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool { + allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -549,8 +549,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -581,8 +581,8 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() } diff --git a/lighthouse_test.go b/lighthouse_test.go index da4e22d..fcd9cc2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -75,24 +75,26 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") lh.addrMap[3] = NewRemoteList() lh.addrMap[3].unlockedSetV4( + 3, 3, []*Ip4AndPort{ NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rAddr := NewUDPAddrFromString("1.2.2.3:12345") rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") lh.addrMap[2] = NewRemoteList() lh.addrMap[2].unlockedSetV4( + 3, 3, []*Ip4AndPort{ NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} diff --git a/main.go b/main.go index a775599..67d4b51 100644 --- a/main.go +++ b/main.go @@ -278,13 +278,13 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L config.GetBool("stats.lighthouse_metrics", false), ) - remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) + remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) - localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) + localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list") if err != nil { return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } diff --git a/outside.go b/outside.go index aad085d..57a00dc 100644 --- a/outside.go +++ b/outside.go @@ -153,7 +153,7 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { if hostDidRoam(hostinfo.remote, addr) { - if !f.lightHouse.remoteAllowList.Allow(addr.IP) { + if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") return } diff --git a/remote_list.go b/remote_list.go index e88a8b0..7c3e716 100644 --- a/remote_list.go +++ b/remote_list.go @@ -11,8 +11,8 @@ import ( type forEachFunc func(addr *udpAddr, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(to *Ip4AndPort) bool -type checkFuncV6 func(to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -246,7 +246,7 @@ func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) { // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -255,7 +255,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check ch // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { - if check(v) { + if check(vpnIp, v) { c.reported = append(c.reported, v) } } @@ -283,7 +283,7 @@ func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) { // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -292,7 +292,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check ch // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { - if check(v) { + if check(vpnIp, v) { c.reported = append(c.reported, v) } } diff --git a/remote_list_test.go b/remote_list_test.go index bceb16c..05d3887 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -10,6 +10,7 @@ import ( func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped @@ -23,10 +24,11 @@ func TestRemoteList_Rebuild(t *testing.T) { {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 1, 1, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped @@ -35,7 +37,7 @@ func TestRemoteList_Rebuild(t *testing.T) { NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) rl.Rebuild([]*net.IPNet{}) @@ -101,6 +103,7 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, @@ -112,10 +115,11 @@ func BenchmarkFullRebuild(b *testing.B) { {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 0, 0, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), @@ -123,7 +127,7 @@ func BenchmarkFullRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -164,6 +168,7 @@ func BenchmarkFullRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, @@ -175,10 +180,11 @@ func BenchmarkSortRebuild(b *testing.B) { {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 0, 0, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), @@ -186,7 +192,7 @@ func BenchmarkSortRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) {