Refactor remotes and handshaking to give every address a fair shot (#437)

This commit is contained in:
Nathan Brown 2021-04-14 13:50:09 -05:00 committed by GitHub
parent 20bef975cd
commit 710df6a876
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1561 additions and 1385 deletions

View File

@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() {
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
return listHostMap(c.f.handshakeManager.pendingHostMap)
} else {
hm = c.f.hostMap
return listHostMap(c.f.hostMap)
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
return nil
}
ch := copyHostInfo(h)
ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
return &ch
}
@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
return &ch
}
@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
return
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.CopyRemotes(),
CachedPackets: len(h.packetStore),
MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CachedPackets: len(h.packetStore),
}
if h.ConnectionState != nil {
chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
}
if c := h.GetCert(); c != nil {
@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
return chi
}
func listHostMap(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
return hosts
}

View File

@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
Signature: []byte{1, 2, 1, 2, 1, 3},
}
remotes := []*udpAddr{remote1, remote2}
remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
},
@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
},
@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []*udpAddr{remote1, remote2},
RemoteAddrs: []*udpAddr{remote2, remote1},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,

View File

@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false)
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
iVpnIp := ip2int(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else {
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
}
}
// GetFromTun will pull a packet off the tun side of nebula
@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
udp.SetNetworkLayerForChecksum(&ip)
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
if err != nil {
panic(err)
}
@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String()
}
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
if !ok {
return false
}
c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
return true
}

View File

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
)
func TestGoodHandshake(t *testing.T) {
@ -23,35 +24,35 @@ func TestGoodHandshake(t *testing.T) {
myControl.Start()
theirControl.Start()
// Send a udp packet through to begin standing up the tunnel, this should come out the other side
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
// Have them consume my stage 0 packet. They have a tunnel now
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
// Get their stage 1 packet so that we can play with it
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
// I consume a garbage packet with a proper nebula header for our tunnel
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
myControl.InjectUDPPacket(badPacket)
// Have me consume their real stage 1 packet. I have a tunnel now
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
// Wait until we see my cached packet come through
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
// Make sure our host infos are correct
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
// Get that cached packet and make sure it looks right
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
// Do a bidirectional tunnel test
t.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
myControl.Stop()
@ -62,14 +63,17 @@ func TestGoodHandshake(t *testing.T) {
func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99})
// The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though)
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
// Add their real udp addr, which should be tried after evil. Doing this first because learned addresses are prepended
// Add their real udp addr, which should be tried after evil.
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. This will now be the first attempted ip
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
@ -80,137 +84,98 @@ func TestWrongResponderHandshake(t *testing.T) {
theirControl.Start()
evilControl.Start()
t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
r.OnceFrom(myControl)
r.OnceFrom(evilControl)
r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
h := &nebula.Header{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}
t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us")
assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
return router.RouteAndExit
}
return router.KeepRouting
})
//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
t.Log("Lets let the messages fly, this time we should have a tunnel with them")
r.OnceFrom(myControl)
r.OnceFrom(theirControl)
t.Log("I should now have a tunnel with them now and my original packet should get there")
r.RouteUntilAfterMsgType(myControl, 1, 0)
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("I should now have a proper tunnel with them")
t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
t.Log("Lets make sure evil is still good")
assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone
t.Log("Success!")
//TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message
// what we really need here is a way to exit all the go routines loops (there are many)
//myControl.Stop()
//theirControl.Stop()
myControl.Stop()
theirControl.Stop()
}
////TODO: We need to test lies both as the race winner and race loser
//func TestManyWrongResponderHandshake(t *testing.T) {
// ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
//
// myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99})
// theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
// evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1})
//
// t.Log("Build a router so we don't have to reason who gets which packet")
// r := newRouter(myControl, theirControl, evilControl)
//
// t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit")
// for i := 0; i < 10; i++ {
// addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i}
// myControl.InjectLightHouseAddr(theirVpnIp, &addr)
// // We also need to tell our router about it
// r.AddRoute(addr.IP, uint16(addr.Port), evilControl)
// }
//
// // Start the servers
// myControl.Start()
// theirControl.Start()
// evilControl.Start()
//
// t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
// myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
//
// t.Log("We need to spin until we get to the right remote for them")
// getOut := false
// injected := false
// for {
// t.Log("Routing for me and evil while we work through the bad ips")
// r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType {
// // We should stop routing right after we see a packet coming from us to them
// if *receiver == *theirControl {
// getOut = true
// return drainAndExit
// }
//
// // We need to poke our real ip in at some point, this is a well protected check looking for that moment
// if *receiver == *evilControl {
// hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true)
// if !injected && len(hi.RemoteAddrs) == 1 {
// t.Log("I am on my last ip for them, time to inject the real one into my lighthouse")
// myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// injected = true
// }
// return drainAndExit
// }
//
// return keepRouting
// })
//
// if getOut {
// break
// }
//
// r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit)
// }
//
// t.Log("I should have a tunnel with evil and them, evil should not have a cached packet")
// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
// evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false)
// realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)}
//
// t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr)
// assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
//
// //t.Log("Draining everyones packets")
// //r.Drain(theirControl)
// //r.DrainAll(myControl, theirControl, evilControl)
// //
// //go func() {
// // for {
// // time.Sleep(10 * time.Millisecond)
// // t.Log(len(theirControl.GetUDPTxChan()))
// // t.Log(len(theirControl.GetTunTxChan()))
// // t.Log(len(myControl.GetUDPTxChan()))
// // t.Log(len(evilControl.GetUDPTxChan()))
// // t.Log("=====")
// // }
// //}()
//
// t.Log("I should have a tunnel with them now and my original packet should get there")
// r.RouteUntilAfterMsgType(myControl, 1, 0)
// myCachedPacket := theirControl.GetFromTun(true)
//
// t.Log("Got the cached packet, lets test the tunnel")
// assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
//
// t.Log("Testing tunnels with them")
// assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
// assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
//
// t.Log("Testing tunnels with evil")
// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
//
// //TODO: assert hostmaps for everyone
//}
func Test_Case1_Stage1Race(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
// Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(myControl, theirControl)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
theirHsForMe := theirControl.GetFromUDP(true)
t.Log("Now inject both stage 1 handshake packets")
myControl.InjectUDPPacket(theirHsForMe)
theirControl.InjectUDPPacket(myHsForThem)
//TODO: they should win, grab their index for me and make sure I use it in the end.
t.Log("They should not have a stage 2 (won the race) but I should send one")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Route for me until I send a message packet to them")
myControl.WaitForType(1, 0, theirControl)
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("Route for them until I send a message packet to me")
theirControl.WaitForType(1, 0, myControl)
t.Log("Their cached packet should be received by me")
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
t.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
}
//TODO: add a test with many lies

View File

@ -64,6 +64,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.IP.String(),
"port": udpAddr.Port,

3
e2e/router/doc.go Normal file
View File

@ -0,0 +1,3 @@
package router
// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

View File

