From db23fdf9bc5a7d76e92e14e2300bb790569eb7bc Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Tue, 27 Apr 2021 21:15:34 -0500 Subject: [PATCH] Dont apply race avoidance to existing handshakes, use the handshake time to determine who wins (#451) Co-authored-by: Wade Simmons --- handshake_ix.go | 26 +++++++++++++------------- handshake_manager.go | 9 +++++---- hostmap.go | 5 +++++ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/handshake_ix.go b/handshake_ix.go index de7a84c..6dec998 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -119,11 +119,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { } hostinfo := &HostInfo{ - ConnectionState: ci, - localIndexId: myIndex, - remoteIndexId: hs.Details.InitiatorIndex, - hostId: vpnIP, - HandshakePacket: make(map[uint8][]byte, 0), + ConnectionState: ci, + localIndexId: myIndex, + remoteIndexId: hs.Details.InitiatorIndex, + hostId: vpnIP, + HandshakePacket: make(map[uint8][]byte, 0), + lastHandshakeTime: hs.Details.Time, } hostinfo.Lock() @@ -138,6 +139,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { hs.Details.ResponderIndex = myIndex hs.Details.Cert = ci.certState.rawCertificateNoKey + // Update the time in case their clock is way off from ours + hs.Details.Time = uint64(time.Now().Unix()) hsBytes, err := proto.Marshal(hs) if err != nil { @@ -204,18 +207,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { } return case ErrExistingHostInfo: - // This means there was an existing tunnel and we didn't win - // handshake avoidance - - //TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing - //TODO: if not new send a test packet like old - + // This means there was an existing tunnel and this handshake was older than the one we are currently based on f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). + WithField("oldHandshakeTime", existing.lastHandshakeTime). + WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("fingerprint", fingerprint). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Prevented a handshake race") + Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) @@ -394,7 +394,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ Info("Handshake message received") hostinfo.remoteIndexId = hs.Details.ResponderIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hostinfo.lastHandshakeTime = hs.Details.Time // Store their cert and our symmetric keys ci.peerCert = remoteCert diff --git a/handshake_manager.go b/handshake_manager.go index fec80a0..90afc62 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -199,7 +199,7 @@ var ( // exact same handshake packet // // ErrExistingHostInfo if we already have an entry in the hostmap for this -// VpnIP and overwrite was false. +// VpnIP and the new handshake was older than the one we currently have // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. @@ -217,10 +217,12 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrAlreadySeen } - if !overwrite { - // It's a new handshake and we lost the race + // Is this a newer handshake? + if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime { return existingHostInfo, ErrExistingHostInfo } + + existingHostInfo.logger(c.l).Info("Taking new handshake") } existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] @@ -261,7 +263,6 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket } if existingHostInfo != nil { - hostinfo.logger(c.l).Info("Race lost, taking new handshake") // We are going to overwrite this entry, so remove the old references delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) diff --git a/hostmap.go b/hostmap.go index 3f04793..4e988bf 100644 --- a/hostmap.go +++ b/hostmap.go @@ -59,6 +59,11 @@ type HostInfo struct { // with a handshake lastRebindCount int8 + // lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally + // Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator + // This is used to avoid an attack where a handshake packet is replayed after some time + lastHandshakeTime uint64 + lastRoam time.Time lastRoamRemote *udpAddr }