diff --git a/CHANGELOG.md b/CHANGELOG.md index be410f0..5e86b77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,10 +53,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 will immediately switch to a preferred remote address after the reception of a handshake packet (instead of waiting until 1,000 packets have been sent). (#532) - + - A race condition when `punchy.respond` is enabled and ensures the correct vpn ip is sent a punch back response in highly queried node. (#566) +- Fix a rare crash during handshake due to a race condition. (#535) + ## [1.4.0] - 2021-05-11 ### Added diff --git a/connection_manager_test.go b/connection_manager_test.go index 9da6ddc..9f2fe6e 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -57,7 +57,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { out := make([]byte, mtu) nc.HandleMonitorTick(now, p, nb, out) // Add an ip we have established a connection w/ to hostmap - hostinfo := nc.hostMap.AddVpnIp(vpnIp) + hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) hostinfo.ConnectionState = &ConnectionState{ certState: cs, H: &noise.HandshakeState{}, @@ -126,7 +126,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { out := make([]byte, mtu) nc.HandleMonitorTick(now, p, nb, out) // Add an ip we have established a connection w/ to hostmap - hostinfo := nc.hostMap.AddVpnIp(vpnIp) + hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) hostinfo.ConnectionState = &ConnectionState{ certState: cs, H: &noise.HandshakeState{}, @@ -232,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { defer cancel() nc := newConnectionManager(ctx, l, ifce, 5, 10) ifce.connectionManager = nc - hostinfo := nc.hostMap.AddVpnIp(vpnIp) + hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) hostinfo.ConnectionState = &ConnectionState{ certState: cs, peerCert: &peerCert, diff --git a/handshake_manager.go b/handshake_manager.go index 7f50c5b..42db182 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -191,13 +191,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l } } -func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { - hostinfo := c.pendingHostMap.AddVpnIp(vpnIp) - // We lock here and use an array to insert items to prevent locking the - // main receive thread for very long by waiting to add items to the pending map - //TODO: what lock? - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - c.metricInitiated.Inc(1) +func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { + hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) + + if created { + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) + c.metricInitiated.Inc(1) + } return hostinfo } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index b669050..dfc8d2c 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -27,7 +27,19 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) - i := blah.AddVpnIp(ip) + var initCalled bool + initFunc := func(*HostInfo) { + initCalled = true + } + + i := blah.AddVpnIp(ip, initFunc) + assert.True(t, initCalled) + + initCalled = false + i2 := blah.AddVpnIp(ip, initFunc) + assert.False(t, initCalled) + assert.Same(t, i, i2) + i.remotes = NewRemoteList() i.HandshakeReady = true @@ -71,7 +83,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - hi := blah.AddVpnIp(ip) + hi := blah.AddVpnIp(ip, nil) hi.HandshakeReady = true assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") diff --git a/hostmap.go b/hostmap.go index 6545307..d558100 100644 --- a/hostmap.go +++ b/hostmap.go @@ -134,24 +134,25 @@ func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) { hm.Unlock() } -func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { - h := &HostInfo{} +func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) { hm.RLock() - if _, ok := hm.Hosts[vpnIp]; !ok { + if h, ok := hm.Hosts[vpnIp]; !ok { hm.RUnlock() h = &HostInfo{ promoteCounter: 0, vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), } + if init != nil { + init(h) + } hm.Lock() hm.Hosts[vpnIp] = h hm.Unlock() - return h + return h, true } else { - h = hm.Hosts[vpnIp] hm.RUnlock() - return h + return h, false } } diff --git a/inside.go b/inside.go index 8a7c990..7ab083a 100644 --- a/inside.go +++ b/inside.go @@ -83,7 +83,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { if err != nil { hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp) + hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) } } ci := hostinfo.ConnectionState @@ -102,16 +102,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { return hostinfo } - if ci == nil { - // if we don't have a connection state, then send a handshake initiation - ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0) - hostinfo.ConnectionState = ci - } else if ci.eKey == nil { - // if we don't have any state at all, create it - } - // If we have already created the handshake packet, we don't want to call the function at all. if !hostinfo.HandshakeReady { ixHandshakeStage0(f, vpnIp, hostinfo) @@ -131,6 +121,12 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { return hostinfo } +// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that +// will create the initial Noise ConnectionState +func (f *Interface) initHostInfo(hostinfo *HostInfo) { + hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) +} + func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { fp := &firewall.Packet{} err := newPacket(p, false, fp) diff --git a/ssh.go b/ssh.go index e640dde..2f4374b 100644 --- a/ssh.go +++ b/ssh.go @@ -569,7 +569,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo) if addr != nil { hostInfo.SetRemote(addr) }