@ -5,6 +5,7 @@ package router
import (
"fmt"
"net"
"reflect"
"strconv"
"sync"
@ -28,18 +29,18 @@ type R struct {
sync.Mutex
}
type exitType int
type ExitType int
const (
// Keeps routing, the function will get called again on the next packet
keepRouting exitType = 0
KeepRouting ExitType = 0
// Does not route this packet and exits immediately
exitNow exitType = 1
ExitNow ExitType = 1
// Routes this packet and exits immediately afterwards
routeAndExit exitType = 2
RouteAndExit ExitType = 2
)
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
func NewR(controls ...*nebula.Control) *R {
r := &R{
@ -77,8 +78,8 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
// OnceFrom will route a single packet from sender then return
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) OnceFrom(sender *nebula.Control) {
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType {
return routeAndExit
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
return RouteAndExit
})
}
@ -116,7 +117,6 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
//TODO: is this RouteWhile?
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
h := &nebula.Header{}
for {
@ -136,16 +136,16 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
e := whatDo(p, receiver)
switch e {
case exitNow:
case ExitNow:
r.Unlock()
return
case routeAndExit:
case RouteAndExit:
receiver.InjectUDPPacket(p)
r.Unlock()
return
case keepRouting:
case KeepRouting:
receiver.InjectUDPPacket(p)
default:
@ -160,35 +160,135 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
h := &nebula.Header{}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if err := h.Parse(p.Data); err != nil {
panic(err)
}
if h.Type == msgType && h.Subtype == subType {
return routeAndExit
return RouteAndExit
}
return keepRouting
return KeepRouting
})
}
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) {
if finish == keepRouting {
finish = routeAndExit
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
if finish == KeepRouting {
finish = RouteAndExit
}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
return finish
}
return keepRouting
return KeepRouting
})
}
// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
// whatDo can return:
// - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
for {
x, rx, _ := reflect.Select(sc)
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
e := whatDo(p, receiver)
switch e {
case ExitNow:
r.Unlock()
return
case RouteAndExit:
receiver.InjectUDPPacket(p)
r.Unlock()
return
case KeepRouting:
receiver.InjectUDPPacket(p)
default:
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
}
}
// FlushAll will route for every registered controller, exiting once there are no packets left to route
func (r *R) FlushAll() {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
// Add a default case to exit when nothing is left to send
sc = append(sc, reflect.SelectCase{
Dir: reflect.SelectDefault,
Chan: reflect.Value{},
Send: reflect.Value{},
})
for {
x, rx, ok := reflect.Select(sc)
if !ok {
return
}
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
r.Unlock()
}
}
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
@ -216,6 +316,5 @@ func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Con
return c
}
//TODO: call receive hooks!
return r.controls[toAddr]
}

View File

@ -202,16 +202,16 @@ logging:
# Handshake Manger Settings
#handshakes:
# Total time to try a handshake = sequence of `try_interval * retries`
# With 100ms interval and 20 retries it is 23.5 seconds
# Handshakes are sent to all known addresses at each interval with a linear backoff,
# Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
#try_interval: 100ms
#retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Nebula security group configuration
firewall:
conntrack:

View File

@ -14,14 +14,10 @@ import (
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil {
ips, err := f.lightHouse.Query(vpnIp, f)
if err != nil {
//l.Debugln(err)
}
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
f.lightHouse.QueryServer(vpnIp, f)
}
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
@ -69,7 +65,6 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
}
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
@ -125,13 +120,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hostinfo := &HostInfo{
ConnectionState: ci,
Remotes: []*udpAddr{},
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
}
hostinfo.Lock()
defer hostinfo.Unlock()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
@ -182,16 +179,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(addr)
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert)
hostinfo.Lock()
defer hostinfo.Unlock()
// Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
@ -214,6 +206,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win
// handshake avoidance
//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
//TODO: if not new send a test packet like old
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
@ -234,6 +230,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
Error("Failed to add HostInfo due to localIndex collision")
return
case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Prevented a pending handshake race")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
@ -286,6 +291,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Handshake is already complete")
//TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false
}
@ -334,17 +341,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
// Ensure the right host responded
if vpnIP != hostinfo.hostId {
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil {
// We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now
f.handshakeManager.pendingHostMap.DeleteHostInfo(ho)
}
// Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
@ -354,26 +357,28 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
newHostInfo.Lock()
// Block the current used address
newHostInfo.unlockedBlockRemote(addr)
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
// If this is an ongoing issue our previous hostmap will have some bad ips too
for _, v := range hostinfo.badRemotes {
newHostInfo.unlockedBlockRemote(v)
}
//TODO: this is me enabling tests
newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)).
WithField("remotes", newHostInfo.Remotes).
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()
// Set the current hostId to the new vpnIp
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.hostId = vpnIP
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
return true
}
// Mark packet 2 as seen so it doesn't show up as missed

View File

@ -12,12 +12,8 @@ import (
)
const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 20
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
DefaultHandshakeWaitRotation = 5
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64
)
@ -25,7 +21,6 @@ var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer,
}
)
@ -33,45 +28,36 @@ var (
type HandshakeConfig struct {
tryInterval time.Duration
retries int
waitRotation int
triggerBuffer int
messageMetrics *MessageMetrics
}
type HandshakeManager struct {
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
config HandshakeConfig
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
l *logrus.Logger
}
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
l: l,
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
l: l,
}
}
@ -84,7 +70,6 @@ func (c *HandshakeManager) Run(f EncWriter) {
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
}
}
}
@ -109,84 +94,84 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
hostinfo.Lock()
defer hostinfo.Unlock()
// If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return
}
hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > c.config.waitRotation {
hostinfo.rotateRemote()
}
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
}
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
} else {
// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
if hostinfo.HandshakeComplete {
// Ensure we don't exist in the pending hostmap anymore since we have completed
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now)
for {
ep := c.InboundHandshakeTimer.Purge()
if ep == nil {
break
// Check if we have a handshake packet to transmit yet
if !hostinfo.HandshakeReady {
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
return
}
// If we are out of time, clean up
if hostinfo.HandshakeCounter >= c.config.retries {
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
Info("Handshake timed out")
//TODO: emit metrics
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
// optimization for a fast lighthouse reply
//TODO: it would feel better to do this once, anytime, as our delay increases over time
if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
if hostinfo.remotes == nil {
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
}
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIP, f)
}
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udpAddr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
sentTo = append(sentTo, addr)
}
index := ep.(uint32)
})
c.pendingHostMap.DeleteIndex(index)
hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
// Increment the counter to increase our delay, linear backoff
hostinfo.HandshakeCounter++
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
//TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
}
@ -194,6 +179,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
//TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
return hostinfo
@ -203,6 +189,7 @@ var (
ErrExistingHostInfo = errors.New("existing hostinfo")
ErrAlreadySeen = errors.New("already seen")
ErrLocalIndexCollision = errors.New("local index collision")
ErrExistingHandshake = errors.New("existing handshake")
)
// CheckAndComplete checks for any conflicts in the main and pending hostmap
@ -217,17 +204,21 @@ var (
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
c.pendingHostMap.RLock()
defer c.pendingHostMap.RUnlock()
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
// Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen
}
if !overwrite {
// It's a new handshake and we lost the race
return existingHostInfo, ErrExistingHostInfo
}
}
@ -237,6 +228,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
@ -252,7 +244,24 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
Info("New host shadows existing host remoteIndex")
}
// Check if we are also handshaking with this vpn ip
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
if found && pendingHostInfo != nil {
if !overwrite {
// We won, let our pending handshake win
return pendingHostInfo, ErrExistingHandshake
}
// We lost, take this handshake and move any cached packets over so they get sent
pendingHostInfo.ConnectionState.queueLock.Lock()
hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
pendingHostInfo.ConnectionState.queueLock.Unlock()
pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
}
if existingHostInfo != nil {
hostinfo.logger(c.l).Info("Race lost, taking new handshake")
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
@ -267,6 +276,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// won't have a localIndexId collision because we already have an entry in the
// pendingHostMap
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
@ -288,6 +299,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
}
c.mainHostMap.addHostInfo(hostinfo, f)
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
}
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
@ -359,3 +371,7 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
}
return index, nil
}
func hsTimeout(tries int, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
}

