diff --git a/control.go b/control.go index 089e8ac..d7a1c1f 100644 --- a/control.go +++ b/control.go @@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() { // ListHostmap returns details about the actual or pending (handshaking) hostmap func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { - var hm *HostMap if pendingMap { - hm = c.f.handshakeManager.pendingHostMap + return listHostMap(c.f.handshakeManager.pendingHostMap) } else { - hm = c.f.hostMap + return listHostMap(c.f.hostMap) } - - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Hosts)) - i := 0 - for _, v := range hm.Hosts { - hosts[i] = copyHostInfo(v) - i++ - } - hm.RUnlock() - - return hosts } // GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found @@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf return nil } - ch := copyHostInfo(h) + ch := copyHostInfo(h, c.f.hostMap.preferredRanges) return &ch } @@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf } hostInfo.SetRemote(addr.Copy()) - ch := copyHostInfo(hostInfo) + ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges) return &ch } @@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { return } -func copyHostInfo(h *HostInfo) ControlHostInfo { +func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi := ControlHostInfo{ - VpnIP: int2ip(h.hostId), - LocalIndex: h.localIndexId, - RemoteIndex: h.remoteIndexId, - RemoteAddrs: h.CopyRemotes(), - CachedPackets: len(h.packetStore), - MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter), + VpnIP: int2ip(h.hostId), + LocalIndex: h.localIndexId, + RemoteIndex: h.remoteIndexId, + RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), + CachedPackets: len(h.packetStore), + } + + if h.ConnectionState != nil { + chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter) } if c := h.GetCert(); c != nil { @@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo { return chi } + +func listHostMap(hm *HostMap) []ControlHostInfo { + hm.RLock() + hosts := make([]ControlHostInfo, len(hm.Hosts)) + i := 0 + for _, v := range hm.Hosts { + hosts[i] = copyHostInfo(v, hm.preferredRanges) + i++ + } + hm.RUnlock() + + return hosts +} diff --git a/control_test.go b/control_test.go index 9dc461f..5679ce6 100644 --- a/control_test.go +++ b/control_test.go @@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { Signature: []byte{1, 2, 1, 2, 1, 3}, } - remotes := []*udpAddr{remote1, remote2} + remotes := NewRemoteList() + remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) + remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) hm.Add(ip2int(ipNet.IP), &HostInfo{ remote: remote1, - Remotes: remotes, + remotes: remotes, ConnectionState: &ConnectionState{ peerCert: crt, }, @@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { hm.Add(ip2int(ipNet2.IP), &HostInfo{ remote: remote1, - Remotes: remotes, + remotes: remotes, ConnectionState: &ConnectionState{ peerCert: nil, }, @@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { VpnIP: net.IPv4(1, 2, 3, 4).To4(), LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udpAddr{remote1, remote2}, + RemoteAddrs: []*udpAddr{remote2, remote1}, CachedPackets: 0, Cert: crt.Copy(), MessageCounter: 0, diff --git a/control_tester.go b/control_tester.go index 574682a..ff79141 100644 --- a/control_tester.go +++ b/control_tester.go @@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { - c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false) + c.f.lightHouse.Lock() + remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp)) + remoteList.Lock() + defer remoteList.Unlock() + c.f.lightHouse.Unlock() + + iVpnIp := ip2int(vpnIp) + if v4 := toAddr.IP.To4(); v4 != nil { + remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + } else { + remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + } } // GetFromTun will pull a packet off the tun side of nebula @@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(&ip) + if err != nil { + panic(err) + } buffer := gopacket.NewSerializeBuffer() opt := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } - err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) if err != nil { panic(err) } @@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 func (c *Control) GetUDPAddr() string { return c.f.outside.addr.String() } + +func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { + hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)] + if !ok { + return false + } + + c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + return true +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 75ea2de..07920fe 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" ) func TestGoodHandshake(t *testing.T) { @@ -23,35 +24,35 @@ func TestGoodHandshake(t *testing.T) { myControl.Start() theirControl.Start() - // Send a udp packet through to begin standing up the tunnel, this should come out the other side + t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) - // Have them consume my stage 0 packet. They have a tunnel now + t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) - // Get their stage 1 packet so that we can play with it + t.Log("Get their stage 1 packet so that we can play with it") stage1Packet := theirControl.GetFromUDP(true) - // I consume a garbage packet with a proper nebula header for our tunnel + t.Log("I consume a garbage packet with a proper nebula header for our tunnel") // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel badPacket := stage1Packet.Copy() badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen] myControl.InjectUDPPacket(badPacket) - // Have me consume their real stage 1 packet. I have a tunnel now + t.Log("Have me consume their real stage 1 packet. I have a tunnel now") myControl.InjectUDPPacket(stage1Packet) - // Wait until we see my cached packet come through + t.Log("Wait until we see my cached packet come through") myControl.WaitForType(1, 0, theirControl) - // Make sure our host infos are correct + t.Log("Make sure our host infos are correct") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) - // Get that cached packet and make sure it looks right + t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) - // Do a bidirectional tunnel test + t.Log("Do a bidirectional tunnel test") assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl)) myControl.Stop() @@ -62,14 +63,17 @@ func TestGoodHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) - evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99}) + // The IPs here are chosen on purpose: + // The current remote handling will sort by preference, public, and then lexically. + // So we need them to have a higher address than evil (we could apply a preference though) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}) + evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}) - // Add their real udp addr, which should be tried after evil. Doing this first because learned addresses are prepended + // Add their real udp addr, which should be tried after evil. myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) - // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. This will now be the first attempted ip + // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) // Build a router so we don't have to reason who gets which packet @@ -80,137 +84,98 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Start() evilControl.Start() - t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") + t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) - r.OnceFrom(myControl) - r.OnceFrom(evilControl) + r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType { + h := &nebula.Header{} + err := h.Parse(p.Data) + if err != nil { + panic(err) + } - t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us") - assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) - assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl) + if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) //TODO: Assert pending hostmap - I should have a correct hostinfo for them now - t.Log("Lets let the messages fly, this time we should have a tunnel with them") - r.OnceFrom(myControl) - r.OnceFrom(theirControl) - - t.Log("I should now have a tunnel with them now and my original packet should get there") - r.RouteUntilAfterMsgType(myControl, 1, 0) + t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) - t.Log("I should now have a proper tunnel with them") + t.Log("Test the tunnel with them") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) - t.Log("Lets make sure evil is still good") - assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) + t.Log("Flush all packets from all controllers") + r.FlushAll() + + t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") + assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil") + //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone t.Log("Success!") - //TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message - // what we really need here is a way to exit all the go routines loops (there are many) - //myControl.Stop() - //theirControl.Stop() + myControl.Stop() + theirControl.Stop() } -////TODO: We need to test lies both as the race winner and race loser -//func TestManyWrongResponderHandshake(t *testing.T) { -// ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) -// -// myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99}) -// theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) -// evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1}) -// -// t.Log("Build a router so we don't have to reason who gets which packet") -// r := newRouter(myControl, theirControl, evilControl) -// -// t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit") -// for i := 0; i < 10; i++ { -// addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i} -// myControl.InjectLightHouseAddr(theirVpnIp, &addr) -// // We also need to tell our router about it -// r.AddRoute(addr.IP, uint16(addr.Port), evilControl) -// } -// -// // Start the servers -// myControl.Start() -// theirControl.Start() -// evilControl.Start() -// -// t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") -// myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) -// -// t.Log("We need to spin until we get to the right remote for them") -// getOut := false -// injected := false -// for { -// t.Log("Routing for me and evil while we work through the bad ips") -// r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType { -// // We should stop routing right after we see a packet coming from us to them -// if *receiver == *theirControl { -// getOut = true -// return drainAndExit -// } -// -// // We need to poke our real ip in at some point, this is a well protected check looking for that moment -// if *receiver == *evilControl { -// hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true) -// if !injected && len(hi.RemoteAddrs) == 1 { -// t.Log("I am on my last ip for them, time to inject the real one into my lighthouse") -// myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) -// injected = true -// } -// return drainAndExit -// } -// -// return keepRouting -// }) -// -// if getOut { -// break -// } -// -// r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit) -// } -// -// t.Log("I should have a tunnel with evil and them, evil should not have a cached packet") -// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) -// evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false) -// realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)} -// -// t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr) -// assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl) -// -// //t.Log("Draining everyones packets") -// //r.Drain(theirControl) -// //r.DrainAll(myControl, theirControl, evilControl) -// // -// //go func() { -// // for { -// // time.Sleep(10 * time.Millisecond) -// // t.Log(len(theirControl.GetUDPTxChan())) -// // t.Log(len(theirControl.GetTunTxChan())) -// // t.Log(len(myControl.GetUDPTxChan())) -// // t.Log(len(evilControl.GetUDPTxChan())) -// // t.Log("=====") -// // } -// //}() -// -// t.Log("I should have a tunnel with them now and my original packet should get there") -// r.RouteUntilAfterMsgType(myControl, 1, 0) -// myCachedPacket := theirControl.GetFromTun(true) -// -// t.Log("Got the cached packet, lets test the tunnel") -// assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) -// -// t.Log("Testing tunnels with them") -// assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) -// assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) -// -// t.Log("Testing tunnels with evil") -// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) -// -// //TODO: assert hostmaps for everyone -//} +func Test_Case1_Stage1Race(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) + + // Put their info in our lighthouse and vice versa + myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(myControl, theirControl) + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake to start on both me and them") + myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them")) + + t.Log("Get both stage 1 handshake packets") + myHsForThem := myControl.GetFromUDP(true) + theirHsForMe := theirControl.GetFromUDP(true) + + t.Log("Now inject both stage 1 handshake packets") + myControl.InjectUDPPacket(theirHsForMe) + theirControl.InjectUDPPacket(myHsForThem) + //TODO: they should win, grab their index for me and make sure I use it in the end. + + t.Log("They should not have a stage 2 (won the race) but I should send one") + theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) + + t.Log("Route for me until I send a message packet to them") + myControl.WaitForType(1, 0, theirControl) + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + + t.Log("Route for them until I send a message packet to me") + theirControl.WaitForType(1, 0, myControl) + + t.Log("Their cached packet should be received by me") + theirCachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80) + + t.Log("Do a bidirectional tunnel test") + assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + //TODO: assert hostmaps +} + +//TODO: add a test with many lies diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 14cd741..1c31d67 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -64,6 +64,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u "host": "any", }}, }, + //"handshakes": m{ + // "try_interval": "1s", + //}, "listen": m{ "host": udpAddr.IP.String(), "port": udpAddr.Port, diff --git a/e2e/router/doc.go b/e2e/router/doc.go new file mode 100644 index 0000000..94e868c --- /dev/null +++ b/e2e/router/doc.go @@ -0,0 +1,3 @@ +package router + +// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before diff --git a/e2e/router/router.go b/e2e/router/router.go index 0cf486c..e656c16 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -5,6 +5,7 @@ package router import ( "fmt" "net" + "reflect" "strconv" "sync" @@ -28,18 +29,18 @@ type R struct { sync.Mutex } -type exitType int +type ExitType int const ( // Keeps routing, the function will get called again on the next packet - keepRouting exitType = 0 + KeepRouting ExitType = 0 // Does not route this packet and exits immediately - exitNow exitType = 1 + ExitNow ExitType = 1 // Routes this packet and exits immediately afterwards - routeAndExit exitType = 2 + RouteAndExit ExitType = 2 ) -type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType +type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType func NewR(controls ...*nebula.Control) *R { r := &R{ @@ -77,8 +78,8 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { // OnceFrom will route a single packet from sender then return // If the router doesn't have the nebula controller for that address, we panic func (r *R) OnceFrom(sender *nebula.Control) { - r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType { - return routeAndExit + r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType { + return RouteAndExit }) } @@ -116,7 +117,6 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // - exitNow: the packet will not be routed and this call will return immediately // - routeAndExit: this call will return immediately after routing the last packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender -//TODO: is this RouteWhile? func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { h := &nebula.Header{} for { @@ -136,16 +136,16 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { e := whatDo(p, receiver) switch e { - case exitNow: + case ExitNow: r.Unlock() return - case routeAndExit: + case RouteAndExit: receiver.InjectUDPPacket(p) r.Unlock() return - case keepRouting: + case KeepRouting: receiver.InjectUDPPacket(p) default: @@ -160,35 +160,135 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { // If the router doesn't have the nebula controller for that address, we panic func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) { h := &nebula.Header{} - r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { + r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { if err := h.Parse(p.Data); err != nil { panic(err) } if h.Type == msgType && h.Subtype == subType { - return routeAndExit + return RouteAndExit } - return keepRouting + return KeepRouting }) } // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) { - if finish == keepRouting { - finish = routeAndExit +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { + if finish == KeepRouting { + finish = RouteAndExit } - r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { + r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType { if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { return finish } - return keepRouting + return KeepRouting }) } +// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from +// whatDo can return: +// - exitNow: the packet will not be routed and this call will return immediately +// - routeAndExit: this call will return immediately after routing the last packet from sender +// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender +func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { + sc := make([]reflect.SelectCase, len(r.controls)) + cm := make([]*nebula.Control, len(r.controls)) + + i := 0 + for _, c := range r.controls { + sc[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(c.GetUDPTxChan()), + Send: reflect.Value{}, + } + + cm[i] = c + i++ + } + + for { + x, rx, _ := reflect.Select(sc) + r.Lock() + + p := rx.Interface().(*nebula.UdpPacket) + + outAddr := cm[x].GetUDPAddr() + inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) + receiver := r.getControl(outAddr, inAddr, p) + if receiver == nil { + r.Unlock() + panic("Can't route for host: " + inAddr) + } + + e := whatDo(p, receiver) + switch e { + case ExitNow: + r.Unlock() + return + + case RouteAndExit: + receiver.InjectUDPPacket(p) + r.Unlock() + return + + case KeepRouting: + receiver.InjectUDPPacket(p) + + default: + panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) + } + r.Unlock() + } +} + +// FlushAll will route for every registered controller, exiting once there are no packets left to route +func (r *R) FlushAll() { + sc := make([]reflect.SelectCase, len(r.controls)) + cm := make([]*nebula.Control, len(r.controls)) + + i := 0 + for _, c := range r.controls { + sc[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(c.GetUDPTxChan()), + Send: reflect.Value{}, + } + + cm[i] = c + i++ + } + + // Add a default case to exit when nothing is left to send + sc = append(sc, reflect.SelectCase{ + Dir: reflect.SelectDefault, + Chan: reflect.Value{}, + Send: reflect.Value{}, + }) + + for { + x, rx, ok := reflect.Select(sc) + if !ok { + return + } + r.Lock() + + p := rx.Interface().(*nebula.UdpPacket) + + outAddr := cm[x].GetUDPAddr() + inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) + receiver := r.getControl(outAddr, inAddr, p) + if receiver == nil { + r.Unlock() + panic("Can't route for host: " + inAddr) + } + r.Unlock() + } +} + // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control { @@ -216,6 +316,5 @@ func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Con return c } - //TODO: call receive hooks! return r.controls[toAddr] } diff --git a/examples/config.yml b/examples/config.yml index 768742d..7d4cf23 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -202,16 +202,16 @@ logging: # Handshake Manger Settings #handshakes: - # Total time to try a handshake = sequence of `try_interval * retries` - # With 100ms interval and 20 retries it is 23.5 seconds + # Handshakes are sent to all known addresses at each interval with a linear backoff, + # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout + # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 - # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses - #wait_rotation: 5 # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 + # Nebula security group configuration firewall: conntrack: diff --git a/handshake_ix.go b/handshake_ix.go index 0e2f53f..de7a84c 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -14,14 +14,10 @@ import ( // Sending is done by the handshake manager func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { // This queries the lighthouse if we don't know a remote for the host + // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send + // more quickly, effect is a quicker handshake. if hostinfo.remote == nil { - ips, err := f.lightHouse.Query(vpnIp, f) - if err != nil { - //l.Debugln(err) - } - for _, ip := range ips { - hostinfo.AddRemote(ip) - } + f.lightHouse.QueryServer(vpnIp, f) } err := f.handshakeManager.AddIndexHostInfo(hostinfo) @@ -69,7 +65,6 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { hostinfo.HandshakePacket[0] = msg hostinfo.HandshakeReady = true hostinfo.handshakeStart = time.Now() - } func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { @@ -125,13 +120,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hostinfo := &HostInfo{ ConnectionState: ci, - Remotes: []*udpAddr{}, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, hostId: vpnIP, HandshakePacket: make(map[uint8][]byte, 0), } + hostinfo.Lock() + defer hostinfo.Unlock() + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -182,16 +179,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { ci.peerCert = remoteCert ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") - //hostinfo.ClearRemotes() - hostinfo.AddRemote(addr) - hostinfo.ForcePromoteBest(f.hostMap.preferredRanges) + hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) + hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert) - hostinfo.Lock() - defer hostinfo.Unlock() - // Only overwrite existing record if we should win the handshake race overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) @@ -214,6 +206,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { case ErrExistingHostInfo: // This means there was an existing tunnel and we didn't win // handshake avoidance + + //TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing + //TODO: if not new send a test packet like old + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -234,6 +230,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). Error("Failed to add HostInfo due to localIndex collision") return + case ErrExistingHandshake: + // We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish + f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Prevented a pending handshake race") + return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here @@ -286,6 +291,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Info("Handshake is already complete") + //TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here + // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets return false } @@ -334,17 +341,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() + // Ensure the right host responded if vpnIP != hostinfo.hostId { f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") - if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil { - // We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now - f.handshakeManager.pendingHostMap.DeleteHostInfo(ho) - } - // Release our old handshake from pending, it should not continue f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) @@ -354,26 +357,28 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ newHostInfo.Lock() // Block the current used address - newHostInfo.unlockedBlockRemote(addr) + newHostInfo.remotes = hostinfo.remotes + newHostInfo.remotes.BlockRemote(addr) - // If this is an ongoing issue our previous hostmap will have some bad ips too - for _, v := range hostinfo.badRemotes { - newHostInfo.unlockedBlockRemote(v) - } - //TODO: this is me enabling tests - newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges) + // Get the correct remote list for the host we did handshake with + hostinfo.remotes = f.lightHouse.QueryCache(vpnIP) - f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)). - WithField("remotes", newHostInfo.Remotes). + f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)). + WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient + hostinfo.ConnectionState.queueLock.Lock() newHostInfo.packetStore = hostinfo.packetStore hostinfo.packetStore = []*cachedPacket{} + hostinfo.ConnectionState.queueLock.Unlock() - // Set the current hostId to the new vpnIp + // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down hostinfo.hostId = vpnIP + f.sendCloseTunnel(hostinfo) newHostInfo.Unlock() + + return true } // Mark packet 2 as seen so it doesn't show up as missed diff --git a/handshake_manager.go b/handshake_manager.go index 099d002..fec80a0 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -12,12 +12,8 @@ import ( ) const ( - // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries - // With 100ms interval and 20 retries is 23.5 seconds - DefaultHandshakeTryInterval = time.Millisecond * 100 - DefaultHandshakeRetries = 20 - // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses - DefaultHandshakeWaitRotation = 5 + DefaultHandshakeTryInterval = time.Millisecond * 100 + DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 ) @@ -25,7 +21,6 @@ var ( defaultHandshakeConfig = HandshakeConfig{ tryInterval: DefaultHandshakeTryInterval, retries: DefaultHandshakeRetries, - waitRotation: DefaultHandshakeWaitRotation, triggerBuffer: DefaultHandshakeTriggerBuffer, } ) @@ -33,45 +28,36 @@ var ( type HandshakeConfig struct { tryInterval time.Duration retries int - waitRotation int triggerBuffer int messageMetrics *MessageMetrics } type HandshakeManager struct { - pendingHostMap *HostMap - mainHostMap *HostMap - lightHouse *LightHouse - outside *udpConn - config HandshakeConfig + pendingHostMap *HostMap + mainHostMap *HostMap + lightHouse *LightHouse + outside *udpConn + config HandshakeConfig + OutboundHandshakeTimer *SystemTimerWheel + messageMetrics *MessageMetrics + l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIP trigger chan uint32 - - OutboundHandshakeTimer *SystemTimerWheel - InboundHandshakeTimer *SystemTimerWheel - - messageMetrics *MessageMetrics - l *logrus.Logger } func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), - mainHostMap: mainHostMap, - lightHouse: lightHouse, - outside: outside, - - config: config, - - trigger: make(chan uint32, config.triggerBuffer), - - OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), - InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), - - messageMetrics: config.messageMetrics, - l: l, + pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), + mainHostMap: mainHostMap, + lightHouse: lightHouse, + outside: outside, + config: config, + trigger: make(chan uint32, config.triggerBuffer), + OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + messageMetrics: config.messageMetrics, + l: l, } } @@ -84,7 +70,6 @@ func (c *HandshakeManager) Run(f EncWriter) { c.handleOutbound(vpnIP, f, true) case now := <-clockSource: c.NextOutboundHandshakeTimerTick(now, f) - c.NextInboundHandshakeTimerTick(now) } } } @@ -109,84 +94,84 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT hostinfo.Lock() defer hostinfo.Unlock() - // If we haven't finished the handshake and we haven't hit max retries, query - // lighthouse and then send the handshake packet again. - if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete { - if hostinfo.remote == nil { - // We continue to query the lighthouse because hosts may - // come online during handshake retries. If the query - // succeeds (no error), add the lighthouse info to hostinfo - ips := c.lightHouse.QueryCache(vpnIP) - // If we have no responses yet, or only one IP (the host hadn't - // finished reporting its own IPs yet), then send another query to - // the LH. - if len(ips) <= 1 { - ips, err = c.lightHouse.Query(vpnIP, f) - } - if err == nil { - for _, ip := range ips { - hostinfo.AddRemote(ip) - } - hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) - } - } else if lighthouseTriggered { - // We were triggered by a lighthouse HostQueryReply packet, but - // we have already picked a remote for this host (this can happen - // if we are configured with multiple lighthouses). So we can skip - // this trigger and let the timerwheel handle the rest of the - // process - return - } - - hostinfo.HandshakeCounter++ - - // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through - // all the others until we can stand up a connection. - if hostinfo.HandshakeCounter > c.config.waitRotation { - hostinfo.rotateRemote() - } - - // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation - if hostinfo.HandshakeReady && hostinfo.remote != nil { - c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) - if err != nil { - hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") - } else { - //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should - // keep the real packet struct around for logging purposes - hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message sent") - } - } - - // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try - if !lighthouseTriggered { - //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) - c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) - } - } else { + // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. + if hostinfo.HandshakeComplete { + // Ensure we don't exist in the pending hostmap anymore since we have completed c.pendingHostMap.DeleteHostInfo(hostinfo) + return } -} -func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) { - c.InboundHandshakeTimer.advance(now) - for { - ep := c.InboundHandshakeTimer.Purge() - if ep == nil { - break + // Check if we have a handshake packet to transmit yet + if !hostinfo.HandshakeReady { + // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly + // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState + c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + return + } + + // If we are out of time, clean up + if hostinfo.HandshakeCounter >= c.config.retries { + hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("remoteIndex", hostinfo.remoteIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). + Info("Handshake timed out") + //TODO: emit metrics + c.pendingHostMap.DeleteHostInfo(hostinfo) + return + } + + // We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific + // optimization for a fast lighthouse reply + //TODO: it would feel better to do this once, anytime, as our delay increases over time + if lighthouseTriggered && hostinfo.HandshakeCounter > 0 { + // If we didn't return here a lighthouse could cause us to aggressively send handshakes + return + } + + // Get a remotes object if we don't already have one. + // This is mainly to protect us as this should never be the case + if hostinfo.remotes == nil { + hostinfo.remotes = c.lightHouse.QueryCache(vpnIP) + } + + //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) + if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { + // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse + // Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about + // the learned public ip for them. Query again to short circuit the promotion counter + c.lightHouse.QueryServer(vpnIP, f) + } + + // Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply + var sentTo []*udpAddr + hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) { + c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + if err != nil { + hostinfo.logger(c.l).WithField("udpAddr", addr). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake message") + + } else { + sentTo = append(sentTo, addr) } - index := ep.(uint32) + }) - c.pendingHostMap.DeleteIndex(index) + hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message sent") + + // Increment the counter to increase our delay, linear backoff + hostinfo.HandshakeCounter++ + + // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add + if !lighthouseTriggered { + //TODO: feel like we dupe handshake real fast in a tight loop, why? + c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } @@ -194,6 +179,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) // We lock here and use an array to insert items to prevent locking the // main receive thread for very long by waiting to add items to the pending map + //TODO: what lock? c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) return hostinfo @@ -203,6 +189,7 @@ var ( ErrExistingHostInfo = errors.New("existing hostinfo") ErrAlreadySeen = errors.New("already seen") ErrLocalIndexCollision = errors.New("local index collision") + ErrExistingHandshake = errors.New("existing handshake") ) // CheckAndComplete checks for any conflicts in the main and pending hostmap @@ -217,17 +204,21 @@ var ( // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { - c.pendingHostMap.RLock() - defer c.pendingHostMap.RUnlock() + c.pendingHostMap.Lock() + defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() + // Check if we already have a tunnel with this vpn ip existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] if found && existingHostInfo != nil { + // Is it just a delayed handshake packet? if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { return existingHostInfo, ErrAlreadySeen } + if !overwrite { + // It's a new handshake and we lost the race return existingHostInfo, ErrExistingHostInfo } } @@ -237,6 +228,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } + existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] if found && existingIndex != hostinfo { // We have a collision, but for a different hostinfo @@ -252,7 +244,24 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket Info("New host shadows existing host remoteIndex") } + // Check if we are also handshaking with this vpn ip + pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId] + if found && pendingHostInfo != nil { + if !overwrite { + // We won, let our pending handshake win + return pendingHostInfo, ErrExistingHandshake + } + + // We lost, take this handshake and move any cached packets over so they get sent + pendingHostInfo.ConnectionState.queueLock.Lock() + hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...) + c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo) + pendingHostInfo.ConnectionState.queueLock.Unlock() + pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel") + } + if existingHostInfo != nil { + hostinfo.logger(c.l).Info("Race lost, taking new handshake") // We are going to overwrite this entry, so remove the old references delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) @@ -267,6 +276,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // won't have a localIndexId collision because we already have an entry in the // pendingHostMap func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + c.pendingHostMap.Lock() + defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() @@ -288,6 +299,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { } c.mainHostMap.addHostInfo(hostinfo, f) + c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) } // AddIndexHostInfo generates a unique localIndexId for this HostInfo @@ -359,3 +371,7 @@ func generateIndex(l *logrus.Logger) (uint32, error) { } return index, nil } + +func hsTimeout(tries int, interval time.Duration) time.Duration { + return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval))) +} diff --git a/handshake_manager_test.go b/handshake_manager_test.go index b97da5c..c34b0cf 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -8,66 +8,12 @@ import ( "github.com/stretchr/testify/assert" ) -//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} -var ips []uint32 - -func Test_NewHandshakeManagerIndex(t *testing.T) { - l := NewTestLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} - preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextInboundHandshakeTimerTick(now) - - var indexes = make([]uint32, 4) - var hostinfo = make([]*HostInfo, len(indexes)) - for i := range indexes { - hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}} - } - - // Add four indexes - for i := range indexes { - err := blah.AddIndexHostInfo(hostinfo[i]) - assert.NoError(t, err) - indexes[i] = hostinfo[i].localIndexId - blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10) - } - // Confirm they are in the pending index list - for _, v := range indexes { - assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v)) - } - // Adding something to pending should not affect the main hostmap - assert.Len(t, mainHM.Indexes, 0) - // Jump ahead 8 seconds - for i := 1; i <= DefaultHandshakeRetries; i++ { - next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i)) - blah.NextInboundHandshakeTimerTick(next_tick) - } - // Confirm they are still in the pending index list - for _, v := range indexes { - assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v)) - } - // Jump ahead 4 more seconds - next_tick := now.Add(12 * time.Second) - blah.NextInboundHandshakeTimerTick(next_tick) - // Confirm they have been removed - for _, v := range indexes { - assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v)) - } -} - func Test_NewHandshakeManagerVpnIP(t *testing.T) { l := NewTestLogger() _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} + ip := ip2int(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) @@ -77,39 +23,30 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) { now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) - // Add four "IPs" - which are just uint32s - for _, v := range ips { - blah.AddVpnIP(v) - } + i := blah.AddVpnIP(ip) + i.remotes = NewRemoteList() + i.HandshakeReady = true + // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) - // Confirm they are in the pending index list - for _, v := range ips { - assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v)) - } - // Jump ahead `HandshakeRetries` ticks - cumulative := time.Duration(0) - for i := 0; i <= DefaultHandshakeRetries+1; i++ { - cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1 - next_tick := now.Add(cumulative) - //l.Infoln(next_tick) - blah.NextOutboundHandshakeTimerTick(next_tick, mw) + // Confirm they are in the pending index list + assert.Contains(t, blah.pendingHostMap.Hosts, ip) + + // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right + for i := 1; i <= DefaultHandshakeRetries+1; i++ { + now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval) + blah.NextOutboundHandshakeTimerTick(now, mw) } // Confirm they are still in the pending index list - for _, v := range ips { - assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v)) - } - // Jump ahead 1 more second - cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval - next_tick := now.Add(cumulative) - //l.Infoln(next_tick) - blah.NextOutboundHandshakeTimerTick(next_tick, mw) + assert.Contains(t, blah.pendingHostMap.Hosts, ip) + + // Tick 1 more time, a minute will certainly flush it out + blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw) + // Confirm they have been removed - for _, v := range ips { - assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v)) - } + assert.NotContains(t, blah.pendingHostMap.Hosts, ip) } func Test_NewHandshakeManagerTrigger(t *testing.T) { @@ -121,7 +58,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := &LightHouse{} + lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l} blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) @@ -130,28 +67,25 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - blah.AddVpnIP(ip) - + hi := blah.AddVpnIP(ip) + hi.HandshakeReady = true assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) + assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") - // Trigger the same method the channel will + // Trigger the same method the channel will but, this should set our remotes pointer blah.handleOutbound(ip, mw, true) + assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt") + assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer") - // Make sure the trigger doesn't schedule another timer entry + // Make sure the trigger doesn't double schedule the timer entry assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - hi := blah.pendingHostMap.Hosts[ip] - assert.Nil(t, hi.remote) uaddr := NewUDPAddrFromString("10.1.1.1:4242") - lh.addrMap = map[uint32]*ip4And6{} - lh.addrMap[ip] = &ip4And6{ - v4: []*Ip4AndPort{NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))}, - v6: []*Ip6AndPort{}, - } + hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) - // This should trigger the hostmap to populate the hostinfo + // We now have remotes but only the first trigger should have pushed things forward blah.handleOutbound(ip, mw, true) - assert.NotNil(t, hi.remote) + assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt") assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) } @@ -166,100 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) { return c } -func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { - l := NewTestLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIP = ip2int(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) - - hostinfo := blah.AddVpnIP(vpnIP) - // Pretned we have an index too - err := blah.AddIndexHostInfo(hostinfo) - assert.NoError(t, err) - blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) - assert.NotZero(t, hostinfo.localIndexId) - assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId) - - // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending - // but not main hostmap - cumulative := time.Duration(0) - for i := 1; i <= DefaultHandshakeRetries+2; i++ { - cumulative += DefaultHandshakeTryInterval * time.Duration(i) - next_tick := now.Add(cumulative) - blah.NextOutboundHandshakeTimerTick(next_tick, mw) - } - /* - for i := 0; i <= HandshakeRetries+1; i++ { - next_tick := now.Add(cumulative) - //l.Infoln(next_tick) - blah.NextOutboundHandshakeTimerTick(next_tick) - } - */ - /* - for i := 0; i <= HandshakeRetries+1; i++ { - next_tick := now.Add(time.Duration(i) * time.Second) - blah.NextOutboundHandshakeTimerTick(next_tick) - } - */ - - /* - cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3 - next_tick := now.Add(cumulative) - l.Infoln(cumulative, next_tick) - blah.NextOutboundHandshakeTimerTick(next_tick) - */ - assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP)) - assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) -} - -func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { - l := NewTestLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextInboundHandshakeTimerTick(now) - - hostinfo := &HostInfo{ConnectionState: &ConnectionState{}} - err := blah.AddIndexHostInfo(hostinfo) - assert.NoError(t, err) - blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) - // Pretned we have an index too - blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) - assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) - - for i := 1; i <= DefaultHandshakeRetries+2; i++ { - next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i)) - blah.NextInboundHandshakeTimerTick(next_tick) - } - - next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3) - blah.NextInboundHandshakeTimerTick(next_tick) - assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) - assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId)) -} - type mockEncWriter struct { } func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { return } - -func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { - return -} diff --git a/hostmap.go b/hostmap.go index c1333b8..3f04793 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,7 +1,6 @@ package nebula import ( - "encoding/json" "errors" "fmt" "net" @@ -16,6 +15,7 @@ import ( //const ProbeLen = 100 const PromoteEvery = 1000 +const ReQueryEvery = 5000 const MaxRemotes = 10 // How long we should prevent roaming back to the previous IP. @@ -30,7 +30,6 @@ type HostMap struct { Hosts map[uint32]*HostInfo preferredRanges []*net.IPNet vpnCIDR *net.IPNet - defaultRoute uint32 unsafeRoutes *CIDRTree metricsEnabled bool l *logrus.Logger @@ -40,25 +39,21 @@ type HostInfo struct { sync.RWMutex remote *udpAddr - Remotes []*udpAddr + remotes *RemoteList promoteCounter uint32 ConnectionState *ConnectionState - handshakeStart time.Time - HandshakeReady bool - HandshakeCounter int - HandshakeComplete bool - HandshakePacket map[uint8][]byte - packetStore []*cachedPacket + handshakeStart time.Time //todo: this an entry in the handshake manager + HandshakeReady bool //todo: being in the manager means you are ready + HandshakeCounter int //todo: another handshake manager entry + HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready + HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry + packetStore []*cachedPacket //todo: this is other handshake manager entry remoteIndexId uint32 localIndexId uint32 hostId uint32 recvError int remoteCidr *CIDRTree - // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. - // They should not be tried again during a handshake - badRemotes []*udpAddr - // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // with a handshake @@ -88,7 +83,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang Hosts: h, preferredRanges: preferredRanges, vpnCIDR: vpnCIDR, - defaultRoute: 0, unsafeRoutes: NewCIDRTree(), l: l, } @@ -131,7 +125,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo { if _, ok := hm.Hosts[vpnIP]; !ok { hm.RUnlock() h = &HostInfo{ - Remotes: []*udpAddr{}, promoteCounter: 0, hostId: vpnIP, HandshakePacket: make(map[uint8][]byte, 0), @@ -239,7 +232,11 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) { func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) +} +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // Check if this same hostId is in the hostmap with a different instance. // This could happen if we have an entry in the pending hostmap with different // index values than the one in the main hostmap. @@ -262,7 +259,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { if len(hm.RemoteIndexes) == 0 { hm.RemoteIndexes = map[uint32]*HostInfo{} } - hm.Unlock() if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), @@ -294,30 +290,6 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { } } -func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo { - hm.Lock() - i, v := hm.Hosts[vpnIp] - if v { - i.AddRemote(remote) - } else { - i = &HostInfo{ - Remotes: []*udpAddr{remote.Copy()}, - promoteCounter: 0, - hostId: vpnIp, - HandshakePacket: make(map[uint8][]byte, 0), - } - i.remote = i.Remotes[0] - hm.Hosts[vpnIp] = i - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). - Debug("Hostmap remote ip added") - } - } - i.ForcePromoteBest(hm.preferredRanges) - hm.Unlock() - return i -} - func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) { return hm.queryVpnIP(vpnIp, nil) } @@ -331,12 +303,13 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { - if promoteIfce != nil { + // Do not attempt promotion if you are a lighthouse + if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { h.TryPromoteBest(hm.preferredRanges, promoteIfce) } - //fmt.Println(h.remote) hm.RUnlock() return h, nil + } else { //return &net.UDPAddr{}, nil, errors.New("Unable to find host") hm.RUnlock() @@ -362,11 +335,8 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { // We already have the hm Lock when this is called, so make sure to not call // any other methods that might try to grab it again func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { - remoteCert := hostinfo.ConnectionState.peerCert - ip := ip2int(remoteCert.Details.Ips[0].IP) - - f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote) if f.serveDns { + remoteCert := hostinfo.ConnectionState.peerCert dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } @@ -381,38 +351,21 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) ClearRemotes(vpnIP uint32) { - hm.Lock() - i := hm.Hosts[vpnIP] - if i == nil { - hm.Unlock() - return - } - i.remote = nil - i.Remotes = nil - hm.Unlock() -} - -func (hm *HostMap) SetDefaultRoute(ip uint32) { - hm.defaultRoute = ip -} - -func (hm *HostMap) PunchList() []*udpAddr { - var list []*udpAddr +// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap +// The caller can then do the its work outside of the read lock +func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList { hm.RLock() + defer hm.RUnlock() + for _, v := range hm.Hosts { - for _, r := range v.Remotes { - list = append(list, r) + if v.remotes != nil { + rl = append(rl, v.remotes) } - // if h, ok := hm.Hosts[vpnIp]; ok { - // hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false) - //fmt.Println(h.remote) - // } } - hm.RUnlock() - return list + return rl } +// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them func (hm *HostMap) Punchy(conn *udpConn) { var metricsTxPunchy metrics.Counter if hm.metricsEnabled { @@ -421,13 +374,18 @@ func (hm *HostMap) Punchy(conn *udpConn) { metricsTxPunchy = metrics.NilCounter{} } + var remotes []*RemoteList b := []byte{1} for { - for _, addr := range hm.PunchList() { - metricsTxPunchy.Inc(1) - conn.WriteTo(b, addr) + remotes = hm.punchList(remotes[:0]) + for _, rl := range remotes { + //TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better + for _, addr := range rl.CopyAddrs(hm.preferredRanges) { + metricsTxPunchy.Inc(1) + conn.WriteTo(b, addr) + } } - time.Sleep(time.Second * 30) + time.Sleep(time.Second * 10) } } @@ -438,38 +396,15 @@ func (hm *HostMap) addUnsafeRoutes(routes *[]route) { } } -func (i *HostInfo) MarshalJSON() ([]byte, error) { - return json.Marshal(m{ - "remote": i.remote, - "remotes": i.Remotes, - "promote_counter": i.promoteCounter, - "connection_state": i.ConnectionState, - "handshake_start": i.handshakeStart, - "handshake_ready": i.HandshakeReady, - "handshake_counter": i.HandshakeCounter, - "handshake_complete": i.HandshakeComplete, - "handshake_packet": i.HandshakePacket, - "packet_store": i.packetStore, - "remote_index": i.remoteIndexId, - "local_index": i.localIndexId, - "host_id": int2ip(i.hostId), - "receive_errors": i.recvError, - "last_roam": i.lastRoam, - "last_roam_remote": i.lastRoamRemote, - }) -} - func (i *HostInfo) BindConnectionState(cs *ConnectionState) { i.ConnectionState = cs } +// TryPromoteBest handles re-querying lighthouses and probing for better paths +// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { - if i.remote == nil { - i.ForcePromoteBest(preferredRanges) - return - } - - if atomic.AddUint32(&i.promoteCounter, 1)%PromoteEvery == 0 { + c := atomic.AddUint32(&i.promoteCounter, 1) + if c%PromoteEvery == 0 { // return early if we are already on a preferred remote rIP := i.remote.IP for _, l := range preferredRanges { @@ -478,87 +413,21 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - // We re-query the lighthouse periodically while sending packets, so - // check for new remotes in our local lighthouse cache - ips := ifce.lightHouse.QueryCache(i.hostId) - for _, ip := range ips { - i.AddRemote(ip) - } + i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) { + if addr == nil || !preferred { + return + } - best, preferred := i.getBestRemote(preferredRanges) - if preferred && !best.Equals(i.remote) { // Try to send a test packet to that host, this should // cause it to detect a roaming event and switch remotes - ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - } -} - -func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) { - best, _ := i.getBestRemote(preferredRanges) - if best != nil { - i.remote = best - } -} - -func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) { - if len(i.Remotes) > 0 { - for _, r := range i.Remotes { - for _, l := range preferredRanges { - if l.Contains(r.IP) { - return r, true - } - } - - if best == nil || !PrivateIP(r.IP) { - best = r - } - /* - for _, r := range i.Remotes { - // Must have > 80% probe success to be considered. - //fmt.Println("GRADE:", r.addr.IP, r.Grade()) - if r.Grade() > float64(.8) { - if localToMe.Contains(r.addr.IP) == true { - best = r.addr - break - //i.remote = i.Remotes[c].addr - } else { - //} - } - */ - } - return best, false + ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + }) } - return nil, false -} - -// rotateRemote will move remote to the next ip in the list of remote ips for this host -// This is different than PromoteBest in that what is algorithmically best may not actually work. -// Only known use case is when sending a stage 0 handshake. -// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver. -func (i *HostInfo) rotateRemote() { - // We have 0, can't rotate - if len(i.Remotes) < 1 { - return + // Re query our lighthouses for new remotes occasionally + if c%ReQueryEvery == 0 && ifce.lightHouse != nil { + ifce.lightHouse.QueryServer(i.hostId, ifce) } - - if i.remote == nil { - i.remote = i.Remotes[0] - return - } - - // We want to look at all but the very last entry since that is handled at the end - for x := 0; x < len(i.Remotes)-1; x++ { - // Find our current position and move to the next one in the list - if i.Remotes[x].Equals(i.remote) { - i.remote = i.Remotes[x+1] - return - } - } - - // Our current position was likely the last in the list, start over at 0 - i.remote = i.Remotes[0] } func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { @@ -607,23 +476,13 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) { } } - i.badRemotes = make([]*udpAddr, 0) + i.remotes.ResetBlockedRemotes() i.packetStore = make([]*cachedPacket, 0) i.ConnectionState.ready = true i.ConnectionState.queueLock.Unlock() i.ConnectionState.certState = nil } -func (i *HostInfo) CopyRemotes() []*udpAddr { - i.RLock() - rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes)) - for x, addr := range i.Remotes { - rc[x] = addr.Copy() - } - i.RUnlock() - return rc -} - func (i *HostInfo) GetCert() *cert.NebulaCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert @@ -631,58 +490,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr { - if i.unlockedIsBadRemote(remote) { - return i.remote - } - - for _, r := range i.Remotes { - if r.Equals(remote) { - return r - } - } - - // Trim this down if necessary - if len(i.Remotes) > MaxRemotes { - i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:] - } - - rc := remote.Copy() - i.Remotes = append(i.Remotes, rc) - return rc -} - func (i *HostInfo) SetRemote(remote *udpAddr) { - i.remote = i.AddRemote(remote) -} - -func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) { - if !i.unlockedIsBadRemote(remote) { - // We copy here because we are taking something else's memory and we can't trust everything - i.badRemotes = append(i.badRemotes, remote.Copy()) + // We copy here because we likely got this remote from a source that reuses the object + if !i.remote.Equals(remote) { + i.remote = remote.Copy() + i.remotes.LearnRemote(i.hostId, remote.Copy()) } - - for k, v := range i.Remotes { - if v.Equals(remote) { - i.Remotes[k] = i.Remotes[len(i.Remotes)-1] - i.Remotes = i.Remotes[:len(i.Remotes)-1] - return - } - } -} - -func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool { - for _, v := range i.badRemotes { - if v.Equals(remote) { - return true - } - } - return false -} - -func (i *HostInfo) ClearRemotes() { - i.remote = nil - i.Remotes = []*udpAddr{} } func (i *HostInfo) ClearConnectionState() { @@ -805,13 +618,3 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { } return &ips } - -func PrivateIP(ip net.IP) bool { - //TODO: Private for ipv6 or just let it ride? - private := false - _, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8") - _, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12") - _, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16") - private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) - return private -} diff --git a/hostmap_test.go b/hostmap_test.go index f158b9e..2808317 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1,169 +1 @@ package nebula - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -/* -func TestHostInfoDestProbe(t *testing.T) { - a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222") - d := NewHostInfoDest(a) - - // 999 probes that all return should give a 100% success rate - for i := 0; i < 999; i++ { - meh := d.Probe() - d.ProbeReceived(meh) - } - assert.Equal(t, d.Grade(), float64(1)) - - // 999 probes of which only half return should give a 50% success rate - for i := 0; i < 999; i++ { - meh := d.Probe() - if i%2 == 0 { - d.ProbeReceived(meh) - } - } - assert.Equal(t, d.Grade(), float64(.5)) - - // 999 probes of which none return should give a 0% success rate - for i := 0; i < 999; i++ { - d.Probe() - } - assert.Equal(t, d.Grade(), float64(0)) - - // 999 probes of which only 1/4 return should give a 25% success rate - for i := 0; i < 999; i++ { - meh := d.Probe() - if i%4 == 0 { - d.ProbeReceived(meh) - } - } - assert.Equal(t, d.Grade(), float64(.25)) - - // 999 probes of which only half return and are duplicates should give a 50% success rate - for i := 0; i < 999; i++ { - meh := d.Probe() - if i%2 == 0 { - d.ProbeReceived(meh) - d.ProbeReceived(meh) - } - } - assert.Equal(t, d.Grade(), float64(.5)) - - // 999 probes of which only way old replies return should give a 0% success rate - for i := 0; i < 999; i++ { - meh := d.Probe() - d.ProbeReceived(meh - 101) - } - assert.Equal(t, d.Grade(), float64(0)) - -} -*/ - -func TestHostmap(t *testing.T) { - l := NewTestLogger() - _, myNet, _ := net.ParseCIDR("10.128.0.0/16") - _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") - myNets := []*net.IPNet{myNet} - preferredRanges := []*net.IPNet{localToMe} - - m := NewHostMap(l, "test", myNet, preferredRanges) - - a := NewUDPAddrFromString("10.127.0.3:11111") - b := NewUDPAddrFromString("1.0.0.1:22222") - y := NewUDPAddrFromString("10.128.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - - info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1"))) - - // There should be three remotes in the host map - assert.Equal(t, 3, len(info.Remotes)) - - // Adding an identical remote should not change the count - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - assert.Equal(t, 3, len(info.Remotes)) - - // Adding a fresh remote should add one - y = NewUDPAddrFromString("10.18.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - assert.Equal(t, 4, len(info.Remotes)) - - // Query and reference remote should get the first one (and not nil) - info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1"))) - assert.NotNil(t, info.remote) - - // Promotion should ensure that the best remote is chosen (y) - info.ForcePromoteBest(myNets) - assert.True(t, myNet.Contains(info.remote.IP)) - -} - -func TestHostmapdebug(t *testing.T) { - l := NewTestLogger() - _, myNet, _ := net.ParseCIDR("10.128.0.0/16") - _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") - preferredRanges := []*net.IPNet{localToMe} - m := NewHostMap(l, "test", myNet, preferredRanges) - - a := NewUDPAddrFromString("10.127.0.3:11111") - b := NewUDPAddrFromString("1.0.0.1:22222") - y := NewUDPAddrFromString("10.128.0.3:11111") - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - - //t.Errorf("%s", m.DebugRemotes(1)) -} - -func TestHostMap_rotateRemote(t *testing.T) { - h := HostInfo{} - // 0 remotes, no panic - h.rotateRemote() - assert.Nil(t, h.remote) - - // 1 remote, no panic - h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 1}, 0)) - h.rotateRemote() - assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 1}) - - h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 2}, 0)) - h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 3}, 0)) - h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 4}, 0)) - - //TODO: ensure we are copying and not storing the slice! - - // Rotate through those 3 - h.rotateRemote() - assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 2}) - - h.rotateRemote() - assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 3}) - - h.rotateRemote() - assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 4}, Port: 0}) - - // Finally, we should start over - h.rotateRemote() - assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 1}, Port: 0}) -} - -func BenchmarkHostmappromote2(b *testing.B) { - l := NewTestLogger() - for n := 0; n < b.N; n++ { - _, myNet, _ := net.ParseCIDR("10.128.0.0/16") - _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") - preferredRanges := []*net.IPNet{localToMe} - m := NewHostMap(l, "test", myNet, preferredRanges) - y := NewUDPAddrFromString("10.128.0.3:11111") - a := NewUDPAddrFromString("10.127.0.3:11111") - g := NewUDPAddrFromString("1.0.0.1:22222") - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g) - m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) - } -} diff --git a/inside.go b/inside.go index b92a76f..46371bd 100644 --- a/inside.go +++ b/inside.go @@ -54,10 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) if dropReason == nil { - mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) - if f.lightHouse != nil && mc%5000 == 0 { - f.lightHouse.Query(fwPacket.RemoteIP, f) - } + f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) } else if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). @@ -84,15 +81,13 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { hostinfo = f.handshakeManager.AddVpnIP(vpnIp) } } - ci := hostinfo.ConnectionState if ci != nil && ci.eKey != nil && ci.ready { return hostinfo } - // Handshake is not ready, we need to grab the lock now before we start - // the handshake process + // Handshake is not ready, we need to grab the lock now before we start the handshake process hostinfo.Lock() defer hostinfo.Unlock() @@ -150,10 +145,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, return } - messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) - if f.lightHouse != nil && messageCounter%5000 == 0 { - f.lightHouse.Query(fp.RemoteIP, f) - } + f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp @@ -187,50 +179,15 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) } -// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp -func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { - hostInfo := f.getOrHandshake(vpnIp) - if hostInfo == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", IntIp(vpnIp)). - Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes") - } - return - } - - if hostInfo.ConnectionState.ready == false { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - hostInfo.ConnectionState.queueLock.Lock() - if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll) - hostInfo.ConnectionState.queueLock.Unlock() - return - } - hostInfo.ConnectionState.queueLock.Unlock() - } - - f.sendMessageToAll(t, st, hostInfo, p, nb, out) - return -} - -func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) { - hostInfo.RLock() - for _, r := range hostInfo.Remotes { - f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b) - } - hostInfo.RUnlock() -} - func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 { +func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning - return 0 + return } var err error @@ -262,7 +219,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, WithField("udpAddr", remote).WithField("counter", c). WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") - return c + return } err = f.writers[q].WriteTo(out, remote) @@ -270,7 +227,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - return c + return } func isMulticast(ip uint32) bool { diff --git a/lighthouse.go b/lighthouse.go index 4f0a468..d3f8d29 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -13,26 +13,11 @@ import ( "github.com/sirupsen/logrus" ) +//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? //TODO: nodes are roaming lighthouses, this is bad. How are they learning? var ErrHostNotKnown = errors.New("host not known") -// The maximum number of ip addresses to store for a given vpnIp per address family -const maxAddrs = 10 - -type ip4And6 struct { - //TODO: adding a lock here could allow us to release the lock on lh.addrMap quicker - - // v4 and v6 store addresses that have been self reported by the client in a server or where all addresses are stored on a client - v4 []*Ip4AndPort - v6 []*Ip6AndPort - - // Learned addresses are ones that a client does not know about but a lighthouse learned from as a result of the received packet - // This is only used if you are a lighthouse server - learnedV4 []*Ip4AndPort - learnedV6 []*Ip6AndPort -} - type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps @@ -42,7 +27,8 @@ type LightHouse struct { punchConn *udpConn // Local cache of answers from light houses - addrMap map[uint32]*ip4And6 + // map of vpn Ip to answers + addrMap map[uint32]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -81,7 +67,7 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i amLighthouse: amLighthouse, myVpnIp: ip2int(myVpnIpNet.IP), myVpnZeros: uint32(32 - ones), - addrMap: make(map[uint32]*ip4And6), + addrMap: make(map[uint32]*RemoteList), nebulaPort: nebulaPort, lighthouses: make(map[uint32]struct{}), staticList: make(map[uint32]struct{}), @@ -130,57 +116,79 @@ func (lh *LightHouse) ValidateLHStaticEntries() error { return nil } -func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]*udpAddr, error) { - //TODO: we need to hold the lock through the next func +func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip, f) } lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() - return TransformLHReplyToUdpAddrs(v), nil - } - lh.RUnlock() - return nil, ErrHostNotKnown -} - -// This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { - if !lh.amLighthouse { - // Send a query to the lighthouses and hope for the best next time - query, err := proto.Marshal(NewLhQueryByInt(ip)) - if err != nil { - lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") - return - } - - lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses))) - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for n := range lh.lighthouses { - f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) - } - } -} - -func (lh *LightHouse) QueryCache(ip uint32) []*udpAddr { - //TODO: we need to hold the lock through the next func - lh.RLock() - if v, ok := lh.addrMap[ip]; ok { - lh.RUnlock() - return TransformLHReplyToUdpAddrs(v) + return v } lh.RUnlock() return nil } -// -func (lh *LightHouse) queryAndPrepMessage(ip uint32, f func(*ip4And6) (int, error)) (bool, int, error) { +// This is asynchronous so no reply should be expected +func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { + if lh.amLighthouse { + return + } + + if lh.IsLighthouseIP(ip) { + return + } + + // Send a query to the lighthouses and hope for the best next time + query, err := proto.Marshal(NewLhQueryByInt(ip)) + if err != nil { + lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") + return + } + + lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses))) + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for n := range lh.lighthouses { + f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) + } +} + +func (lh *LightHouse) QueryCache(ip uint32) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { - n, err := f(v) lh.RUnlock() - return true, n, err + return v + } + lh.RUnlock() + + lh.Lock() + defer lh.Unlock() + // Add an entry if we don't already have one + return lh.unlockedGetRemoteList(ip) +} + +// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing +// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp +// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() +func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) { + lh.RLock() + // Do we have an entry in the main cache? + if v, ok := lh.addrMap[vpnIp]; ok { + // Swap lh lock for remote list lock + v.RLock() + defer v.RUnlock() + + lh.RUnlock() + + // vpnIp should also be the owner here since we are a lighthouse. + c := v.cache[vpnIp] + // Make sure we have + if c != nil { + n, err := f(c) + return true, n, err + } + return false, 0, nil } lh.RUnlock() return false, 0, nil @@ -203,70 +211,47 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { lh.Unlock() } -// AddRemote is correct way for non LightHouse members to add an address. toAddr will be placed in the learned map -// static means this is a static host entry from the config file, it should only be used on start up -func (lh *LightHouse) AddRemote(vpnIP uint32, toAddr *udpAddr, static bool) { +// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with +// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client +func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) { + lh.Lock() + am := lh.unlockedGetRemoteList(vpnIp) + am.Lock() + defer am.Unlock() + lh.Unlock() + if ipv4 := toAddr.IP.To4(); ipv4 != nil { - lh.addRemoteV4(vpnIP, NewIp4AndPort(ipv4, uint32(toAddr.Port)), static, true) + to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) + if !lh.unlockedShouldAddV4(to) { + return + } + am.unlockedPrependV4(lh.myVpnIp, to) + } else { - lh.addRemoteV6(vpnIP, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)), static, true) + to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) + if !lh.unlockedShouldAddV6(to) { + return + } + am.unlockedPrependV6(lh.myVpnIp, to) } - //TODO: if we do not add due to a config filter we may end up not having any addresses here - if static { - lh.staticList[vpnIP] = struct{}{} - } + // Mark it as static + lh.staticList[vpnIp] = struct{}{} } -// unlockedGetAddrs assumes you have the lh lock -func (lh *LightHouse) unlockedGetAddrs(vpnIP uint32) *ip4And6 { +// unlockedGetRemoteList assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList { am, ok := lh.addrMap[vpnIP] if !ok { - am = &ip4And6{} + am = NewRemoteList() lh.addrMap[vpnIP] = am } return am } -// addRemoteV4 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated -func (lh *LightHouse) addRemoteV4(vpnIP uint32, to *Ip4AndPort, static bool, learned bool) { - // First we check if the sender thinks this is a static entry - // and do nothing if it is not, but should be considered static - if static == false { - if _, ok := lh.staticList[vpnIP]; ok { - return - } - } - - lh.Lock() - defer lh.Unlock() - am := lh.unlockedGetAddrs(vpnIP) - - if learned { - if !lh.unlockedShouldAddV4(am.learnedV4, to) { - return - } - am.learnedV4 = prependAndLimitV4(am.learnedV4, to) - } else { - if !lh.unlockedShouldAddV4(am.v4, to) { - return - } - am.v4 = prependAndLimitV4(am.v4, to) - } -} - -func prependAndLimitV4(cache []*Ip4AndPort, to *Ip4AndPort) []*Ip4AndPort { - cache = append(cache, nil) - copy(cache[1:], cache) - cache[0] = to - if len(cache) > MaxRemotes { - cache = cache[:maxAddrs] - } - return cache -} - -// unlockedShouldAddV4 checks if to is allowed by our allow list and is not already present in the cache -func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool { +// unlockedShouldAddV4 checks if to is allowed by our allow list +func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool { allow := lh.remoteAllowList.AllowIpV4(to.Ip) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow") @@ -276,69 +261,21 @@ func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool return false } - for _, v := range am { - if v.Ip == to.Ip && v.Port == to.Port { - return false - } - } - return true } -// addRemoteV6 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated -func (lh *LightHouse) addRemoteV6(vpnIP uint32, to *Ip6AndPort, static bool, learned bool) { - // First we check if the sender thinks this is a static entry - // and do nothing if it is not, but should be considered static - if static == false { - if _, ok := lh.staticList[vpnIP]; ok { - return - } - } - - lh.Lock() - defer lh.Unlock() - am := lh.unlockedGetAddrs(vpnIP) - - if learned { - if !lh.unlockedShouldAddV6(am.learnedV6, to) { - return - } - am.learnedV6 = prependAndLimitV6(am.learnedV6, to) - } else { - if !lh.unlockedShouldAddV6(am.v6, to) { - return - } - am.v6 = prependAndLimitV6(am.v6, to) - } -} - -func prependAndLimitV6(cache []*Ip6AndPort, to *Ip6AndPort) []*Ip6AndPort { - cache = append(cache, nil) - copy(cache[1:], cache) - cache[0] = to - if len(cache) > MaxRemotes { - cache = cache[:maxAddrs] - } - return cache -} - -// unlockedShouldAddV6 checks if to is allowed by our allow list and is not already present in the cache -func (lh *LightHouse) unlockedShouldAddV6(am []*Ip6AndPort, to *Ip6AndPort) bool { +// unlockedShouldAddV6 checks if to is allowed by our allow list +func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool { allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } + // We don't check our vpn network here because nebula does not support ipv6 on the inside if !allow { return false } - for _, v := range am { - if v.Hi == to.Hi && v.Lo == to.Lo && v.Port == to.Port { - return false - } - } - return true } @@ -349,13 +286,6 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) { - if lh.amLighthouse { - lh.DeleteVpnIP(vpnIP) - lh.AddRemote(vpnIP, toIp, false) - } -} - func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool { if _, ok := lh.lighthouses[vpnIP]; ok { return true @@ -496,7 +426,6 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -//TODO: do we need c here? func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) @@ -544,13 +473,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr //TODO: we can DRY this further reqVpnIP := n.Details.VpnIp //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - //TODO: If we use a lock on cache we can avoid holding it on lh.addrMap and keep things moving better - found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(cache *ip4And6) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply n.Details.VpnIp = reqVpnIP - lhh.coalesceAnswers(cache, n) + lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) }) @@ -568,12 +496,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(cache *ip4And6) (int, error) { + found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification n.Details.VpnIp = vpnIp - lhh.coalesceAnswers(cache, n) + lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) }) @@ -591,12 +519,24 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(cache *ip4And6, n *NebulaMeta) { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.v4...) - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.learnedV4...) +func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { + if c.v4 != nil { + if c.v4.learned != nil { + n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + } + if c.v4.reported != nil && len(c.v4.reported) > 0 { + n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + } + } - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.v6...) - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.learnedV6...) + if c.v6 != nil { + if c.v6.learned != nil { + n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + } + if c.v6.reported != nil && len(c.v6.reported) > 0 { + n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + } + } } func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) { @@ -604,14 +544,14 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) return } - // We can't just slam the responses in as they may come from multiple lighthouses and we should coalesce the answers - for _, to := range n.Details.Ip4AndPorts { - lhh.lh.addRemoteV4(n.Details.VpnIp, to, false, false) - } + lhh.lh.Lock() + am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp) + am.Lock() + lhh.lh.Unlock() - for _, to := range n.Details.Ip6AndPorts { - lhh.lh.addRemoteV6(n.Details.VpnIp, to, false, false) - } + am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { @@ -637,35 +577,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } lhh.lh.Lock() - defer lhh.lh.Unlock() - am := lhh.lh.unlockedGetAddrs(vpnIp) + am := lhh.lh.unlockedGetRemoteList(vpnIp) + am.Lock() + lhh.lh.Unlock() - //TODO: other note on a lock for am so we can release more quickly and lock our real unit of change which is far less contended - - // We don't accumulate addresses being told to us - am.v4 = am.v4[:0] - am.v6 = am.v6[:0] - - for _, v := range n.Details.Ip4AndPorts { - if lhh.lh.unlockedShouldAddV4(am.v4, v) { - am.v4 = append(am.v4, v) - } - } - - for _, v := range n.Details.Ip6AndPorts { - if lhh.lh.unlockedShouldAddV6(am.v6, v) { - am.v6 = append(am.v6, v) - } - } - - // We prefer the first n addresses if we got too big - if len(am.v4) > MaxRemotes { - am.v4 = am.v4[:MaxRemotes] - } - - if len(am.v6) > MaxRemotes { - am.v6 = am.v6[:MaxRemotes] - } + am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.Unlock() } func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) { @@ -716,33 +634,6 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u } } -func TransformLHReplyToUdpAddrs(ips *ip4And6) []*udpAddr { - addrs := make([]*udpAddr, len(ips.v4)+len(ips.v6)+len(ips.learnedV4)+len(ips.learnedV6)) - i := 0 - - for _, v := range ips.learnedV4 { - addrs[i] = NewUDPAddrFromLH4(v) - i++ - } - - for _, v := range ips.v4 { - addrs[i] = NewUDPAddrFromLH4(v) - i++ - } - - for _, v := range ips.learnedV6 { - addrs[i] = NewUDPAddrFromLH6(v) - i++ - } - - for _, v := range ips.v6 { - addrs[i] = NewUDPAddrFromLH6(v) - i++ - } - - return addrs -} - // ipMaskContains checks if testIp is contained by ip after applying a cidr // zeros is 32 - bits from net.IPMask.Size() func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool { diff --git a/lighthouse_test.go b/lighthouse_test.go index 9b7f044..da4e22d 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -48,16 +48,16 @@ func Test_lhStaticMapping(t *testing.T) { udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) + meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) + meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) err := meh.ValidateLHStaticEntries() assert.Nil(t, err) lh2 := "10.128.0.3" lh2IP := net.ParseIP(lh2) - meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) + meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) + meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242))) err = meh.ValidateLHStaticEntries() assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -73,17 +73,27 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr := NewUDPAddrFromString("4.5.6.7:12345") hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = &ip4And6{v4: []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port))}, - } + lh.addrMap[3] = NewRemoteList() + lh.addrMap[3].unlockedSetV4( + 3, + []*Ip4AndPort{ + NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), + NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + }, + func(*Ip4AndPort) bool { return true }, + ) rAddr := NewUDPAddrFromString("1.2.2.3:12345") rAddr2 := NewUDPAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = &ip4And6{v4: []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port))}, - } + lh.addrMap[2] = NewRemoteList() + lh.addrMap[2].unlockedSetV4( + 3, + []*Ip4AndPort{ + NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), + NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + }, + func(*Ip4AndPort) bool { return true }, + ) mw := &mockEncWriter{} @@ -173,7 +183,7 @@ func TestLighthouse_Memory(t *testing.T) { assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting - // Send 12 addrs, get 10 back, one removed on a dupe check the other by limiting + // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) newLHHostUpdate( myUdpAddr0, myVpnIp, @@ -191,11 +201,12 @@ func TestLighthouse_Memory(t *testing.T) { myUdpAddr10, myUdpAddr11, // This should get cut }, lhh) + r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, r.msg.Details.Ip4AndPorts, - myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10, + myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) // Make sure we won't add ips in our vpn network @@ -247,71 +258,71 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig lhh.HandleRequest(fromAddr, vpnIp, b, w) } -func Test_lhRemoteAllowList(t *testing.T) { - l := NewTestLogger() - c := NewConfig(l) - c.Settings["remoteallowlist"] = map[interface{}]interface{}{ - "10.20.0.0/12": false, - } - allowList, err := c.GetAllowList("remoteallowlist", false) - assert.Nil(t, err) - - lh1 := "10.128.0.2" - lh1IP := net.ParseIP(lh1) - - udpServer, _ := NewListener(l, "0.0.0.0", 0, true) - - lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) - lh.SetRemoteAllowList(allowList) - - // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap - remote1IP := net.ParseIP("10.20.0.3") - lh.AddRemote(ip2int(remote1IP), NewUDPAddr(remote1IP, uint16(4242)), true) - assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) - assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v4) - assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v6) - - // Make sure a good ip enters the cache and addrMap - remote2IP := net.ParseIP("10.128.0.3") - remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) - lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true) - assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote2UDPAddr) - - // Another good ip gets into the cache, ordering is inverted - remote3IP := net.ParseIP("10.128.0.4") - remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) - lh.AddRemote(ip2int(remote2IP), remote3UDPAddr, true) - assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote3UDPAddr, remote2UDPAddr) - - // If we exceed the length limit we should only have the most recent addresses - addedAddrs := []*udpAddr{} - for i := 0; i < 11; i++ { - remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) - lh.AddRemote(ip2int(remote2IP), remoteUDPAddr, true) - // The first entry here is a duplicate, don't add it to the assert list - if i != 0 { - addedAddrs = append(addedAddrs, remoteUDPAddr) - } - } - - // We should only have the last 10 of what we tried to add - assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") - ln := len(addedAddrs) - assertIp4InArray( - t, - lh.addrMap[ip2int(remote2IP)].learnedV4, - addedAddrs[ln-1], - addedAddrs[ln-2], - addedAddrs[ln-3], - addedAddrs[ln-4], - addedAddrs[ln-5], - addedAddrs[ln-6], - addedAddrs[ln-7], - addedAddrs[ln-8], - addedAddrs[ln-9], - addedAddrs[ln-10], - ) -} +//TODO: this is a RemoteList test +//func Test_lhRemoteAllowList(t *testing.T) { +// l := NewTestLogger() +// c := NewConfig(l) +// c.Settings["remoteallowlist"] = map[interface{}]interface{}{ +// "10.20.0.0/12": false, +// } +// allowList, err := c.GetAllowList("remoteallowlist", false) +// assert.Nil(t, err) +// +// lh1 := "10.128.0.2" +// lh1IP := net.ParseIP(lh1) +// +// udpServer, _ := NewListener(l, "0.0.0.0", 0, true) +// +// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) +// lh.SetRemoteAllowList(allowList) +// +// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap +// remote1IP := net.ParseIP("10.20.0.3") +// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP)) +// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242)) +// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) +// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{})) +// +// // Make sure a good ip enters the cache and addrMap +// remote2IP := net.ParseIP("10.128.0.3") +// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) +// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false) +// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr) +// +// // Another good ip gets into the cache, ordering is inverted +// remote3IP := net.ParseIP("10.128.0.4") +// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) +// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false) +// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr) +// +// // If we exceed the length limit we should only have the most recent addresses +// addedAddrs := []*udpAddr{} +// for i := 0; i < 11; i++ { +// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) +// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false) +// // The first entry here is a duplicate, don't add it to the assert list +// if i != 0 { +// addedAddrs = append(addedAddrs, remoteUDPAddr) +// } +// } +// +// // We should only have the last 10 of what we tried to add +// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") +// assertUdpAddrInArray( +// t, +// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), +// addedAddrs[0], +// addedAddrs[1], +// addedAddrs[2], +// addedAddrs[3], +// addedAddrs[4], +// addedAddrs[5], +// addedAddrs[6], +// addedAddrs[7], +// addedAddrs[8], +// addedAddrs[9], +// ) +//} func Test_ipMaskContains(t *testing.T) { assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255")))) @@ -354,6 +365,16 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) { } } +// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match +func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) { + assert.Len(t, have, len(want)) + for k, w := range want { + if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) + } + } +} + func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr { addrs := make([]*udpAddr, len(ips)) for k, v := range ips { diff --git a/main.go b/main.go index 6d8bd3a..6abe5b6 100644 --- a/main.go +++ b/main.go @@ -221,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) - hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) + hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) @@ -302,14 +302,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L if err != nil { return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } - lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true) + lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) } } else { ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v)) if err != nil { return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } - lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true) + lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port)) } } @@ -328,7 +328,6 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L handshakeConfig := HandshakeConfig{ tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), - waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation), triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, diff --git a/outside.go b/outside.go index f8148a5..b2fd6e2 100644 --- a/outside.go +++ b/outside.go @@ -132,6 +132,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, f.connectionManager.In(hostinfo.hostId) } +// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately f.connectionManager.ClearIP(hostInfo.hostId) @@ -140,6 +141,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { f.hostMap.DeleteHostInfo(hostInfo) } +// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote +func (f *Interface) sendCloseTunnel(h *HostInfo) { + f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) +} + func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { if hostDidRoam(hostinfo.remote, addr) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) { @@ -160,9 +166,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { remoteCopy := *hostinfo.remote hostinfo.lastRoamRemote = &remoteCopy hostinfo.SetRemote(addr) - if f.lightHouse.amLighthouse { - f.lightHouse.AddRemote(hostinfo.hostId, addr, false) - } } } diff --git a/remote_list.go b/remote_list.go new file mode 100644 index 0000000..cf40ad8 --- /dev/null +++ b/remote_list.go @@ -0,0 +1,500 @@ +package nebula + +import ( + "bytes" + "net" + "sort" + "sync" +) + +// forEachFunc is used to benefit folks that want to do work inside the lock +type forEachFunc func(addr *udpAddr, preferred bool) + +// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) +type checkFuncV4 func(to *Ip4AndPort) bool +type checkFuncV6 func(to *Ip6AndPort) bool + +// CacheMap is a struct that better represents the lighthouse cache for humans +// The string key is the owners vpnIp +type CacheMap map[string]*Cache + +// Cache is the other part of CacheMap to better represent the lighthouse cache for humans +// We don't reason about ipv4 vs ipv6 here +type Cache struct { + Learned []*udpAddr `json:"learned,omitempty"` + Reported []*udpAddr `json:"reported,omitempty"` +} + +//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion +// We will never clean learned/reported information for them as it stands today + +// cache is an internal struct that splits v4 and v6 addresses inside the cache map +type cache struct { + v4 *cacheV4 + v6 *cacheV6 +} + +// cacheV4 stores learned and reported ipv4 records under cache +type cacheV4 struct { + learned *Ip4AndPort + reported []*Ip4AndPort +} + +// cacheV4 stores learned and reported ipv6 records under cache +type cacheV6 struct { + learned *Ip6AndPort + reported []*Ip6AndPort +} + +// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. +// It serves as a local cache of query replies, host update notifications, and locally learned addresses +type RemoteList struct { + // Every interaction with internals requires a lock! + sync.RWMutex + + // A deduplicated set of addresses. Any accessor should lock beforehand. + addrs []*udpAddr + + // These are maps to store v4 and v6 addresses per lighthouse + // Map key is the vpnIp of the person that told us about this the cached entries underneath. + // For learned addresses, this is the vpnIp that sent the packet + cache map[uint32]*cache + + // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. + // They should not be tried again during a handshake + badRemotes []*udpAddr + + // A flag that the cache may have changed and addrs needs to be rebuilt + shouldRebuild bool +} + +// NewRemoteList creates a new empty RemoteList +func NewRemoteList() *RemoteList { + return &RemoteList{ + addrs: make([]*udpAddr, 0), + cache: make(map[uint32]*cache), + } +} + +// Len locks and reports the size of the deduplicated address list +// The deduplication work may need to occur here, so you must pass preferredRanges +func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { + r.Rebuild(preferredRanges) + r.RLock() + defer r.RUnlock() + return len(r.addrs) +} + +// ForEach locks and will call the forEachFunc for every deduplicated address in the list +// The deduplication work may need to occur here, so you must pass preferredRanges +func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { + r.Rebuild(preferredRanges) + r.RLock() + for _, v := range r.addrs { + forEach(v, isPreferred(v.IP, preferredRanges)) + } + r.RUnlock() +} + +// CopyAddrs locks and makes a deep copy of the deduplicated address list +// The deduplication work may need to occur here, so you must pass preferredRanges +func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr { + r.Rebuild(preferredRanges) + + r.RLock() + defer r.RUnlock() + c := make([]*udpAddr, len(r.addrs)) + for i, v := range r.addrs { + c[i] = v.Copy() + } + return c +} + +// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr +// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. +// It will mark the deduplicated address list as dirty, so do not call it unless new information is available +//TODO: this needs to support the allow list list +func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) { + r.Lock() + defer r.Unlock() + if v4 := addr.IP.To4(); v4 != nil { + r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + } else { + r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + } +} + +// CopyCache locks and creates a more human friendly form of the internal address cache. +// This may contain duplicates and blocked addresses +func (r *RemoteList) CopyCache() *CacheMap { + r.RLock() + defer r.RUnlock() + + cm := make(CacheMap) + getOrMake := func(vpnIp string) *Cache { + c := cm[vpnIp] + if c == nil { + c = &Cache{ + Learned: make([]*udpAddr, 0), + Reported: make([]*udpAddr, 0), + } + cm[vpnIp] = c + } + return c + } + + for owner, mc := range r.cache { + c := getOrMake(IntIp(owner).String()) + + if mc.v4 != nil { + if mc.v4.learned != nil { + c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + } + + for _, a := range mc.v4.reported { + c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + } + } + + if mc.v6 != nil { + if mc.v6.learned != nil { + c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + } + + for _, a := range mc.v6.reported { + c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + } + } + } + + return &cm +} + +// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list +func (r *RemoteList) BlockRemote(bad *udpAddr) { + r.Lock() + defer r.Unlock() + + // Check if we already blocked this addr + if r.unlockedIsBad(bad) { + return + } + + // We copy here because we are taking something else's memory and we can't trust everything + r.badRemotes = append(r.badRemotes, bad.Copy()) + + // Mark the next interaction must recollect/dedupe + r.shouldRebuild = true +} + +// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list +func (r *RemoteList) CopyBlockedRemotes() []*udpAddr { + r.RLock() + defer r.RUnlock() + + c := make([]*udpAddr, len(r.badRemotes)) + for i, v := range r.badRemotes { + c[i] = v.Copy() + } + return c +} + +// ResetBlockedRemotes locks and clears the blocked remotes list +func (r *RemoteList) ResetBlockedRemotes() { + r.Lock() + r.badRemotes = nil + r.Unlock() +} + +// Rebuild locks and generates the deduplicated address list only if there is work to be done +// There is generally no reason to call this directly but it is safe to do so +func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { + r.Lock() + defer r.Unlock() + + // Only rebuild if the cache changed + //TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in + if r.shouldRebuild { + r.unlockedCollect() + r.shouldRebuild = false + } + + // Always re-sort, preferredRanges can change via HUP + r.unlockedSort(preferredRanges) +} + +// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list +func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool { + for _, v := range r.badRemotes { + if v.Equals(remote) { + return true + } + } + return false +} + +// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the +// deduplicated address list as dirty +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) { + r.shouldRebuild = true + r.unlockedGetOrMakeV4(ownerVpnIp).learned = to +} + +// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided +// and marks the deduplicated address list as dirty +func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) { + r.shouldRebuild = true + c := r.unlockedGetOrMakeV4(ownerVpnIp) + + // Reset the slice + c.reported = c.reported[:0] + + // We can't take their array but we can take their pointers + for _, v := range to[:minInt(len(to), MaxRemotes)] { + if check(v) { + c.reported = append(c.reported, v) + } + } +} + +// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner +// This is only useful for establishing static hosts +func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) { + r.shouldRebuild = true + c := r.unlockedGetOrMakeV4(ownerVpnIp) + + // We are doing the easy append because this is rarely called + c.reported = append([]*Ip4AndPort{to}, c.reported...) + if len(c.reported) > MaxRemotes { + c.reported = c.reported[:MaxRemotes] + } +} + +// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the +// deduplicated address list as dirty +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) { + r.shouldRebuild = true + r.unlockedGetOrMakeV6(ownerVpnIp).learned = to +} + +// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided +// and marks the deduplicated address list as dirty +func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) { + r.shouldRebuild = true + c := r.unlockedGetOrMakeV6(ownerVpnIp) + + // Reset the slice + c.reported = c.reported[:0] + + // We can't take their array but we can take their pointers + for _, v := range to[:minInt(len(to), MaxRemotes)] { + if check(v) { + c.reported = append(c.reported, v) + } + } +} + +// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner +// This is only useful for establishing static hosts +func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) { + r.shouldRebuild = true + c := r.unlockedGetOrMakeV6(ownerVpnIp) + + // We are doing the easy append because this is rarely called + c.reported = append([]*Ip6AndPort{to}, c.reported...) + if len(c.reported) > MaxRemotes { + c.reported = c.reported[:MaxRemotes] + } +} + +// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. +// The caller must dirty the learned address cache if required +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 { + am := r.cache[ownerVpnIp] + if am == nil { + am = &cache{} + r.cache[ownerVpnIp] = am + } + // Avoid occupying memory for v6 addresses if we never have any + if am.v4 == nil { + am.v4 = &cacheV4{} + } + return am.v4 +} + +// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. +// The caller must dirty the learned address cache if required +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 { + am := r.cache[ownerVpnIp] + if am == nil { + am = &cache{} + r.cache[ownerVpnIp] = am + } + // Avoid occupying memory for v4 addresses if we never have any + if am.v6 == nil { + am.v6 = &cacheV6{} + } + return am.v6 +} + +// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list. +// The result of this function can contain duplicates. unlockedSort handles cleaning it. +func (r *RemoteList) unlockedCollect() { + addrs := r.addrs[:0] + + for _, c := range r.cache { + if c.v4 != nil { + if c.v4.learned != nil { + u := NewUDPAddrFromLH4(c.v4.learned) + if !r.unlockedIsBad(u) { + addrs = append(addrs, u) + } + } + + for _, v := range c.v4.reported { + u := NewUDPAddrFromLH4(v) + if !r.unlockedIsBad(u) { + addrs = append(addrs, u) + } + } + } + + if c.v6 != nil { + if c.v6.learned != nil { + u := NewUDPAddrFromLH6(c.v6.learned) + if !r.unlockedIsBad(u) { + addrs = append(addrs, u) + } + } + + for _, v := range c.v6.reported { + u := NewUDPAddrFromLH6(v) + if !r.unlockedIsBad(u) { + addrs = append(addrs, u) + } + } + } + } + + r.addrs = addrs +} + +// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list +func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { + n := len(r.addrs) + if n < 2 { + return + } + + lessFunc := func(i, j int) bool { + a := r.addrs[i] + b := r.addrs[j] + // Preferred addresses first + + aPref := isPreferred(a.IP, preferredRanges) + bPref := isPreferred(b.IP, preferredRanges) + switch { + case aPref && !bPref: + // If i is preferred and j is not, i is less than j + return true + + case !aPref && bPref: + // If j is preferred then i is not due to the else, i is not less than j + return false + + default: + // Both i an j are either preferred or not, sort within that + } + + // ipv6 addresses 2nd + a4 := a.IP.To4() + b4 := b.IP.To4() + switch { + case a4 == nil && b4 != nil: + // If i is v6 and j is v4, i is less than j + return true + + case a4 != nil && b4 == nil: + // If j is v6 and i is v4, i is not less than j + return false + + case a4 != nil && b4 != nil: + // Special case for ipv4, a4 and b4 are not nil + aPrivate := isPrivateIP(a4) + bPrivate := isPrivateIP(b4) + switch { + case !aPrivate && bPrivate: + // If i is a public ip (not private) and j is a private ip, i is less then j + return true + + case aPrivate && !bPrivate: + // If j is public (not private) then i is private due to the else, i is not less than j + return false + + default: + // Both i an j are either public or private, sort within that + } + + default: + // Both i an j are either ipv4 or ipv6, sort within that + } + + // lexical order of ips 3rd + c := bytes.Compare(a.IP, b.IP) + if c == 0 { + // Ips are the same, Lexical order of ports 4th + return a.Port < b.Port + } + + // Ip wasn't the same + return c < 0 + } + + // Sort it + sort.Slice(r.addrs, lessFunc) + + // Deduplicate + a, b := 0, 1 + for b < n { + if !r.addrs[a].Equals(r.addrs[b]) { + a++ + if a != b { + r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] + } + } + b++ + } + + r.addrs = r.addrs[:a+1] + return +} + +// minInt returns the minimum integer of a or b +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +// isPreferred returns true of the ip is contained in the preferredRanges list +func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { + //TODO: this would be better in a CIDR6Tree + for _, p := range preferredRanges { + if p.Contains(ip) { + return true + } + } + return false +} + +var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") +var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") +var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") + +// isPrivateIP returns true if the ip is contained by a rfc 1918 private range +func isPrivateIP(ip net.IP) bool { + //TODO: another great cidrtree option + //TODO: Private for ipv6 or just let it ride? + return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) +} diff --git a/remote_list_test.go b/remote_list_test.go new file mode 100644 index 0000000..bceb16c --- /dev/null +++ b/remote_list_test.go @@ -0,0 +1,228 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRemoteList_Rebuild(t *testing.T) { + rl := NewRemoteList() + rl.unlockedSetV4( + 0, + []*Ip4AndPort{ + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped + {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped + {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped + {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe + {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe + }, + func(*Ip4AndPort) bool { return true }, + ) + + rl.unlockedSetV6( + 1, + []*Ip6AndPort{ + NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped + NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped + NewIp6AndPort(net.ParseIP("1:100::1"), 1), + NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + }, + func(*Ip6AndPort) bool { return true }, + ) + + rl.Rebuild([]*net.IPNet{}) + assert.Len(t, rl.addrs, 10, "addrs contains too many entries") + + // ipv6 first, sorted lexically within + assert.Equal(t, "[1::1]:1", rl.addrs[0].String()) + assert.Equal(t, "[1::1]:2", rl.addrs[1].String()) + assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String()) + + // ipv4 last, sorted by public first, then private, lexically within them + assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String()) + assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String()) + assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String()) + assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String()) + assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String()) + assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String()) + assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) + + // Now ensure we can hoist ipv4 up + _, ipNet, err := net.ParseCIDR("0.0.0.0/0") + assert.NoError(t, err) + rl.Rebuild([]*net.IPNet{ipNet}) + assert.Len(t, rl.addrs, 10, "addrs contains too many entries") + + // ipv4 first, public then private, lexically within them + assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String()) + assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String()) + assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String()) + assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String()) + assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String()) + assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String()) + assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String()) + + // ipv6 last, sorted by public first, then private, lexically within them + assert.Equal(t, "[1::1]:1", rl.addrs[7].String()) + assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) + assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) + + // Ensure we can hoist a specific ipv4 range over anything else + _, ipNet, err = net.ParseCIDR("172.17.0.0/16") + assert.NoError(t, err) + rl.Rebuild([]*net.IPNet{ipNet}) + assert.Len(t, rl.addrs, 10, "addrs contains too many entries") + + // Preferred ipv4 first + assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String()) + assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String()) + + // ipv6 next + assert.Equal(t, "[1::1]:1", rl.addrs[2].String()) + assert.Equal(t, "[1::1]:2", rl.addrs[3].String()) + assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String()) + + // the remaining ipv4 last + assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String()) + assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String()) + assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String()) + assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String()) + assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) +} + +func BenchmarkFullRebuild(b *testing.B) { + rl := NewRemoteList() + rl.unlockedSetV4( + 0, + []*Ip4AndPort{ + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, + {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port + }, + func(*Ip4AndPort) bool { return true }, + ) + + rl.unlockedSetV6( + 0, + []*Ip6AndPort{ + NewIp6AndPort(net.ParseIP("1::1"), 1), + NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port + NewIp6AndPort(net.ParseIP("1:100::1"), 1), + NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + }, + func(*Ip6AndPort) bool { return true }, + ) + + b.Run("no preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.shouldRebuild = true + rl.Rebuild([]*net.IPNet{}) + } + }) + + _, ipNet, err := net.ParseCIDR("172.17.0.0/16") + assert.NoError(b, err) + b.Run("1 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.shouldRebuild = true + rl.Rebuild([]*net.IPNet{ipNet}) + } + }) + + _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") + assert.NoError(b, err) + b.Run("2 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.shouldRebuild = true + rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + } + }) + + _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") + assert.NoError(b, err) + b.Run("3 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.shouldRebuild = true + rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + } + }) +} + +func BenchmarkSortRebuild(b *testing.B) { + rl := NewRemoteList() + rl.unlockedSetV4( + 0, + []*Ip4AndPort{ + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, + {Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101}, + {Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe + {Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port + }, + func(*Ip4AndPort) bool { return true }, + ) + + rl.unlockedSetV6( + 0, + []*Ip6AndPort{ + NewIp6AndPort(net.ParseIP("1::1"), 1), + NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port + NewIp6AndPort(net.ParseIP("1:100::1"), 1), + NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + }, + func(*Ip6AndPort) bool { return true }, + ) + + b.Run("no preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.shouldRebuild = true + rl.Rebuild([]*net.IPNet{}) + } + }) + + _, ipNet, err := net.ParseCIDR("172.17.0.0/16") + rl.Rebuild([]*net.IPNet{ipNet}) + + assert.NoError(b, err) + b.Run("1 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.Rebuild([]*net.IPNet{ipNet}) + } + }) + + _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") + rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + + assert.NoError(b, err) + b.Run("2 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + } + }) + + _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") + rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + + assert.NoError(b, err) + b.Run("3 preferred", func(b *testing.B) { + for i := 0; i < b.N; i++ { + rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + } + }) +} diff --git a/ssh.go b/ssh.go index bd49c3d..2516ee0 100644 --- a/ssh.go +++ b/ssh.go @@ -10,8 +10,8 @@ import ( "os" "reflect" "runtime/pprof" + "sort" "strings" - "sync/atomic" "syscall" "github.com/sirupsen/logrus" @@ -335,8 +335,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error return nil } - hostMap.RLock() - defer hostMap.RUnlock() + hm := listHostMap(hostMap) + sort.Slice(hm, func(i, j int) bool { + return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0 + }) if fs.Json || fs.Pretty { js := json.NewEncoder(w.GetWriter()) @@ -344,35 +346,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error js.SetIndent("", " ") } - d := make([]m, len(hostMap.Hosts)) - x := 0 - var h m - for _, v := range hostMap.Hosts { - h = m{ - "vpnIp": int2ip(v.hostId), - "localIndex": v.localIndexId, - "remoteIndex": v.remoteIndexId, - "remoteAddrs": v.CopyRemotes(), - "cachedPackets": len(v.packetStore), - "cert": v.GetCert(), - } - - if v.ConnectionState != nil { - h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter) - } - - d[x] = h - x++ - } - - err := js.Encode(d) + err := js.Encode(hm) if err != nil { //TODO return nil } + } else { - for i, v := range hostMap.Hosts { - err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes())) + for _, v := range hm { + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs)) if err != nil { return err } @@ -389,8 +371,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr return nil } + type lighthouseInfo struct { + VpnIP net.IP `json:"vpnIp"` + Addrs *CacheMap `json:"addrs"` + } + lightHouse.RLock() - defer lightHouse.RUnlock() + addrMap := make([]lighthouseInfo, len(lightHouse.addrMap)) + x := 0 + for k, v := range lightHouse.addrMap { + addrMap[x] = lighthouseInfo{ + VpnIP: int2ip(k), + Addrs: v.CopyCache(), + } + x++ + } + lightHouse.RUnlock() + + sort.Slice(addrMap, func(i, j int) bool { + return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0 + }) if fs.Json || fs.Pretty { js := json.NewEncoder(w.GetWriter()) @@ -398,27 +398,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr js.SetIndent("", " ") } - d := make([]m, len(lightHouse.addrMap)) - x := 0 - var h m - for vpnIp, v := range lightHouse.addrMap { - h = m{ - "vpnIp": int2ip(vpnIp), - "addrs": TransformLHReplyToUdpAddrs(v), - } - - d[x] = h - x++ - } - - err := js.Encode(d) + err := js.Encode(addrMap) if err != nil { //TODO return nil } + } else { - for vpnIp, v := range lightHouse.addrMap { - err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), TransformLHReplyToUdpAddrs(v))) + for _, v := range addrMap { + b, err := json.Marshal(v.Addrs) + if err != nil { + return err + } + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b))) if err != nil { return err } @@ -469,8 +461,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - ips, _ := ifce.lightHouse.Query(vpnIp, ifce) - return json.NewEncoder(w.GetWriter()).Encode(ips) + return json.NewEncoder(w.GetWriter()).Encode(ifce.lightHouse.Query(vpnIp, ifce).CopyCache()) } func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { @@ -727,7 +718,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp) if err != nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -737,7 +728,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr enc.SetIndent("", " ") } - return enc.Encode(hostInfo) + return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges)) } func sshReload(fs interface{}, a []string, w sshd.StringWriter) error { diff --git a/tun_tester.go b/tun_tester.go index 7c10cd5..01b3c9d 100644 --- a/tun_tester.go +++ b/tun_tester.go @@ -41,9 +41,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (c *Tun) Send(packet []byte) { - if c.l.Level >= logrus.DebugLevel { - c.l.Debug("Tun injecting packet") - } + c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet") c.rxPackets <- packet } diff --git a/udp_all.go b/udp_all.go index 70827aa..05883b9 100644 --- a/udp_all.go +++ b/udp_all.go @@ -13,8 +13,8 @@ type udpAddr struct { } func NewUDPAddr(ip net.IP, port uint16) *udpAddr { - addr := udpAddr{IP: make([]byte, len(ip)), Port: port} - copy(addr.IP, ip) + addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port} + copy(addr.IP, ip.To16()) return &addr } @@ -22,7 +22,7 @@ func NewUDPAddrFromString(s string) *udpAddr { ip, port, err := parseIPAndPort(s) //TODO: handle err _ = err - return &udpAddr{IP: ip, Port: port} + return &udpAddr{IP: ip.To16(), Port: port} } func (ua *udpAddr) Equals(t *udpAddr) bool { diff --git a/udp_linux.go b/udp_linux.go index c49aea5..6eb22f4 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -97,40 +97,21 @@ func (u *udpConn) GetSendBuffer() (int, error) { } func (u *udpConn) LocalAddr() (*udpAddr, error) { - var rsa unix.RawSockaddrAny - var rLen = unix.SizeofSockaddrAny - - _, _, err := unix.Syscall( - unix.SYS_GETSOCKNAME, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unsafe.Pointer(&rLen)), - ) - - if err != 0 { + sa, err := unix.Getsockname(u.sysFd) + if err != nil { return nil, err } addr := &udpAddr{} - if rsa.Addr.Family == unix.AF_INET { - pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(&rsa)) - addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) - copy(addr.IP, pp.Addr[:]) - - } else if rsa.Addr.Family == unix.AF_INET6 { - //TODO: this cast sucks and we can do better - pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(&rsa)) - addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) - copy(addr.IP, pp.Addr[:]) - - } else { - addr.Port = 0 - addr.IP = []byte{} + switch sa := sa.(type) { + case *unix.SockaddrInet4: + addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() + addr.Port = uint16(sa.Port) + case *unix.SockaddrInet6: + addr.IP = sa.Addr[0:] + addr.Port = uint16(sa.Port) } - //TODO: Just use this instead? - //a, b := unix.Getsockname(u.sysFd) - return addr, nil } diff --git a/udp_tester.go b/udp_tester.go index 622c6ac..cc7e181 100644 --- a/udp_tester.go +++ b/udp_tester.go @@ -3,6 +3,7 @@ package nebula import ( + "fmt" "net" "github.com/sirupsen/logrus" @@ -53,7 +54,14 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *udpConn) Send(packet *UdpPacket) { - u.l.Infof("UDP injecting packet %+v", packet) + h := &Header{} + if err := h.Parse(packet.Data); err != nil { + panic(err) + } + u.l.WithField("header", h). + WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("dataLen", len(packet.Data)). + Info("UDP receiving injected packet") u.rxPackets <- packet }