diff --git a/allow_list.go b/allow_list.go index 1f782b8..97c13a0 100644 --- a/allow_list.go +++ b/allow_list.go @@ -9,6 +9,18 @@ import ( type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny cidrTree *CIDR6Tree +} + +type RemoteAllowList struct { + AllowList *AllowList + + // Inside Range Specific, keys of this tree are inside CIDRs and values + // are *AllowList + insideAllowLists *CIDR6Tree +} + +type LocalAllowList struct { + AllowList *AllowList // To avoid ambiguity, all rules must be true, or all rules must be false. nameRules []AllowListNameRule @@ -61,7 +73,14 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool { } } -func (al *AllowList) AllowName(name string) bool { +func (al *LocalAllowList) Allow(ip net.IP) bool { + if al == nil { + return true + } + return al.AllowList.Allow(ip) +} + +func (al *LocalAllowList) AllowName(name string) bool { if al == nil || len(al.nameRules) == 0 { return true } @@ -75,3 +94,47 @@ func (al *AllowList) AllowName(name string) bool { // If no rules match, return the default, which is the inverse of the rules return !al.nameRules[0].Allow } + +func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { + if al == nil { + return true + } + return al.AllowList.Allow(ip) +} + +func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool { + if !al.getInsideAllowList(vpnIp).Allow(ip) { + return false + } + return al.AllowList.Allow(ip) +} + +func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool { + if al == nil { + return true + } + if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { + return false + } + return al.AllowList.AllowIpV4(ip) +} + +func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool { + if al == nil { + return true + } + if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { + return false + } + return al.AllowList.AllowIpV6(hi, lo) +} + +func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList { + if al.insideAllowLists != nil { + inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + if inside != nil { + return inside.(*AllowList) + } + } + return nil +} diff --git a/allow_list_test.go b/allow_list_test.go index 9c03eb0..2dcc3a1 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -31,14 +31,14 @@ func TestAllowList_Allow(t *testing.T) { assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) } -func TestAllowList_AllowName(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0")) +func TestLocalAllowList_AllowName(t *testing.T) { + assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0")) rules := []AllowListNameRule{ {Name: regexp.MustCompile("^docker.*$"), Allow: false}, {Name: regexp.MustCompile("^tun.*$"), Allow: false}, } - al := &AllowList{nameRules: rules} + al := &LocalAllowList{nameRules: rules} assert.Equal(t, false, al.AllowName("docker0")) assert.Equal(t, false, al.AllowName("tun0")) @@ -48,7 +48,7 @@ func TestAllowList_AllowName(t *testing.T) { {Name: regexp.MustCompile("^eth.*$"), Allow: true}, {Name: regexp.MustCompile("^ens.*$"), Allow: true}, } - al = &AllowList{nameRules: rules} + al = &LocalAllowList{nameRules: rules} assert.Equal(t, false, al.AllowName("docker0")) assert.Equal(t, true, al.AllowName("eth0")) diff --git a/config.go b/config.go index a11b89a..152fd64 100644 --- a/config.go +++ b/config.go @@ -226,19 +226,94 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration { return v } -func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) { +func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) { + var nameRules []AllowListNameRule + handleKey := func(key string, value interface{}) (bool, error) { + if key == "interfaces" { + var err error + nameRules, err = c.getAllowListInterfaces(k, value) + if err != nil { + return false, err + } + + return true, nil + } + return false, nil + } + + al, err := c.GetAllowList(k, handleKey) + if err != nil { + return nil, err + } + return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil +} + +func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) { + al, err := c.GetAllowList(k, nil) + if err != nil { + return nil, err + } + remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey) + if err != nil { + return nil, err + } + return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil +} + +func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + remoteAllowRanges := NewCIDR6Tree() + + rawMap, ok := value.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) + if err != nil { + return nil, err + } + + _, cidr, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + remoteAllowRanges.AddCIDR(cidr, allowList) + } + + return remoteAllowRanges, nil +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { r := c.Get(k) if r == nil { return nil, nil } - rawMap, ok := r.(map[interface{}]interface{}) + return c.getAllowList(k, r, handleKey) +} + +// If the handleKey func returns true, the rest of the parsing is skipped +// for this key. This allows parsing of special values like `interfaces`. +func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { + rawMap, ok := raw.(map[interface{}]interface{}) if !ok { - return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r) + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } tree := NewCIDR6Tree() - var nameRules []AllowListNameRule // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -256,18 +331,14 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - // Special rule for interface names - if rawCIDR == "interfaces" { - if !allowInterfaces { - return nil, fmt.Errorf("config `%s` does not support `interfaces`", k) - } - var err error - nameRules, err = c.getAllowListInterfaces(k, rawValue) + if handleKey != nil { + handled, err := handleKey(rawCIDR, rawValue) if err != nil { return nil, err } - - continue + if handled { + continue + } } value, ok := rawValue.(bool) @@ -325,7 +396,7 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error } } - return &AllowList{cidrTree: tree, nameRules: nameRules}, nil + return &AllowList{cidrTree: tree}, nil } func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { diff --git a/config_test.go b/config_test.go index 5a1aea4..84848b8 100644 --- a/config_test.go +++ b/config_test.go @@ -97,21 +97,21 @@ func TestConfig_GetAllowList(t *testing.T) { c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0": true, } - r, err := c.GetAllowList("allowlist", false) + r, err := c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": "abc", } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": true, "10.0.0.0/8": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -121,7 +121,7 @@ func TestConfig_GetAllowList(t *testing.T) { "fd00::/8": true, "fd00:fd00::/16": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -129,7 +129,7 @@ func TestConfig_GetAllowList(t *testing.T) { "10.0.0.0/8": false, "10.42.42.0/24": true, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } @@ -142,27 +142,19 @@ func TestConfig_GetAllowList(t *testing.T) { "fd00::/8": true, "fd00:fd00::/16": false, } - r, err = c.GetAllowList("allowlist", false) + r, err = c.GetAllowList("allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } // Test interface names - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ - `docker.*`: false, - }, - } - r, err = c.GetAllowList("allowlist", false) - assert.EqualError(t, err, "config `allowlist` does not support `interfaces`") - c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ `docker.*`: "foo", }, } - r, err = c.GetAllowList("allowlist", true) + lr, err := c.GetLocalAllowList("allowlist") assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -171,7 +163,7 @@ func TestConfig_GetAllowList(t *testing.T) { `eth.*`: true, }, } - r, err = c.GetAllowList("allowlist", true) + lr, err = c.GetLocalAllowList("allowlist") assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -179,9 +171,9 @@ func TestConfig_GetAllowList(t *testing.T) { `docker.*`: false, }, } - r, err = c.GetAllowList("allowlist", true) + lr, err = c.GetLocalAllowList("allowlist") if assert.NoError(t, err) { - assert.NotNil(t, r) + assert.NotNil(t, lr) } } diff --git a/examples/config.yml b/examples/config.yml index 08dfde7..baa4a1c 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -56,6 +56,14 @@ lighthouse: #"10.0.0.0/8": false #"10.42.42.0/24": true + # EXPERIMENTAL: This option my change or disappear in the future. + # Optionally allows the definition of remote_allow_list blocks + # specific to an inside VPN IP CIDR. + #remote_allow_ranges: + # This rule would only allow only private IPs for this VPN range + #"10.42.42.0/24": + #"192.168.0.0/16": true + # local_allow_list allows you to filter which local IP addresses we advertise # to the lighthouses. This uses the same logic as `remote_allow_list`, but # additionally, you can specify an `interfaces` map of regular expressions diff --git a/handshake.go b/handshake.go index a703ff8..8d8aef0 100644 --- a/handshake.go +++ b/handshake.go @@ -6,7 +6,8 @@ const ( ) func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { - if !f.lightHouse.remoteAllowList.Allow(addr.IP) { + // First remote allow list check before we know the vpnIp + if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) { f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } diff --git a/handshake_ix.go b/handshake_ix.go index 7fcdfb7..46fd1ec 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -111,6 +111,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } + if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) { + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + myIndex, err := generateIndex(f.l) if err != nil { f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). @@ -313,6 +318,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ hostinfo.Lock() defer hostinfo.Unlock() + if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { + f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return false + } + ci := hostinfo.ConnectionState if ci.ready { f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). diff --git a/hostmap.go b/hostmap.go index 25982c0..c2b520e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -622,7 +622,7 @@ func (i *HostInfo) Probes() []*Probe { // Utility functions -func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { //FIXME: This function is pretty garbage var ips []net.IP ifaces, _ := net.Interfaces() diff --git a/lighthouse.go b/lighthouse.go index 7cbf411..56e2851 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -35,10 +35,10 @@ type LightHouse struct { // respond with. // - When we are not a lighthouse, this filters which addresses we accept // from lighthouses. - remoteAllowList *AllowList + remoteAllowList *RemoteAllowList // filters local addresses that we advertise to lighthouses - localAllowList *AllowList + localAllowList *LocalAllowList // used to trigger the HandshakeManager when we receive HostQueryReply handshakeTrigger chan<- uint32 @@ -93,14 +93,14 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i return &h } -func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) { +func (lh *LightHouse) SetRemoteAllowList(allowList *RemoteAllowList) { lh.Lock() defer lh.Unlock() lh.remoteAllowList = allowList } -func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) { +func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) { lh.Lock() defer lh.Unlock() @@ -223,14 +223,14 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { if ipv4 := toAddr.IP.To4(); ipv4 != nil { to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV4(to) { + if !lh.unlockedShouldAddV4(vpnIp, to) { return } am.unlockedPrependV4(lh.myVpnIp, to) } else { to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV6(to) { + if !lh.unlockedShouldAddV6(vpnIp, to) { return } am.unlockedPrependV6(lh.myVpnIp, to) @@ -251,8 +251,8 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool { - allow := lh.remoteAllowList.AllowIpV4(to.Ip) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool { + allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -265,8 +265,8 @@ func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool { } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool { - allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool { + allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -549,8 +549,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -581,8 +581,8 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.Unlock() } diff --git a/lighthouse_test.go b/lighthouse_test.go index da4e22d..fcd9cc2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -75,24 +75,26 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") lh.addrMap[3] = NewRemoteList() lh.addrMap[3].unlockedSetV4( + 3, 3, []*Ip4AndPort{ NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rAddr := NewUDPAddrFromString("1.2.2.3:12345") rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") lh.addrMap[2] = NewRemoteList() lh.addrMap[2].unlockedSetV4( + 3, 3, []*Ip4AndPort{ NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} diff --git a/main.go b/main.go index a775599..67d4b51 100644 --- a/main.go +++ b/main.go @@ -278,13 +278,13 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L config.GetBool("stats.lighthouse_metrics", false), ) - remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) + remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) - localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) + localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list") if err != nil { return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } diff --git a/outside.go b/outside.go index aad085d..57a00dc 100644 --- a/outside.go +++ b/outside.go @@ -153,7 +153,7 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { if hostDidRoam(hostinfo.remote, addr) { - if !f.lightHouse.remoteAllowList.Allow(addr.IP) { + if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) { hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") return } diff --git a/remote_list.go b/remote_list.go index e88a8b0..7c3e716 100644 --- a/remote_list.go +++ b/remote_list.go @@ -11,8 +11,8 @@ import ( type forEachFunc func(addr *udpAddr, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(to *Ip4AndPort) bool -type checkFuncV6 func(to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -246,7 +246,7 @@ func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) { // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -255,7 +255,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check ch // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { - if check(v) { + if check(vpnIp, v) { c.reported = append(c.reported, v) } } @@ -283,7 +283,7 @@ func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) { // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -292,7 +292,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check ch // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { - if check(v) { + if check(vpnIp, v) { c.reported = append(c.reported, v) } } diff --git a/remote_list_test.go b/remote_list_test.go index bceb16c..05d3887 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -10,6 +10,7 @@ import ( func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped @@ -23,10 +24,11 @@ func TestRemoteList_Rebuild(t *testing.T) { {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 1, 1, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped @@ -35,7 +37,7 @@ func TestRemoteList_Rebuild(t *testing.T) { NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) rl.Rebuild([]*net.IPNet{}) @@ -101,6 +103,7 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, @@ -112,10 +115,11 @@ func BenchmarkFullRebuild(b *testing.B) { {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 0, 0, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), @@ -123,7 +127,7 @@ func BenchmarkFullRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -164,6 +168,7 @@ func BenchmarkFullRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList() rl.unlockedSetV4( + 0, 0, []*Ip4AndPort{ {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, @@ -175,10 +180,11 @@ func BenchmarkSortRebuild(b *testing.B) { {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port }, - func(*Ip4AndPort) bool { return true }, + func(uint32, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( + 0, 0, []*Ip6AndPort{ NewIp6AndPort(net.ParseIP("1::1"), 1), @@ -186,7 +192,7 @@ func BenchmarkSortRebuild(b *testing.B) { NewIp6AndPort(net.ParseIP("1:100::1"), 1), NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe }, - func(*Ip6AndPort) bool { return true }, + func(uint32, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) {