View File

@ -8,66 +8,12 @@ import (
"github.com/stretchr/testify/assert"
)
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
var indexes = make([]uint32, 4)
var hostinfo = make([]*HostInfo, len(indexes))
for i := range indexes {
hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
}
// Add four indexes
for i := range indexes {
err := blah.AddIndexHostInfo(hostinfo[i])
assert.NoError(t, err)
indexes[i] = hostinfo[i].localIndexId
blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
}
// Confirm they are in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds
for i := 1; i <= DefaultHandshakeRetries; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
// Confirm they are still in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Jump ahead 4 more seconds
next_tick := now.Add(12 * time.Second)
blah.NextInboundHandshakeTimerTick(next_tick)
// Confirm they have been removed
for _, v := range indexes {
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
}
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
@ -77,39 +23,30 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
// Add four "IPs" - which are just uint32s
for _, v := range ips {
blah.AddVpnIP(v)
}
i := blah.AddVpnIP(ip)
i.remotes = NewRemoteList()
i.HandshakeReady = true
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead `HandshakeRetries` ticks
cumulative := time.Duration(0)
for i := 0; i <= DefaultHandshakeRetries+1; i++ {
cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
// Confirm they are in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
for i := 1; i <= DefaultHandshakeRetries+1; i++ {
now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
blah.NextOutboundHandshakeTimerTick(now, mw)
}
// Confirm they are still in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead 1 more second
cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
// Tick 1 more time, a minute will certainly flush it out
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
// Confirm they have been removed
for _, v := range ips {
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
}
assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
@ -121,7 +58,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{}
lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
@ -130,28 +67,25 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip)
hi := blah.AddVpnIP(ip)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
// Trigger the same method the channel will
// Trigger the same method the channel will but, this should set our remotes pointer
blah.handleOutbound(ip, mw, true)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
// Make sure the trigger doesn't schedule another timer entry
// Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
uaddr := NewUDPAddrFromString("10.1.1.1:4242")
lh.addrMap = map[uint32]*ip4And6{}
lh.addrMap[ip] = &ip4And6{
v4: []*Ip4AndPort{NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))},
v6: []*Ip6AndPort{},
}
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
// This should trigger the hostmap to populate the hostinfo
// We now have remotes but only the first trigger should have pushed things forward
blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
@ -166,100 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
return c
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
assert.NotZero(t, hostinfo.localIndexId)
assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap
cumulative := time.Duration(0)
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
cumulative += DefaultHandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(time.Duration(i) * time.Second)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
next_tick := now.Add(cumulative)
l.Infoln(cumulative, next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
*/
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
// Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
}
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}
func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}

View File

