Dont apply race avoidance to existing handshakes, use the handshake time to determine who wins (#451)

Co-authored-by: Wade Simmons <wadey@slack-corp.com>
This commit is contained in:
Nathan Brown 2021-04-27 21:15:34 -05:00 committed by GitHub
parent df7c7eec4a
commit db23fdf9bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 17 deletions

View File

@ -119,11 +119,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
} }
hostinfo := &HostInfo{ hostinfo := &HostInfo{
ConnectionState: ci, ConnectionState: ci,
localIndexId: myIndex, localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex, remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP, hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
} }
hostinfo.Lock() hostinfo.Lock()
@ -138,6 +139,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey hs.Details.Cert = ci.certState.rawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().Unix())
hsBytes, err := proto.Marshal(hs) hsBytes, err := proto.Marshal(hs)
if err != nil { if err != nil {
@ -204,18 +207,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
} }
return return
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win // This means there was an existing tunnel and this handshake was older than the one we are currently based on
// handshake avoidance
//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
//TODO: if not new send a test packet like old
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@ -394,7 +394,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
Info("Handshake message received") Info("Handshake message received")
hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.remoteIndexId = hs.Details.ResponderIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey hostinfo.lastHandshakeTime = hs.Details.Time
// Store their cert and our symmetric keys // Store their cert and our symmetric keys
ci.peerCert = remoteCert ci.peerCert = remoteCert

View File

@ -199,7 +199,7 @@ var (
// exact same handshake packet // exact same handshake packet
// //
// ErrExistingHostInfo if we already have an entry in the hostmap for this // ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and overwrite was false. // VpnIP and the new handshake was older than the one we currently have
// //
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
@ -217,10 +217,12 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingHostInfo, ErrAlreadySeen return existingHostInfo, ErrAlreadySeen
} }
if !overwrite { // Is this a newer handshake?
// It's a new handshake and we lost the race if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime {
return existingHostInfo, ErrExistingHostInfo return existingHostInfo, ErrExistingHostInfo
} }
existingHostInfo.logger(c.l).Info("Taking new handshake")
} }
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
@ -261,7 +263,6 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
} }
if existingHostInfo != nil { if existingHostInfo != nil {
hostinfo.logger(c.l).Info("Race lost, taking new handshake")
// We are going to overwrite this entry, so remove the old references // We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId) delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)

View File

@ -59,6 +59,11 @@ type HostInfo struct {
// with a handshake // with a handshake
lastRebindCount int8 lastRebindCount int8
// lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally
// Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator
// This is used to avoid an attack where a handshake packet is replayed after some time
lastHandshakeTime uint64
lastRoam time.Time lastRoam time.Time
lastRoamRemote *udpAddr lastRoamRemote *udpAddr
} }