diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0962f92..b6b57a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,6 +40,9 @@ jobs: - name: Test run: make test + - name: End 2 end + run: make e2e + test: name: Build and test on ${{ matrix.os }} runs-on: ${{ matrix.os }} @@ -72,3 +75,6 @@ jobs: - name: Test run: go test -v ./... + + - name: End 2 end + run: go test -tags=e2e_testing -count=1 ./e2e diff --git a/Makefile b/Makefile index 3af3429..fcef8e5 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,19 @@ ALL = $(ALL_LINUX) \ windows-amd64 e2e: - go test -v -tags=e2e_testing ./e2e + $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e + +e2ev: TEST_FLAGS = -v +e2ev: e2e + +e2evv: TEST_ENV += TEST_LOGS=1 +e2evv: e2ev + +e2evvv: TEST_ENV += TEST_LOGS=2 +e2evvv: e2ev + +e2evvvv: TEST_ENV += TEST_LOGS=3 +e2evvvv: e2ev all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) @@ -138,5 +150,5 @@ smoke-docker-race: BUILD_ARGS = -race smoke-docker-race: smoke-docker .FORCE: -.PHONY: e2e test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race +.PHONY: e2e e2ev e2evv e2evvv e2evvvv test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race .DEFAULT_GOAL := bin diff --git a/control.go b/control.go index a5df2d5..089e8ac 100644 --- a/control.go +++ b/control.go @@ -164,12 +164,11 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { } func copyHostInfo(h *HostInfo) ControlHostInfo { - addrs := h.RemoteUDPAddrs() chi := ControlHostInfo{ VpnIP: int2ip(h.hostId), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, - RemoteAddrs: make([]*udpAddr, len(addrs), len(addrs)), + RemoteAddrs: h.CopyRemotes(), CachedPackets: len(h.packetStore), MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter), } @@ -182,9 +181,5 @@ func copyHostInfo(h *HostInfo) ControlHostInfo { chi.CurrentRemote = h.remote.Copy() } - for i, addr := range addrs { - chi.RemoteAddrs[i] = addr.Copy() - } - return chi } diff --git a/control_test.go b/control_test.go index a411fc1..9dc461f 100644 --- a/control_test.go +++ b/control_test.go @@ -45,7 +45,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { Signature: []byte{1, 2, 1, 2, 1, 3}, } - remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)} + remotes := []*udpAddr{remote1, remote2} hm.Add(ip2int(ipNet.IP), &HostInfo{ remote: remote1, Remotes: remotes, diff --git a/control_tester.go b/control_tester.go index 01e4d1f..574682a 100644 --- a/control_tester.go +++ b/control_tester.go @@ -57,6 +57,14 @@ func (c *Control) GetFromUDP(block bool) *UdpPacket { return c.f.outside.Get(block) } +func (c *Control) GetUDPTxChan() <-chan *UdpPacket { + return c.f.outside.txPackets +} + +func (c *Control) GetTunTxChan() <-chan []byte { + return c.f.inside.(*Tun).txPackets +} + // InjectUDPPacket will inject a packet into the udp side of nebula func (c *Control) InjectUDPPacket(p *UdpPacket) { c.f.outside.Send(p) @@ -90,3 +98,7 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*Tun).Send(buffer.Bytes()) } + +func (c *Control) GetUDPAddr() string { + return c.f.outside.addr.String() +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 87c5d93..c64e62b 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -6,29 +6,24 @@ import ( "net" "testing" "time" + + "github.com/slackhq/nebula/e2e/router" ) func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - defMask := net.IPMask{0, 0, 0, 0} - - myUdpAddr := &net.UDPAddr{IP: net.IP{10, 0, 0, 1}, Port: 4242} - myVpnIpNet := &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: defMask} - myControl := newSimpleServer(ca, caKey, "me", myUdpAddr, myVpnIpNet) - - theirUdpAddr := &net.UDPAddr{IP: net.IP{10, 0, 0, 2}, Port: 4242} - theirVpnIpNet := &net.IPNet{IP: net.IP{10, 128, 0, 2}, Mask: defMask} - theirControl := newSimpleServer(ca, caKey, "them", theirUdpAddr, theirVpnIpNet) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() // Send a udp packet through to begin standing up the tunnel, this should come out the other side - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) // Have them consume my stage 0 packet. They have a tunnel now theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -40,21 +35,172 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) // Make sure our host infos are correct - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) // Get that cached packet and make sure it looks right myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) - // Send a packet from them to me - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) - myControl.InjectUDPPacket(theirControl.GetFromUDP(true)) - theirPacket := myControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from them"), theirPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + // Do a bidirectional tunnel test + assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl)) - // And once more from me to them - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hello again from me")) - theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) - myPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hello again from me"), myPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + myControl.Stop() + theirControl.Stop() + //TODO: assert hostmaps } + +func TestWrongResponderHandshake(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) + evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99}) + + // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse + myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) + + // But also add their real udp addr, which should be tried after evil + myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(myControl, theirControl, evilControl) + + // Start the servers + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") + myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + r.OnceFrom(myControl) + r.OnceFrom(evilControl) + + t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us") + assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) + assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl) + + //TODO: Assert pending hostmap - I should have a correct hostinfo for them now + + t.Log("Lets let the messages fly, this time we should have a tunnel with them") + r.OnceFrom(myControl) + r.OnceFrom(theirControl) + + t.Log("I should now have a tunnel with them now and my original packet should get there") + r.RouteUntilAfterMsgType(myControl, 1, 0) + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + + t.Log("I should now have a proper tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) + assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + + t.Log("Lets make sure evil is still good") + assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) + + //TODO: assert hostmaps for everyone + t.Log("Success!") + //TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message + // what we really need here is a way to exit all the go routines loops (there are many) + //myControl.Stop() + //theirControl.Stop() +} + +////TODO: We need to test lies both as the race winner and race loser +//func TestManyWrongResponderHandshake(t *testing.T) { +// ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) +// +// myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99}) +// theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) +// evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1}) +// +// t.Log("Build a router so we don't have to reason who gets which packet") +// r := newRouter(myControl, theirControl, evilControl) +// +// t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit") +// for i := 0; i < 10; i++ { +// addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i} +// myControl.InjectLightHouseAddr(theirVpnIp, &addr) +// // We also need to tell our router about it +// r.AddRoute(addr.IP, uint16(addr.Port), evilControl) +// } +// +// // Start the servers +// myControl.Start() +// theirControl.Start() +// evilControl.Start() +// +// t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)") +// myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) +// +// t.Log("We need to spin until we get to the right remote for them") +// getOut := false +// injected := false +// for { +// t.Log("Routing for me and evil while we work through the bad ips") +// r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType { +// // We should stop routing right after we see a packet coming from us to them +// if *receiver == *theirControl { +// getOut = true +// return drainAndExit +// } +// +// // We need to poke our real ip in at some point, this is a well protected check looking for that moment +// if *receiver == *evilControl { +// hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true) +// if !injected && len(hi.RemoteAddrs) == 1 { +// t.Log("I am on my last ip for them, time to inject the real one into my lighthouse") +// myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) +// injected = true +// } +// return drainAndExit +// } +// +// return keepRouting +// }) +// +// if getOut { +// break +// } +// +// r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit) +// } +// +// t.Log("I should have a tunnel with evil and them, evil should not have a cached packet") +// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) +// evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false) +// realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)} +// +// t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr) +// assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl) +// +// //t.Log("Draining everyones packets") +// //r.Drain(theirControl) +// //r.DrainAll(myControl, theirControl, evilControl) +// // +// //go func() { +// // for { +// // time.Sleep(10 * time.Millisecond) +// // t.Log(len(theirControl.GetUDPTxChan())) +// // t.Log(len(theirControl.GetTunTxChan())) +// // t.Log(len(myControl.GetUDPTxChan())) +// // t.Log(len(evilControl.GetUDPTxChan())) +// // t.Log("=====") +// // } +// //}() +// +// t.Log("I should have a tunnel with them now and my original packet should get there") +// r.RouteUntilAfterMsgType(myControl, 1, 0) +// myCachedPacket := theirControl.GetFromTun(true) +// +// t.Log("Got the cached packet, lets test the tunnel") +// assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) +// +// t.Log("Testing tunnels with them") +// assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) +// assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) +// +// t.Log("Testing tunnels with evil") +// assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r) +// +// //TODO: assert hostmaps for everyone +//} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b85da66..bd553d7 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -7,7 +7,9 @@ import ( "encoding/binary" "fmt" "io" + "io/ioutil" "net" + "os" "testing" "time" @@ -16,6 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" @@ -25,9 +28,17 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, listenAddr *net.UDPAddr, vpnIp *net.IPNet) *nebula.Control { - l := logrus.New() - _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIp, nil, []string{}) +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP) (*nebula.Control, net.IP, *net.UDPAddr) { + l := NewTestLogger() + + vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{0, 0, 0, 0}} + copy(vpnIpNet.IP, udpIp) + vpnIpNet.IP[1] += 128 + udpAddr := net.UDPAddr{ + IP: udpIp, + Port: 4242, + } + _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { @@ -54,12 +65,12 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, l }}, }, "listen": m{ - "host": listenAddr.IP.String(), - "port": listenAddr.Port, + "host": udpAddr.IP.String(), + "port": udpAddr.Port, }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), - "level": "info", + "level": l.Level.String(), }, } cb, err := yaml.Marshal(mc) @@ -76,7 +87,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, l panic(err) } - return control + return control, vpnIpNet.IP, &udpAddr } // newTestCaCert will generate a CA cert @@ -193,6 +204,36 @@ func int2ip(nn uint32) net.IP { return ip } +type doneCb func() + +func deadline(t *testing.T, seconds time.Duration) doneCb { + timeout := time.After(seconds * time.Second) + done := make(chan bool) + go func() { + select { + case <-timeout: + t.Fatal("Test did not finish in time") + case <-done: + } + }() + + return func() { + done <- true + } +} + +func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { + // Send a packet from them to me + controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + bPacket := r.RouteUntilTxTun(controlB, controlA) + assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) + + // And once more from me to them + controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + aPacket := r.RouteUntilTxTun(controlA, controlB) + assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) +} + func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { // Get both host infos hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false) @@ -202,14 +243,14 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIP, "HostA VpnIp is wrong in controlB") - assert.Equal(t, vpnIpA, hAinB.VpnIP, "HostB VpnIp is wrong in controlA") + assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A") + assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "HostA remote ip is wrong in controlB") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "HostB remote ip is wrong in controlA") + assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") + assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - assert.Equal(t, uint16(addrA.Port), hBinA.CurrentRemote.Port, "HostA remote ip is wrong in controlB") - assert.Equal(t, uint16(addrB.Port), hAinB.CurrentRemote.Port, "HostB remote ip is wrong in controlA") + assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") + assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") @@ -250,3 +291,24 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, from assert.NotNil(t, data) assert.Equal(t, expected, data.Payload(), "Data was incorrect") } + +func NewTestLogger() *logrus.Logger { + l := logrus.New() + + v := os.Getenv("TEST_LOGS") + if v == "" { + l.SetOutput(ioutil.Discard) + return l + } + + switch v { + case "2": + l.SetLevel(logrus.DebugLevel) + case "3": + l.SetLevel(logrus.TraceLevel) + default: + l.SetLevel(logrus.InfoLevel) + } + + return l +} diff --git a/e2e/router/router.go b/e2e/router/router.go new file mode 100644 index 0000000..0cf486c --- /dev/null +++ b/e2e/router/router.go @@ -0,0 +1,221 @@ +// +build e2e_testing + +package router + +import ( + "fmt" + "net" + "strconv" + "sync" + + "github.com/slackhq/nebula" +) + +type R struct { + // Simple map of the ip:port registered on a control to the control + // Basically a router, right? + controls map[string]*nebula.Control + + // A map for inbound packets for a control that doesn't know about this address + inNat map[string]*nebula.Control + + // A last used map, if an inbound packet hit the inNat map then + // all return packets should use the same last used inbound address for the outbound sender + // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver + outNat map[string]net.UDPAddr + + // All interactions are locked to help serialize behavior + sync.Mutex +} + +type exitType int + +const ( + // Keeps routing, the function will get called again on the next packet + keepRouting exitType = 0 + // Does not route this packet and exits immediately + exitNow exitType = 1 + // Routes this packet and exits immediately afterwards + routeAndExit exitType = 2 +) + +type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType + +func NewR(controls ...*nebula.Control) *R { + r := &R{ + controls: make(map[string]*nebula.Control), + inNat: make(map[string]*nebula.Control), + outNat: make(map[string]net.UDPAddr), + } + + for _, c := range controls { + addr := c.GetUDPAddr() + if _, ok := r.controls[addr]; ok { + panic("Duplicate listen address: " + addr) + } + r.controls[addr] = c + } + + return r +} + +// AddRoute will place the nebula controller at the ip and port specified. +// It does not look at the addr attached to the instance. +// If a route is used, this will behave like a NAT for the return path. +// Rewriting the source ip:port to what was last sent to from the origin +func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { + r.Lock() + defer r.Unlock() + + inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + if _, ok := r.inNat[inAddr]; ok { + panic("Duplicate listen address inNat: " + inAddr) + } + r.inNat[inAddr] = c +} + +// OnceFrom will route a single packet from sender then return +// If the router doesn't have the nebula controller for that address, we panic +func (r *R) OnceFrom(sender *nebula.Control) { + r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType { + return routeAndExit + }) +} + +// RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun +// If the router doesn't have the nebula controller for that address, we panic +func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte { + tunTx := receiver.GetTunTxChan() + udpTx := sender.GetUDPTxChan() + + for { + select { + // Maybe we already have something on the tun for us + case b := <-tunTx: + return b + + // Nope, lets push the sender along + case p := <-udpTx: + outAddr := sender.GetUDPAddr() + r.Lock() + inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) + c := r.getControl(outAddr, inAddr, p) + if c == nil { + r.Unlock() + panic("No control for udp tx") + } + + c.InjectUDPPacket(p) + r.Unlock() + } + } +} + +// RouteExitFunc will call the whatDo func with each udp packet from sender. +// whatDo can return: +// - exitNow: the packet will not be routed and this call will return immediately +// - routeAndExit: this call will return immediately after routing the last packet from sender +// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender +//TODO: is this RouteWhile? +func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { + h := &nebula.Header{} + for { + p := sender.GetFromUDP(true) + r.Lock() + if err := h.Parse(p.Data); err != nil { + panic(err) + } + + outAddr := sender.GetUDPAddr() + inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) + receiver := r.getControl(outAddr, inAddr, p) + if receiver == nil { + r.Unlock() + panic("Can't route for host: " + inAddr) + } + + e := whatDo(p, receiver) + switch e { + case exitNow: + r.Unlock() + return + + case routeAndExit: + receiver.InjectUDPPacket(p) + r.Unlock() + return + + case keepRouting: + receiver.InjectUDPPacket(p) + + default: + panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) + } + + r.Unlock() + } +} + +// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender +// If the router doesn't have the nebula controller for that address, we panic +func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) { + h := &nebula.Header{} + r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { + if err := h.Parse(p.Data); err != nil { + panic(err) + } + if h.Type == msgType && h.Subtype == subType { + return routeAndExit + } + + return keepRouting + }) +} + +// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr +// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` +// If the router doesn't have the nebula controller for that address, we panic +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) { + if finish == keepRouting { + finish = routeAndExit + } + + r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType { + if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + return finish + } + + return keepRouting + }) +} + +// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change +// This is an internal router function, the caller must hold the lock +func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { + p.FromIp = newAddr.IP + p.FromPort = uint16(newAddr.Port) + } + + c, ok := r.inNat[toAddr] + if ok { + sHost, sPort, err := net.SplitHostPort(toAddr) + if err != nil { + panic(err) + } + + port, err := strconv.Atoi(sPort) + if err != nil { + panic(err) + } + + r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ + IP: net.ParseIP(sHost), + Port: port, + } + return c + } + + //TODO: call receive hooks! + return r.controls[toAddr] +} diff --git a/handshake_ix.go b/handshake_ix.go index 1749c16..1587d13 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,7 +1,6 @@ package nebula import ( - "bytes" "sync/atomic" "time" @@ -126,7 +125,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hostinfo := &HostInfo{ ConnectionState: ci, - Remotes: []*HostInfoDest{}, + Remotes: []*udpAddr{}, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, hostId: vpnIP, @@ -274,25 +273,24 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { if hostinfo == nil { + // Nothing here to tear down, got a bogus stage 2 packet return true } + hostinfo.Lock() defer hostinfo.Unlock() - if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { + ci := hostinfo.ConnectionState + if ci.ready { f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Info("Already seen this handshake packet") + Info("Handshake is already complete") + + // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets return false } - ci := hostinfo.ConnectionState - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) - - hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:])) - copy(hostinfo.HandshakePacket[2], packet[HeaderLen:]) - + //TODO: we need this to merge in https://github.com/flynn/noise/pull/39 msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) if err != nil { f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). @@ -307,6 +305,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") + + // This should be impossible in IX but just in case, if we get here then there is no chance to recover + // the handshake state machine. Tear it down return true } @@ -315,6 +316,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + + // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } @@ -323,12 +326,58 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Invalid certificate from host") + + // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } + vpnIP := ip2int(remoteCert.Details.Ips[0].IP) certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() + if vpnIP != hostinfo.hostId { + f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)). + WithField("udpAddr", addr).WithField("certName", certName). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Incorrect host responded to handshake") + + if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil { + // We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now + f.handshakeManager.pendingHostMap.DeleteHostInfo(ho) + } + + // Create a new hostinfo/handshake for the intended vpn ip, but first release the lock + f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + //TODO: this adds it to the timer wheel in a way that aggressively retries + newHostInfo := f.getOrHandshake(hostinfo.hostId) + newHostInfo.Lock() + + // Block the current used address + newHostInfo.unlockedBlockRemote(addr) + + // If this is an ongoing issue our previous hostmap will have some bad ips too + for _, v := range hostinfo.badRemotes { + newHostInfo.unlockedBlockRemote(v) + } + //TODO: this is me enabling tests + newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges) + + f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)). + WithField("remotes", newHostInfo.Remotes). + Info("Blocked addresses for handshakes") + + // Swap the packet store to benefit the original intended recipient + newHostInfo.packetStore = hostinfo.packetStore + hostinfo.packetStore = []*cachedPacket{} + + // Set the current hostId to the new vpnIp + hostinfo.hostId = vpnIP + newHostInfo.Unlock() + } + + // Mark packet 2 as seen so it doesn't show up as missed + ci.window.Update(f.l, 2) + duration := time.Since(hostinfo.handshakeStart).Nanoseconds() f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). @@ -338,29 +387,22 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ WithField("durationNs", duration). Info("Handshake message received") - //ci.remoteIndex = hs.ResponderIndex hostinfo.remoteIndexId = hs.Details.ResponderIndex hs.Details.Cert = ci.certState.rawCertificateNoKey - /* - hsBytes, err := proto.Marshal(hs) - if err != nil { - l.Debugln("Failed to marshal handshake: ", err) - return - } - */ - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - + // Store their cert and our symmetric keys ci.peerCert = remoteCert ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") + // Make sure the current udpAddr being used is set for responding hostinfo.SetRemote(addr) + + // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert) + // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp + //TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok? f.handshakeManager.Complete(hostinfo, f) hostinfo.handshakeComplete(f.l) f.metricHandshakes.Update(duration) diff --git a/hostmap.go b/hostmap.go index 6dc0fec..754b6c8 100644 --- a/hostmap.go +++ b/hostmap.go @@ -40,7 +40,7 @@ type HostInfo struct { sync.RWMutex remote *udpAddr - Remotes []*HostInfoDest + Remotes []*udpAddr promoteCounter uint32 ConnectionState *ConnectionState handshakeStart time.Time @@ -55,6 +55,10 @@ type HostInfo struct { recvError int remoteCidr *CIDRTree + // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. + // They should not be tried again during a handshake + badRemotes []*udpAddr + // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // with a handshake @@ -138,7 +142,7 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo { if _, ok := hm.Hosts[vpnIP]; !ok { hm.RUnlock() h = &HostInfo{ - Remotes: []*HostInfoDest{}, + Remotes: []*udpAddr{}, promoteCounter: 0, hostId: vpnIP, HandshakePacket: make(map[uint8][]byte, 0), @@ -308,12 +312,12 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo { i.AddRemote(remote) } else { i = &HostInfo{ - Remotes: []*HostInfoDest{NewHostInfoDest(remote)}, + Remotes: []*udpAddr{remote.Copy()}, promoteCounter: 0, hostId: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), } - i.remote = i.Remotes[0].addr + i.remote = i.Remotes[0] hm.Hosts[vpnIp] = i if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). @@ -409,7 +413,7 @@ func (hm *HostMap) PunchList() []*udpAddr { hm.RLock() for _, v := range hm.Hosts { for _, r := range v.Remotes { - list = append(list, r.addr) + list = append(list, r) } // if h, ok := hm.Hosts[vpnIp]; ok { // hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false) @@ -511,16 +515,14 @@ func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) { func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) { if len(i.Remotes) > 0 { for _, r := range i.Remotes { - rIP := r.addr.IP - for _, l := range preferredRanges { - if l.Contains(rIP) { - return r.addr, true + if l.Contains(r.IP) { + return r, true } } - if best == nil || !PrivateIP(rIP) { - best = r.addr + if best == nil || !PrivateIP(r.IP) { + best = r } /* for _, r := range i.Remotes { @@ -553,21 +555,21 @@ func (i *HostInfo) rotateRemote() { } if i.remote == nil { - i.remote = i.Remotes[0].addr + i.remote = i.Remotes[0] return } // We want to look at all but the very last entry since that is handled at the end for x := 0; x < len(i.Remotes)-1; x++ { // Find our current position and move to the next one in the list - if i.Remotes[x].addr.Equals(i.remote) { - i.remote = i.Remotes[x+1].addr + if i.Remotes[x].Equals(i.remote) { + i.remote = i.Remotes[x+1] return } } // Our current position was likely the last in the list, start over at 0 - i.remote = i.Remotes[0].addr + i.remote = i.Remotes[0] } func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { @@ -616,18 +618,21 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) { } } + i.badRemotes = make([]*udpAddr, 0) i.packetStore = make([]*cachedPacket, 0) i.ConnectionState.ready = true i.ConnectionState.queueLock.Unlock() i.ConnectionState.certState = nil } -func (i *HostInfo) RemoteUDPAddrs() []*udpAddr { - var addrs []*udpAddr - for _, r := range i.Remotes { - addrs = append(addrs, r.addr) +func (i *HostInfo) CopyRemotes() []*udpAddr { + i.RLock() + rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes)) + for x, addr := range i.Remotes { + rc[x] = addr.Copy() } - return addrs + i.RUnlock() + return rc } func (i *HostInfo) GetCert() *cert.NebulaCertificate { @@ -638,30 +643,57 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { } func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr { - //add := true + if i.unlockedIsBadRemote(remote) { + return i.remote + } + for _, r := range i.Remotes { - if r.addr.Equals(remote) { - return r.addr - //add = false + if r.Equals(remote) { + return r } } + // Trim this down if necessary if len(i.Remotes) > MaxRemotes { i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:] } - r := NewHostInfoDest(remote) - i.Remotes = append(i.Remotes, r) - return r.addr - //l.Debugf("Added remote %s for vpn ip", remote) + + rc := remote.Copy() + i.Remotes = append(i.Remotes, rc) + return rc } func (i *HostInfo) SetRemote(remote *udpAddr) { i.remote = i.AddRemote(remote) } +func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) { + if !i.unlockedIsBadRemote(remote) { + // We copy here because we are taking something else's memory and we can't trust everything + i.badRemotes = append(i.badRemotes, remote.Copy()) + } + + for k, v := range i.Remotes { + if v.Equals(remote) { + i.Remotes[k] = i.Remotes[len(i.Remotes)-1] + i.Remotes = i.Remotes[:len(i.Remotes)-1] + return + } + } +} + +func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool { + for _, v := range i.badRemotes { + if v.Equals(remote) { + return true + } + } + return false +} + func (i *HostInfo) ClearRemotes() { i.remote = nil - i.Remotes = []*HostInfoDest{} + i.Remotes = []*udpAddr{} } func (i *HostInfo) ClearConnectionState() { @@ -711,20 +743,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { //######################## -func NewHostInfoDest(addr *udpAddr) *HostInfoDest { - i := &HostInfoDest{ - addr: addr.Copy(), - } - return i -} - -func (hid *HostInfoDest) MarshalJSON() ([]byte, error) { - return json.Marshal(m{ - "address": hid.addr, - "probe_count": hid.probeCounter, - }) -} - /* func (hm *HostMap) DebugRemotes(vpnIp uint32) string { @@ -814,7 +832,10 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) - l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName") + if l.Level >= logrus.TraceLevel { + l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") + } + if !allow { continue } @@ -833,7 +854,9 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP { //TODO: Would be nice to filter out SLAAC MAC based ips as well if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { allow := allowList.Allow(ip) - l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow") + if l.Level >= logrus.TraceLevel { + l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + } if !allow { continue } diff --git a/inside.go b/inside.go index 8cb1e51..239b3c3 100644 --- a/inside.go +++ b/inside.go @@ -215,9 +215,11 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp } func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) { - for _, r := range hostInfo.RemoteUDPAddrs() { + hostInfo.RLock() + for _, r := range hostInfo.Remotes { f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b) } + hostInfo.RUnlock() } func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { diff --git a/ssh.go b/ssh.go index a9b0729..57c74d8 100644 --- a/ssh.go +++ b/ssh.go @@ -352,7 +352,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error "vpnIp": int2ip(v.hostId), "localIndex": v.localIndexId, "remoteIndex": v.remoteIndexId, - "remoteAddrs": v.RemoteUDPAddrs(), + "remoteAddrs": v.CopyRemotes(), "cachedPackets": len(v.packetStore), "cert": v.GetCert(), } @@ -372,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error } } else { for i, v := range hostMap.Hosts { - err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.RemoteUDPAddrs())) + err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes())) if err != nil { return err } diff --git a/tun_tester.go b/tun_tester.go index a7bbd4e..7c10cd5 100644 --- a/tun_tester.go +++ b/tun_tester.go @@ -28,8 +28,8 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int MTU: defaultMTU, UnsafeRoutes: unsafeRoutes, l: l, - rxPackets: make(chan []byte, 100), - txPackets: make(chan []byte, 100), + rxPackets: make(chan []byte, 1), + txPackets: make(chan []byte, 1), }, nil } @@ -41,6 +41,9 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (c *Tun) Send(packet []byte) { + if c.l.Level >= logrus.DebugLevel { + c.l.Debug("Tun injecting packet") + } c.rxPackets <- packet } diff --git a/tun_windows.go b/tun_windows.go index 594675e..234d653 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -1,3 +1,5 @@ +// +build !e2e_testing + package nebula import ( diff --git a/udp_tester.go b/udp_tester.go index b527837..8c71346 100644 --- a/udp_tester.go +++ b/udp_tester.go @@ -38,6 +38,7 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *udpConn) Send(packet *UdpPacket) { + u.l.Infof("UDP injecting packet %+v", packet) u.rxPackets <- packet } @@ -71,8 +72,8 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { } copy(p.Data, b) - copy(p.ToIp, addr.IP) - copy(p.FromIp, u.addr.IP) + copy(p.ToIp, addr.IP.To16()) + copy(p.FromIp, u.addr.IP.To16()) u.txPackets <- p return nil diff --git a/udp_windows.go b/udp_windows.go index dcfe884..2d3918c 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -1,3 +1,5 @@ +// +build !e2e_testing + package nebula // Windows support is primarily implemented in udp_generic, besides NewListenConfig