@ -1,7 +1,6 @@
package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
@ -16,6 +15,7 @@ import (
//const ProbeLen = 100
const PromoteEvery = 1000
const ReQueryEvery = 5000
const MaxRemotes = 10
// How long we should prevent roaming back to the previous IP.
@ -30,7 +30,6 @@ type HostMap struct {
Hosts map[uint32]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree
metricsEnabled bool
l *logrus.Logger
@ -40,25 +39,21 @@ type HostInfo struct {
sync.RWMutex
remote *udpAddr
Remotes []*udpAddr
remotes *RemoteList
promoteCounter uint32
ConnectionState *ConnectionState
handshakeStart time.Time
HandshakeReady bool
HandshakeCounter int
HandshakeComplete bool
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
hostId uint32
recvError int
remoteCidr *CIDRTree
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
// with a handshake
@ -88,7 +83,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(),
l: l,
}
@ -131,7 +125,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
if _, ok := hm.Hosts[vpnIP]; !ok {
hm.RUnlock()
h = &HostInfo{
Remotes: []*udpAddr{},
promoteCounter: 0,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
@ -239,7 +232,11 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
hm.unlockedDeleteHostInfo(hostinfo)
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap.
@ -262,7 +259,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
if len(hm.RemoteIndexes) == 0 {
hm.RemoteIndexes = map[uint32]*HostInfo{}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@ -294,30 +290,6 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
}
}
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
hm.Lock()
i, v := hm.Hosts[vpnIp]
if v {
i.AddRemote(remote)
} else {
i = &HostInfo{
Remotes: []*udpAddr{remote.Copy()},
promoteCounter: 0,
hostId: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
i.remote = i.Remotes[0]
hm.Hosts[vpnIp] = i
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
return i
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil)
}
@ -331,12 +303,13 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil {
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
//fmt.Println(h.remote)
hm.RUnlock()
return h, nil
} else {
//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
hm.RUnlock()
@ -362,11 +335,8 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
// We already have the hm Lock when this is called, so make sure to not call
// any other methods that might try to grab it again
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
remoteCert := hostinfo.ConnectionState.peerCert
ip := ip2int(remoteCert.Details.Ips[0].IP)
f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
@ -381,38 +351,21 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
func (hm *HostMap) ClearRemotes(vpnIP uint32) {
hm.Lock()
i := hm.Hosts[vpnIP]
if i == nil {
hm.Unlock()
return
}
i.remote = nil
i.Remotes = nil
hm.Unlock()
}
func (hm *HostMap) SetDefaultRoute(ip uint32) {
hm.defaultRoute = ip
}
func (hm *HostMap) PunchList() []*udpAddr {
var list []*udpAddr
// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
// The caller can then do the its work outside of the read lock
func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Hosts {
for _, r := range v.Remotes {
list = append(list, r)
if v.remotes != nil {
rl = append(rl, v.remotes)
}
// if h, ok := hm.Hosts[vpnIp]; ok {
// hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
//fmt.Println(h.remote)
// }
}
hm.RUnlock()
return list
return rl
}
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
@ -421,13 +374,18 @@ func (hm *HostMap) Punchy(conn *udpConn) {
metricsTxPunchy = metrics.NilCounter{}
}
var remotes []*RemoteList
b := []byte{1}
for {
for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo(b, addr)
remotes = hm.punchList(remotes[:0])
for _, rl := range remotes {
//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
metricsTxPunchy.Inc(1)
conn.WriteTo(b, addr)
}
}
time.Sleep(time.Second * 30)
time.Sleep(time.Second * 10)
}
}
@ -438,38 +396,15 @@ func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
}
}
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
"remotes": i.Remotes,
"promote_counter": i.promoteCounter,
"connection_state": i.ConnectionState,
"handshake_start": i.handshakeStart,
"handshake_ready": i.HandshakeReady,
"handshake_counter": i.HandshakeCounter,
"handshake_complete": i.HandshakeComplete,
"handshake_packet": i.HandshakePacket,
"packet_store": i.packetStore,
"remote_index": i.remoteIndexId,
"local_index": i.localIndexId,
"host_id": int2ip(i.hostId),
"receive_errors": i.recvError,
"last_roam": i.lastRoam,
"last_roam_remote": i.lastRoamRemote,
})
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs
}
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
if i.remote == nil {
i.ForcePromoteBest(preferredRanges)
return
}
if atomic.AddUint32(&i.promoteCounter, 1)%PromoteEvery == 0 {
c := atomic.AddUint32(&i.promoteCounter, 1)
if c%PromoteEvery == 0 {
// return early if we are already on a preferred remote
rIP := i.remote.IP
for _, l := range preferredRanges {
@ -478,87 +413,21 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
}
// We re-query the lighthouse periodically while sending packets, so
// check for new remotes in our local lighthouse cache
ips := ifce.lightHouse.QueryCache(i.hostId)
for _, ip := range ips {
i.AddRemote(ip)
}
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
if addr == nil || !preferred {
return
}
best, preferred := i.getBestRemote(preferredRanges)
if preferred && !best.Equals(i.remote) {
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
}
}
func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
best, _ := i.getBestRemote(preferredRanges)
if best != nil {
i.remote = best
}
}
func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
if len(i.Remotes) > 0 {
for _, r := range i.Remotes {
for _, l := range preferredRanges {
if l.Contains(r.IP) {
return r, true
}
}
if best == nil || !PrivateIP(r.IP) {
best = r
}
/*
for _, r := range i.Remotes {
// Must have > 80% probe success to be considered.
//fmt.Println("GRADE:", r.addr.IP, r.Grade())
if r.Grade() > float64(.8) {
if localToMe.Contains(r.addr.IP) == true {
best = r.addr
break
//i.remote = i.Remotes[c].addr
} else {
//}
}
*/
}
return best, false
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
})
}
return nil, false
}
// rotateRemote will move remote to the next ip in the list of remote ips for this host
// This is different than PromoteBest in that what is algorithmically best may not actually work.
// Only known use case is when sending a stage 0 handshake.
// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
func (i *HostInfo) rotateRemote() {
// We have 0, can't rotate
if len(i.Remotes) < 1 {
return
// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce)
}
if i.remote == nil {
i.remote = i.Remotes[0]
return
}
// We want to look at all but the very last entry since that is handled at the end
for x := 0; x < len(i.Remotes)-1; x++ {
// Find our current position and move to the next one in the list
if i.Remotes[x].Equals(i.remote) {
i.remote = i.Remotes[x+1]
return
}
}
// Our current position was likely the last in the list, start over at 0
i.remote = i.Remotes[0]
}
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
@ -607,23 +476,13 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
}
}
i.badRemotes = make([]*udpAddr, 0)
i.remotes.ResetBlockedRemotes()
i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock()
i.ConnectionState.certState = nil
}
func (i *HostInfo) CopyRemotes() []*udpAddr {
i.RLock()
rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes))
for x, addr := range i.Remotes {
rc[x] = addr.Copy()
}
i.RUnlock()
return rc
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
if i.ConnectionState != nil {
return i.ConnectionState.peerCert
@ -631,58 +490,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr {
if i.unlockedIsBadRemote(remote) {
return i.remote
}
for _, r := range i.Remotes {
if r.Equals(remote) {
return r
}
}
// Trim this down if necessary
if len(i.Remotes) > MaxRemotes {
i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
}
rc := remote.Copy()
i.Remotes = append(i.Remotes, rc)
return rc
}
func (i *HostInfo) SetRemote(remote *udpAddr) {
i.remote = i.AddRemote(remote)
}
func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) {
if !i.unlockedIsBadRemote(remote) {
// We copy here because we are taking something else's memory and we can't trust everything
i.badRemotes = append(i.badRemotes, remote.Copy())
// We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) {
i.remote = remote.Copy()
i.remotes.LearnRemote(i.hostId, remote.Copy())
}
for k, v := range i.Remotes {
if v.Equals(remote) {
i.Remotes[k] = i.Remotes[len(i.Remotes)-1]
i.Remotes = i.Remotes[:len(i.Remotes)-1]
return
}
}
}
func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool {
for _, v := range i.badRemotes {
if v.Equals(remote) {
return true
}
}
return false
}
func (i *HostInfo) ClearRemotes() {
i.remote = nil
i.Remotes = []*udpAddr{}
}
func (i *HostInfo) ClearConnectionState() {
@ -805,13 +618,3 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
}
return &ips
}
func PrivateIP(ip net.IP) bool {
//TODO: Private for ipv6 or just let it ride?
private := false
_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
return private
}

View File

@ -1,169 +1 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
/*
func TestHostInfoDestProbe(t *testing.T) {
a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
d := NewHostInfoDest(a)
// 999 probes that all return should give a 100% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh)
}
assert.Equal(t, d.Grade(), float64(1))
// 999 probes of which only half return should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which none return should give a 0% success rate
for i := 0; i < 999; i++ {
d.Probe()
}
assert.Equal(t, d.Grade(), float64(0))
// 999 probes of which only 1/4 return should give a 25% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%4 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.25))
// 999 probes of which only half return and are duplicates should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which only way old replies return should give a 0% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh - 101)
}
assert.Equal(t, d.Grade(), float64(0))
}
*/
func TestHostmap(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
// There should be three remotes in the host map
assert.Equal(t, 3, len(info.Remotes))
// Adding an identical remote should not change the count
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 3, len(info.Remotes))
// Adding a fresh remote should add one
y = NewUDPAddrFromString("10.18.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 4, len(info.Remotes))
// Query and reference remote should get the first one (and not nil)
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
assert.NotNil(t, info.remote)
// Promotion should ensure that the best remote is chosen (y)
info.ForcePromoteBest(myNets)
assert.True(t, myNet.Contains(info.remote.IP))
}
func TestHostmapdebug(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
//t.Errorf("%s", m.DebugRemotes(1))
}
func TestHostMap_rotateRemote(t *testing.T) {
h := HostInfo{}
// 0 remotes, no panic
h.rotateRemote()
assert.Nil(t, h.remote)
// 1 remote, no panic
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 1}, 0))
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 1})
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 2}, 0))
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 3}, 0))
h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 4}, 0))
//TODO: ensure we are copying and not storing the slice!
// Rotate through those 3
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 2})
h.rotateRemote()
assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 3})
h.rotateRemote()
assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 4}, Port: 0})
// Finally, we should start over
h.rotateRemote()
assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 1}, Port: 0})
}
func BenchmarkHostmappromote2(b *testing.B) {
l := NewTestLogger()
for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap(l, "test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
}
}

View File

@ -54,10 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
} else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
@ -84,15 +81,13 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
}
}
ci := hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
// Handshake is not ready, we need to grab the lock now before we start
// the handshake process
// Handshake is not ready, we need to grab the lock now before we start the handshake process
hostinfo.Lock()
defer hostinfo.Unlock()
@ -150,10 +145,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
return
}
messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
if f.lightHouse != nil && messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f)
}
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@ -187,50 +179,15 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
}
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
}
f.sendMessageToAll(t, st, hostInfo, p, nb, out)
return
}
func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
hostInfo.RLock()
for _, r := range hostInfo.Remotes {
f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
}
hostInfo.RUnlock()
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 {
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
if ci.eKey == nil {
//TODO: log warning
return 0
return
}
var err error
@ -262,7 +219,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return c
return
}
err = f.writers[q].WriteTo(out, remote)
@ -270,7 +227,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
return c
return
}
func isMulticast(ip uint32) bool {

View File

@ -13,26 +13,11 @@ import (
"github.com/sirupsen/logrus"
)
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
//TODO: nodes are roaming lighthouses, this is bad. How are they learning?
var ErrHostNotKnown = errors.New("host not known")
// The maximum number of ip addresses to store for a given vpnIp per address family
const maxAddrs = 10
type ip4And6 struct {
//TODO: adding a lock here could allow us to release the lock on lh.addrMap quicker
// v4 and v6 store addresses that have been self reported by the client in a server or where all addresses are stored on a client
v4 []*Ip4AndPort
v6 []*Ip6AndPort
// Learned addresses are ones that a client does not know about but a lighthouse learned from as a result of the received packet
// This is only used if you are a lighthouse server
learnedV4 []*Ip4AndPort
learnedV6 []*Ip6AndPort
}
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
@ -42,7 +27,8 @@ type LightHouse struct {
punchConn *udpConn
// Local cache of answers from light houses
addrMap map[uint32]*ip4And6
// map of vpn Ip to answers
addrMap map[uint32]*RemoteList
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
@ -81,7 +67,7 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i
amLighthouse: amLighthouse,
myVpnIp: ip2int(myVpnIpNet.IP),
myVpnZeros: uint32(32 - ones),
addrMap: make(map[uint32]*ip4And6),
addrMap: make(map[uint32]*RemoteList),
nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}),
staticList: make(map[uint32]struct{}),
@ -130,57 +116,79 @@ func (lh *LightHouse) ValidateLHStaticEntries() error {
return nil
}
func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]*udpAddr, error) {
//TODO: we need to hold the lock through the next func
func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return TransformLHReplyToUdpAddrs(v), nil
}
lh.RUnlock()
return nil, ErrHostNotKnown
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if !lh.amLighthouse {
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
}
}
}
func (lh *LightHouse) QueryCache(ip uint32) []*udpAddr {
//TODO: we need to hold the lock through the next func
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return TransformLHReplyToUdpAddrs(v)
return v
}
lh.RUnlock()
return nil
}
//
func (lh *LightHouse) queryAndPrepMessage(ip uint32, f func(*ip4And6) (int, error)) (bool, int, error) {
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if lh.amLighthouse {
return
}
if lh.IsLighthouseIP(ip) {
return
}
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
}
}
func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
n, err := f(v)
lh.RUnlock()
return true, n, err
return v
}
lh.RUnlock()
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(ip)
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock()
// Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok {
// Swap lh lock for remote list lock
v.RLock()
defer v.RUnlock()
lh.RUnlock()
// vpnIp should also be the owner here since we are a lighthouse.
c := v.cache[vpnIp]
// Make sure we have
if c != nil {
n, err := f(c)
return true, n, err
}
return false, 0, nil
}
lh.RUnlock()
return false, 0, nil
@ -203,70 +211,47 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
lh.Unlock()
}
// AddRemote is correct way for non LightHouse members to add an address. toAddr will be placed in the learned map
// static means this is a static host entry from the config file, it should only be used on start up
func (lh *LightHouse) AddRemote(vpnIP uint32, toAddr *udpAddr, static bool) {
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
defer am.Unlock()
lh.Unlock()
if ipv4 := toAddr.IP.To4(); ipv4 != nil {
lh.addRemoteV4(vpnIP, NewIp4AndPort(ipv4, uint32(toAddr.Port)), static, true)
to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
if !lh.unlockedShouldAddV4(to) {
return
}
am.unlockedPrependV4(lh.myVpnIp, to)
} else {
lh.addRemoteV6(vpnIP, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)), static, true)
to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
if !lh.unlockedShouldAddV6(to) {
return
}
am.unlockedPrependV6(lh.myVpnIp, to)
}
//TODO: if we do not add due to a config filter we may end up not having any addresses here
if static {
lh.staticList[vpnIP] = struct{}{}
}
// Mark it as static
lh.staticList[vpnIp] = struct{}{}
}
// unlockedGetAddrs assumes you have the lh lock
func (lh *LightHouse) unlockedGetAddrs(vpnIP uint32) *ip4And6 {
// unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
am, ok := lh.addrMap[vpnIP]
if !ok {
am = &ip4And6{}
am = NewRemoteList()
lh.addrMap[vpnIP] = am
}
return am
}
// addRemoteV4 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
func (lh *LightHouse) addRemoteV4(vpnIP uint32, to *Ip4AndPort, static bool, learned bool) {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
lh.Lock()
defer lh.Unlock()
am := lh.unlockedGetAddrs(vpnIP)
if learned {
if !lh.unlockedShouldAddV4(am.learnedV4, to) {
return
}
am.learnedV4 = prependAndLimitV4(am.learnedV4, to)
} else {
if !lh.unlockedShouldAddV4(am.v4, to) {
return
}
am.v4 = prependAndLimitV4(am.v4, to)
}
}
func prependAndLimitV4(cache []*Ip4AndPort, to *Ip4AndPort) []*Ip4AndPort {
cache = append(cache, nil)
copy(cache[1:], cache)
cache[0] = to
if len(cache) > MaxRemotes {
cache = cache[:maxAddrs]
}
return cache
}
// unlockedShouldAddV4 checks if to is allowed by our allow list and is not already present in the cache
func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool {
// unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(to.Ip)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@ -276,69 +261,21 @@ func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool
return false
}
for _, v := range am {
if v.Ip == to.Ip && v.Port == to.Port {
return false
}
}
return true
}
// addRemoteV6 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
func (lh *LightHouse) addRemoteV6(vpnIP uint32, to *Ip6AndPort, static bool, learned bool) {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
lh.Lock()
defer lh.Unlock()
am := lh.unlockedGetAddrs(vpnIP)
if learned {
if !lh.unlockedShouldAddV6(am.learnedV6, to) {
return
}
am.learnedV6 = prependAndLimitV6(am.learnedV6, to)
} else {
if !lh.unlockedShouldAddV6(am.v6, to) {
return
}
am.v6 = prependAndLimitV6(am.v6, to)
}
}
func prependAndLimitV6(cache []*Ip6AndPort, to *Ip6AndPort) []*Ip6AndPort {
cache = append(cache, nil)
copy(cache[1:], cache)
cache[0] = to
if len(cache) > MaxRemotes {
cache = cache[:maxAddrs]
}
return cache
}
// unlockedShouldAddV6 checks if to is allowed by our allow list and is not already present in the cache
func (lh *LightHouse) unlockedShouldAddV6(am []*Ip6AndPort, to *Ip6AndPort) bool {
// unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool {
allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow {
return false
}
for _, v := range am {
if v.Hi == to.Hi && v.Lo == to.Lo && v.Port == to.Port {
return false
}
}
return true
}
@ -349,13 +286,6 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip
}
func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
if lh.amLighthouse {
lh.DeleteVpnIP(vpnIP)
lh.AddRemote(vpnIP, toIp, false)
}
}
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
if _, ok := lh.lighthouses[vpnIP]; ok {
return true
@ -496,7 +426,6 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}
//TODO: do we need c here?
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
@ -544,13 +473,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
//TODO: we can DRY this further
reqVpnIP := n.Details.VpnIp
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
//TODO: If we use a lock on cache we can avoid holding it on lh.addrMap and keep things moving better
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(cache *ip4And6) (int, error) {
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIP
lhh.coalesceAnswers(cache, n)
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
})
@ -568,12 +496,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(cache *ip4And6) (int, error) {
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = vpnIp
lhh.coalesceAnswers(cache, n)
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
})
@ -591,12 +519,24 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) coalesceAnswers(cache *ip4And6, n *NebulaMeta) {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.v4...)
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.learnedV4...)
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
if c.v4 != nil {
if c.v4.learned != nil {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned)
}
if c.v4.reported != nil && len(c.v4.reported) > 0 {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...)
}
}
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.v6...)
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.learnedV6...)
if c.v6 != nil {
if c.v6.learned != nil {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned)
}
if c.v6.reported != nil && len(c.v6.reported) > 0 {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...)
}
}
}
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
@ -604,14 +544,14 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32)
return
}
// We can't just slam the responses in as they may come from multiple lighthouses and we should coalesce the answers
for _, to := range n.Details.Ip4AndPorts {
lhh.lh.addRemoteV4(n.Details.VpnIp, to, false, false)
}
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
am.Lock()
lhh.lh.Unlock()
for _, to := range n.Details.Ip6AndPorts {
lhh.lh.addRemoteV6(n.Details.VpnIp, to, false, false)
}
am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
// Non-blocking attempt to trigger, skip if it would block
select {
@ -637,35 +577,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
}
lhh.lh.Lock()
defer lhh.lh.Unlock()
am := lhh.lh.unlockedGetAddrs(vpnIp)
am := lhh.lh.unlockedGetRemoteList(vpnIp)
am.Lock()
lhh.lh.Unlock()
//TODO: other note on a lock for am so we can release more quickly and lock our real unit of change which is far less contended
// We don't accumulate addresses being told to us
am.v4 = am.v4[:0]
am.v6 = am.v6[:0]
for _, v := range n.Details.Ip4AndPorts {
if lhh.lh.unlockedShouldAddV4(am.v4, v) {
am.v4 = append(am.v4, v)
}
}
for _, v := range n.Details.Ip6AndPorts {
if lhh.lh.unlockedShouldAddV6(am.v6, v) {
am.v6 = append(am.v6, v)
}
}
// We prefer the first n addresses if we got too big
if len(am.v4) > MaxRemotes {
am.v4 = am.v4[:MaxRemotes]
}
if len(am.v6) > MaxRemotes {
am.v6 = am.v6[:MaxRemotes]
}
am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
}
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
@ -716,33 +634,6 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
}
}
func TransformLHReplyToUdpAddrs(ips *ip4And6) []*udpAddr {
addrs := make([]*udpAddr, len(ips.v4)+len(ips.v6)+len(ips.learnedV4)+len(ips.learnedV6))
i := 0
for _, v := range ips.learnedV4 {
addrs[i] = NewUDPAddrFromLH4(v)
i++
}
for _, v := range ips.v4 {
addrs[i] = NewUDPAddrFromLH4(v)
i++
}
for _, v := range ips.learnedV6 {
addrs[i] = NewUDPAddrFromLH6(v)
i++
}
for _, v := range ips.v6 {
addrs[i] = NewUDPAddrFromLH6(v)
i++
}
return addrs
}
// ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {

View File

@ -48,16 +48,16 @@ func Test_lhStaticMapping(t *testing.T) {
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err := meh.ValidateLHStaticEntries()
assert.Nil(t, err)
lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
}
@ -73,17 +73,27 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = &ip4And6{v4: []*Ip4AndPort{
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port))},
}
lh.addrMap[3] = NewRemoteList()
lh.addrMap[3].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
},
func(*Ip4AndPort) bool { return true },
)
rAddr := NewUDPAddrFromString("1.2.2.3:12345")
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = &ip4And6{v4: []*Ip4AndPort{
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port))},
}
lh.addrMap[2] = NewRemoteList()
lh.addrMap[2].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
},
func(*Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{}
@ -173,7 +183,7 @@ func TestLighthouse_Memory(t *testing.T) {
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, one removed on a dupe check the other by limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
newLHHostUpdate(
myUdpAddr0,
myVpnIp,
@ -191,11 +201,12 @@ func TestLighthouse_Memory(t *testing.T) {
myUdpAddr10,
myUdpAddr11, // This should get cut
}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(
t,
r.msg.Details.Ip4AndPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
)
// Make sure we won't add ips in our vpn network
@ -247,71 +258,71 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
lhh.HandleRequest(fromAddr, vpnIp, b, w)
}
func Test_lhRemoteAllowList(t *testing.T) {
l := NewTestLogger()
c := NewConfig(l)
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
"10.20.0.0/12": false,
}
allowList, err := c.GetAllowList("remoteallowlist", false)
assert.Nil(t, err)
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh.SetRemoteAllowList(allowList)
// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
remote1IP := net.ParseIP("10.20.0.3")
lh.AddRemote(ip2int(remote1IP), NewUDPAddr(remote1IP, uint16(4242)), true)
assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v4)
assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v6)
// Make sure a good ip enters the cache and addrMap
remote2IP := net.ParseIP("10.128.0.3")
remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true)
assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote2UDPAddr)
// Another good ip gets into the cache, ordering is inverted
remote3IP := net.ParseIP("10.128.0.4")
remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
lh.AddRemote(ip2int(remote2IP), remote3UDPAddr, true)
assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote3UDPAddr, remote2UDPAddr)
// If we exceed the length limit we should only have the most recent addresses
addedAddrs := []*udpAddr{}
for i := 0; i < 11; i++ {
remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
lh.AddRemote(ip2int(remote2IP), remoteUDPAddr, true)
// The first entry here is a duplicate, don't add it to the assert list
if i != 0 {
addedAddrs = append(addedAddrs, remoteUDPAddr)
}
}
// We should only have the last 10 of what we tried to add
assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
ln := len(addedAddrs)
assertIp4InArray(
t,
lh.addrMap[ip2int(remote2IP)].learnedV4,
addedAddrs[ln-1],
addedAddrs[ln-2],
addedAddrs[ln-3],
addedAddrs[ln-4],
addedAddrs[ln-5],
addedAddrs[ln-6],
addedAddrs[ln-7],
addedAddrs[ln-8],
addedAddrs[ln-9],
addedAddrs[ln-10],
)
}
//TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) {
// l := NewTestLogger()
// c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false,
// }
// allowList, err := c.GetAllowList("remoteallowlist", false)
// assert.Nil(t, err)
//
// lh1 := "10.128.0.2"
// lh1IP := net.ParseIP(lh1)
//
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
//
// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
// lh.SetRemoteAllowList(allowList)
//
// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
// remote1IP := net.ParseIP("10.20.0.3")
// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
//
// // Make sure a good ip enters the cache and addrMap
// remote2IP := net.ParseIP("10.128.0.3")
// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
//
// // Another good ip gets into the cache, ordering is inverted
// remote3IP := net.ParseIP("10.128.0.4")
// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
//
// // If we exceed the length limit we should only have the most recent addresses
// addedAddrs := []*udpAddr{}
// for i := 0; i < 11; i++ {
// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
// // The first entry here is a duplicate, don't add it to the assert list
// if i != 0 {
// addedAddrs = append(addedAddrs, remoteUDPAddr)
// }
// }
//
// // We should only have the last 10 of what we tried to add
// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
// assertUdpAddrInArray(
// t,
// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
// addedAddrs[0],
// addedAddrs[1],
// addedAddrs[2],
// addedAddrs[3],
// addedAddrs[4],
// addedAddrs[5],
// addedAddrs[6],
// addedAddrs[7],
// addedAddrs[8],
// addedAddrs[9],
// )
//}
func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
@ -354,6 +365,16 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
}
}
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
}
}
}
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
addrs := make([]*udpAddr, len(ips))
for k, v := range ips {

View File

@ -221,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@ -302,14 +302,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
}
} else {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
}
}
@ -328,7 +328,6 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,

View File

@ -132,6 +132,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
f.connectionManager.In(hostinfo.hostId)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId)
@ -140,6 +141,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
f.hostMap.DeleteHostInfo(hostInfo)
}
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
@ -160,9 +166,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
remoteCopy := *hostinfo.remote
hostinfo.lastRoamRemote = &remoteCopy
hostinfo.SetRemote(addr)
if f.lightHouse.amLighthouse {
f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
}
}
}

500
remote_list.go Normal file
View File

@ -0,0 +1,500 @@
package nebula
import (
"bytes"
"net"
"sort"
"sync"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udpAddr, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(to *Ip4AndPort) bool
type checkFuncV6 func(to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
Learned []*udpAddr `json:"learned,omitempty"`
Reported []*udpAddr `json:"reported,omitempty"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
// We will never clean learned/reported information for them as it stands today
// cache is an internal struct that splits v4 and v6 addresses inside the cache map
type cache struct {
v4 *cacheV4
v6 *cacheV6
}
// cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct {
learned *Ip4AndPort
reported []*Ip4AndPort
}
// cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct {
learned *Ip6AndPort
reported []*Ip6AndPort
}
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct {
// Every interaction with internals requires a lock!
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udpAddr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
cache map[uint32]*cache
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList {
return &RemoteList{
addrs: make([]*udpAddr, 0),
cache: make(map[uint32]*cache),
}
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
return len(r.addrs)
}
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
r.Rebuild(preferredRanges)
r.RLock()
for _, v := range r.addrs {
forEach(v, isPreferred(v.IP, preferredRanges))
}
r.RUnlock()
}
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.addrs))
for i, v := range r.addrs {
c[i] = v.Copy()
}
return c
}
// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
r.Lock()
defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
} else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
}
}
// CopyCache locks and creates a more human friendly form of the internal address cache.
// This may contain duplicates and blocked addresses
func (r *RemoteList) CopyCache() *CacheMap {
r.RLock()
defer r.RUnlock()
cm := make(CacheMap)
getOrMake := func(vpnIp string) *Cache {
c := cm[vpnIp]
if c == nil {
c = &Cache{
Learned: make([]*udpAddr, 0),
Reported: make([]*udpAddr, 0),
}
cm[vpnIp] = c
}
return c
}
for owner, mc := range r.cache {
c := getOrMake(IntIp(owner).String())
if mc.v4 != nil {
if mc.v4.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
}
for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
}
for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
}
}
}
return &cm
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udpAddr) {
r.Lock()
defer r.Unlock()
// Check if we already blocked this addr
if r.unlockedIsBad(bad) {
return
}
// We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.Copy())
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.badRemotes))
for i, v := range r.badRemotes {
c[i] = v.Copy()
}
return c
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
r.badRemotes = nil
r.Unlock()
}
// Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so
func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
r.Lock()
defer r.Unlock()
// Only rebuild if the cache changed
//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
if r.shouldRebuild {
r.unlockedCollect()
r.shouldRebuild = false
}
// Always re-sort, preferredRanges can change via HUP
r.unlockedSort(preferredRanges)
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
for _, v := range r.badRemotes {
if v.Equals(remote) {
return true
}
}
return false
}
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v6 addresses if we never have any
if am.v4 == nil {
am.v4 = &cacheV4{}
}
return am.v4
}
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v4 addresses if we never have any
if am.v6 == nil {
am.v6 = &cacheV6{}
}
return am.v6
}
// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list.
// The result of this function can contain duplicates. unlockedSort handles cleaning it.
func (r *RemoteList) unlockedCollect() {
addrs := r.addrs[:0]
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
u := NewUDPAddrFromLH4(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
u := NewUDPAddrFromLH4(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
if c.v6 != nil {
if c.v6.learned != nil {
u := NewUDPAddrFromLH6(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
u := NewUDPAddrFromLH6(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
}
r.addrs = addrs
}
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
n := len(r.addrs)
if n < 2 {
return
}
lessFunc := func(i, j int) bool {
a := r.addrs[i]
b := r.addrs[j]
// Preferred addresses first
aPref := isPreferred(a.IP, preferredRanges)
bPref := isPreferred(b.IP, preferredRanges)
switch {
case aPref && !bPref:
// If i is preferred and j is not, i is less than j
return true
case !aPref && bPref:
// If j is preferred then i is not due to the else, i is not less than j
return false
default:
// Both i an j are either preferred or not, sort within that
}
// ipv6 addresses 2nd
a4 := a.IP.To4()
b4 := b.IP.To4()
switch {
case a4 == nil && b4 != nil:
// If i is v6 and j is v4, i is less than j
return true
case a4 != nil && b4 == nil:
// If j is v6 and i is v4, i is not less than j
return false
case a4 != nil && b4 != nil:
// Special case for ipv4, a4 and b4 are not nil
aPrivate := isPrivateIP(a4)
bPrivate := isPrivateIP(b4)
switch {
case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j
return true
case aPrivate && !bPrivate:
// If j is public (not private) then i is private due to the else, i is not less than j
return false
default:
// Both i an j are either public or private, sort within that
}
default:
// Both i an j are either ipv4 or ipv6, sort within that
}
// lexical order of ips 3rd
c := bytes.Compare(a.IP, b.IP)
if c == 0 {
// Ips are the same, Lexical order of ports 4th
return a.Port < b.Port
}
// Ip wasn't the same
return c < 0
}
// Sort it
sort.Slice(r.addrs, lessFunc)
// Deduplicate
a, b := 0, 1
for b < n {
if !r.addrs[a].Equals(r.addrs[b]) {
a++
if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
}
}
b++
}
r.addrs = r.addrs[:a+1]
return
}
// minInt returns the minimum integer of a or b
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// isPreferred returns true of the ip is contained in the preferredRanges list
func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges {
if p.Contains(ip) {
return true
}
}
return false
}
var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
func isPrivateIP(ip net.IP) bool {
//TODO: another great cidrtree option
//TODO: Private for ipv6 or just let it ride?
return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
}

228
remote_list_test.go Normal file
View File

@ -0,0 +1,228 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
1,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
rl.Rebuild([]*net.IPNet{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within
assert.Equal(t, "[1::1]:1", rl.addrs[0].String())
assert.Equal(t, "[1::1]:2", rl.addrs[1].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String())
// ipv4 last, sorted by public first, then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up
_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String())
// ipv6 last, sorted by public first, then private, lexically within them
assert.Equal(t, "[1::1]:1", rl.addrs[7].String())
assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else
_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first
assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String())
// ipv6 next
assert.Equal(t, "[1::1]:1", rl.addrs[2].String())
assert.Equal(t, "[1::1]:2", rl.addrs[3].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String())
// the remaining ipv4 last
assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
rl.Rebuild([]*net.IPNet{ipNet})
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}

87
ssh.go
View File

@ -10,8 +10,8 @@ import (
"os"
"reflect"
"runtime/pprof"
"sort"
"strings"
"sync/atomic"
"syscall"
"github.com/sirupsen/logrus"
@ -335,8 +335,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
return nil
}
hostMap.RLock()
defer hostMap.RUnlock()
hm := listHostMap(hostMap)
sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
})
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
@ -344,35 +346,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
js.SetIndent("", " ")
}
d := make([]m, len(hostMap.Hosts))
x := 0
var h m
for _, v := range hostMap.Hosts {
h = m{
"vpnIp": int2ip(v.hostId),
"localIndex": v.localIndexId,
"remoteIndex": v.remoteIndexId,
"remoteAddrs": v.CopyRemotes(),
"cachedPackets": len(v.packetStore),
"cert": v.GetCert(),
}
if v.ConnectionState != nil {
h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
}
d[x] = h
x++
}
err := js.Encode(d)
err := js.Encode(hm)
if err != nil {
//TODO
return nil
}
} else {
for i, v := range hostMap.Hosts {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes()))
for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
if err != nil {
return err
}
@ -389,8 +371,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
return nil
}
type lighthouseInfo struct {
VpnIP net.IP `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"`
}
lightHouse.RLock()
defer lightHouse.RUnlock()
addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
x := 0
for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{
VpnIP: int2ip(k),
Addrs: v.CopyCache(),
}
x++
}
lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool {
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
})
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
@ -398,27 +398,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
js.SetIndent("", " ")
}
d := make([]m, len(lightHouse.addrMap))
x := 0
var h m
for vpnIp, v := range lightHouse.addrMap {
h = m{
"vpnIp": int2ip(vpnIp),
"addrs": TransformLHReplyToUdpAddrs(v),
}
d[x] = h
x++
}
err := js.Encode(d)
err := js.Encode(addrMap)
if err != nil {
//TODO
return nil
}
} else {
for vpnIp, v := range lightHouse.addrMap {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), TransformLHReplyToUdpAddrs(v)))
for _, v := range addrMap {
b, err := json.Marshal(v.Addrs)
if err != nil {
return err
}
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
if err != nil {
return err
}
@ -469,8 +461,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
ips, _ := ifce.lightHouse.Query(vpnIp, ifce)
return json.NewEncoder(w.GetWriter()).Encode(ips)
return json.NewEncoder(w.GetWriter()).Encode(ifce.lightHouse.Query(vpnIp, ifce).CopyCache())
}
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
@ -727,7 +718,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -737,7 +728,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}
return enc.Encode(hostInfo)
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
}
func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {

View File

@ -41,9 +41,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
func (c *Tun) Send(packet []byte) {
if c.l.Level >= logrus.DebugLevel {
c.l.Debug("Tun injecting packet")
}
c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
c.rxPackets <- packet
}

View File

@ -13,8 +13,8 @@ type udpAddr struct {
}
func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
addr := udpAddr{IP: make([]byte, len(ip)), Port: port}
copy(addr.IP, ip)
addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip.To16())
return &addr
}
@ -22,7 +22,7 @@ func NewUDPAddrFromString(s string) *udpAddr {
ip, port, err := parseIPAndPort(s)
//TODO: handle err
_ = err
return &udpAddr{IP: ip, Port: port}
return &udpAddr{IP: ip.To16(), Port: port}
}
func (ua *udpAddr) Equals(t *udpAddr) bool {

View File

@ -97,40 +97,21 @@ func (u *udpConn) GetSendBuffer() (int, error) {
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
var rsa unix.RawSockaddrAny
var rLen = unix.SizeofSockaddrAny
_, _, err := unix.Syscall(
unix.SYS_GETSOCKNAME,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
)
if err != 0 {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
return nil, err
}
addr := &udpAddr{}
if rsa.Addr.Family == unix.AF_INET {
pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(&rsa))
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
copy(addr.IP, pp.Addr[:])
} else if rsa.Addr.Family == unix.AF_INET6 {
//TODO: this cast sucks and we can do better
pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(&rsa))
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
copy(addr.IP, pp.Addr[:])
} else {
addr.Port = 0
addr.IP = []byte{}
switch sa := sa.(type) {
case *unix.SockaddrInet4:
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
addr.Port = uint16(sa.Port)
case *unix.SockaddrInet6:
addr.IP = sa.Addr[0:]
addr.Port = uint16(sa.Port)
}
//TODO: Just use this instead?
//a, b := unix.Getsockname(u.sysFd)
return addr, nil
}

View File

@ -3,6 +3,7 @@
package nebula
import (
"fmt"
"net"
"github.com/sirupsen/logrus"
@ -53,7 +54,14 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *udpConn) Send(packet *UdpPacket) {
u.l.Infof("UDP injecting packet %+v", packet)
h := &Header{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
u.l.WithField("header", h).
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet")
u.rxPackets <- packet
}