From 6c55d67f18bbcde3cfbb059d4c477a80707be102 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 12 Mar 2021 14:16:25 -0500 Subject: [PATCH] Refactor handshake_ix (#401) There are some subtle race conditions with the previous handshake_ix implementation, mostly around collisions with localIndexId. This change refactors it so that we have a "commit" phase during the handshake where we grab the lock for the hostmap and ensure that we have a unique local index before storing it. We also now avoid using the pending hostmap at all for receiving stage1 packets, since we have everything we need to just store the completed handshake. Co-authored-by: Nate Brown Co-authored-by: Ryan Huber Co-authored-by: forfuncsake --- handshake.go | 19 +- handshake_ix.go | 402 ++++++++++++++++++-------------------- handshake_manager.go | 126 +++++++++++- handshake_manager_test.go | 30 ++- hostmap.go | 80 ++------ outside.go | 3 - 6 files changed, 345 insertions(+), 315 deletions(-) diff --git a/handshake.go b/handshake.go index cd929a6..6ed0014 100644 --- a/handshake.go +++ b/handshake.go @@ -6,30 +6,23 @@ const ( ) func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { - newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) - //TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases - //if err != nil { - // l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message") - // return - //} - if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } - tearDown := false switch h.Subtype { case handshakeIXPSK0: switch h.MessageCounter { case 1: - tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h) + ixHandshakeStage1(f, addr, packet, h) case 2: - tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h) + newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) + tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h) + if tearDown && newHostinfo != nil { + f.handshakeManager.DeleteHostInfo(newHostinfo) + } } } - if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteHostInfo(newHostinfo) - } } diff --git a/handshake_ix.go b/handshake_ix.go index d43804a..813c68d 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -25,17 +25,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } } - myIndex, err := generateIndex() + err := f.handshakeManager.AddIndexHostInfo(hostinfo) if err != nil { l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return } + ci := hostinfo.ConnectionState - f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo) hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: myIndex, + InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().Unix()), Cert: ci.certState.rawCertificateNoKey, } @@ -73,122 +73,140 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } -func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { - var ip uint32 - if h.RemoteIndex == 0 { - ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(1) +func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { + ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) + // Mark packet 1 as seen so it doesn't show up as missed + ci.window.Update(1) - msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) - if err != nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") - return true - } + msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + return + } - hs := &NebulaHandshake{} - err = proto.Unmarshal(msg, hs) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ - if err != nil || hs.Details == nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") - return true - } + hs := &NebulaHandshake{} + err = proto.Unmarshal(msg, hs) + /* + l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) + */ + if err != nil || hs.Details == nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + return + } - hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) - if hostinfo != nil { - hostinfo.RLock() - defer hostinfo.RUnlock() + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Invalid certificate from host") + return + } + vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + certName := remoteCert.Details.Name + fingerprint, _ := remoteCert.Sha256Sum() - if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { - if msg, ok := hostinfo.HandshakePacket[2]; ok { - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) - err := f.outside.WriteTo(msg, addr) - if err != nil { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") - } - return false - } - - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true). - WithField("packets", hostinfo.HandshakePacket). - Error("Seen this handshake packet already but don't have a cached packet to return") - } - } - - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) - if err != nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Invalid certificate from host") - return true - } - vpnIP := ip2int(remoteCert.Details.Ips[0].IP) - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - - myIndex, err := generateIndex() - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") - return true - } - - hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") - - return true - } - hostinfo.Lock() - defer hostinfo.Unlock() - - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + myIndex, err := generateIndex() + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message received") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + return + } - f.handshakeManager.addRemoteIndexHostInfo(hs.Details.InitiatorIndex, hostinfo) - hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hostinfo := &HostInfo{ + ConnectionState: ci, + Remotes: []*HostInfoDest{}, + localIndexId: myIndex, + remoteIndexId: hs.Details.InitiatorIndex, + hostId: vpnIP, + HandshakePacket: make(map[uint8][]byte, 0), + } - hsBytes, err := proto.Marshal(hs) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return true - } + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message received") - header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return true - } + hs.Details.ResponderIndex = myIndex + hs.Details.Cert = ci.certState.rawCertificateNoKey - if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { + hsBytes, err := proto.Marshal(hs) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + return + } + + header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) + msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + return + } else if dKey == nil || eKey == nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") + return + } + + hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) + copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) + + // Regardless of whether you are the sender or receiver, you should arrive here + // and complete standing up the connection. + hostinfo.HandshakePacket[2] = make([]byte, len(msg)) + copy(hostinfo.HandshakePacket[2], msg) + + // We are sending handshake packet 2, so we don't expect to receive + // handshake packet 2 from the initiator. + ci.window.Update(2) + + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") + + //hostinfo.ClearRemotes() + hostinfo.AddRemote(*addr) + hostinfo.ForcePromoteBest(f.hostMap.preferredRanges) + hostinfo.CreateRemoteCIDR(remoteCert) + + hostinfo.Lock() + defer hostinfo.Unlock() + + // Only overwrite existing record if we should win the handshake race + overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) + existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) + if err != nil { + switch err { + case ErrAlreadySeen: + msg = existing.HandshakePacket[2] + f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + err := f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + WithError(err).Error("Failed to send handshake message") + } else { + l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + Info("Handshake message sent") + } + return + case ErrExistingHostInfo: + // This means there was an existing tunnel and we didn't win + // handshake avoidance l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -198,82 +216,52 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - return true - } - - hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) - copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - if dKey != nil && eKey != nil { - hostinfo.HandshakePacket[2] = make([]byte, len(msg)) - copy(hostinfo.HandshakePacket[2], msg) - - // We are sending handshake packet 2, so we don't expect to receive - // handshake packet 2 from the initiator. - ci.window.Update(2) - - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) - err := f.outside.WriteTo(msg, addr) - if err != nil { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake") - } else { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") - } - - ip = ip2int(remoteCert.Details.Ips[0].IP) - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") - - //hostinfo.ClearRemotes() - hostinfo.AddRemote(*addr) - hostinfo.CreateRemoteCIDR(remoteCert) - f.lightHouse.AddRemoteAndReset(ip, addr) - if f.serveDns { - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) - } - - ho, err := f.hostMap.QueryVpnIP(vpnIP) - if err == nil && ho.localIndexId != 0 { - l.WithField("vpnIp", vpnIP). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("action", "removing stale index"). - WithField("index", ho.localIndexId). - WithField("remoteIndex", ho.remoteIndexId). - Debug("Handshake processing") - f.hostMap.DeleteHostInfo(ho) - } - - f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) - - hostinfo.handshakeComplete() - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + return + case ErrLocalIndexCollision: + // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") - return true + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). + Error("Failed to add HostInfo due to localIndex collision") + return + default: + // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete + // And we forget to update it here + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to add HostInfo to HostMap") + return } - } - f.hostMap.AddRemote(ip, addr) - return false + // Do the send + f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + err = f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake") + } else { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake message sent") + } + + hostinfo.handshakeComplete() + + return } func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { @@ -286,7 +274,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Already seen this handshake packet") + Info("Already seen this handshake packet") return false } @@ -307,6 +295,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false + } else if dKey == nil || eKey == nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Error("Noise did not arrive at a key") + return true } hs := &NebulaHandshake{} @@ -351,45 +344,20 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // Regardless of whether you are the sender or receiver, you should arrive here // and complete standing up the connection. - if dKey != nil && eKey != nil { - ip := ip2int(remoteCert.Details.Ips[0].IP) - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") - //hostinfo.ClearRemotes() - f.hostMap.AddRemote(ip, addr) - hostinfo.CreateRemoteCIDR(remoteCert) - f.lightHouse.AddRemoteAndReset(ip, addr) - if f.serveDns { - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) - } + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") - ho, err := f.hostMap.QueryVpnIP(vpnIP) - if err == nil && ho.localIndexId != 0 { - l.WithField("vpnIp", vpnIP). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("action", "removing stale index"). - WithField("index", ho.localIndexId). - WithField("remoteIndex", ho.remoteIndexId). - Debug("Handshake processing") - f.hostMap.DeleteHostInfo(ho) - } + //hostinfo.ClearRemotes() + hostinfo.AddRemote(*addr) + hostinfo.ForcePromoteBest(f.hostMap.preferredRanges) + hostinfo.CreateRemoteCIDR(remoteCert) - f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) - - hostinfo.handshakeComplete() - f.metricHandshakes.Update(duration) - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") - return true - } + f.handshakeManager.Complete(hostinfo, f) + hostinfo.handshakeComplete() + f.metricHandshakes.Update(duration) return false } diff --git a/handshake_manager.go b/handshake_manager.go index 2223e74..f82e603 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -1,9 +1,10 @@ package nebula import ( + "bytes" "crypto/rand" "encoding/binary" - "fmt" + "errors" "net" "time" @@ -196,18 +197,123 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { return hostinfo } -func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { - hostinfo, err := c.pendingHostMap.AddIndex(index, ci) - if err != nil { - return nil, fmt.Errorf("Issue adding index: %d", index) +var ( + ErrExistingHostInfo = errors.New("existing hostinfo") + ErrAlreadySeen = errors.New("already seen") + ErrLocalIndexCollision = errors.New("local index collision") +) + +// CheckAndComplete checks for any conflicts in the main and pending hostmap +// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be: + +// ErrAlreadySeen if we already have an entry in the hostmap that has seen the +// exact same handshake packet +// +// ErrExistingHostInfo if we already have an entry in the hostmap for this +// VpnIP and overwrite was false. +// +// ErrLocalIndexCollision if we already have an entry in the main or pending +// hostmap for the hostinfo.localIndexId. +func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { + c.pendingHostMap.RLock() + defer c.pendingHostMap.RUnlock() + c.mainHostMap.Lock() + defer c.mainHostMap.Unlock() + + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + if found && existingHostInfo != nil { + if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { + return existingHostInfo, ErrAlreadySeen + } + if !overwrite { + return existingHostInfo, ErrExistingHostInfo + } } - //c.mainHostMap.AddIndexHostInfo(index, hostinfo) - c.InboundHandshakeTimer.Add(index, time.Second*10) - return hostinfo, nil + + existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + if found { + // We have a collision, but for a different hostinfo + return existingIndex, ErrLocalIndexCollision + } + existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] + if found && existingIndex != hostinfo { + // We have a collision, but for a different hostinfo + return existingIndex, ErrLocalIndexCollision + } + + existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { + // We have a collision, but this can happen since we can't control + // the remote ID. Just log about the situation as a note. + hostinfo.logger(). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + Info("New host shadows existing host remoteIndex") + } + + if existingHostInfo != nil { + // We are going to overwrite this entry, so remove the old references + delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) + delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) + } + + c.mainHostMap.addHostInfo(hostinfo, f) + return existingHostInfo, nil } -func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { - c.pendingHostMap.AddIndexHostInfo(index, h) +// Complete is a simpler version of CheckAndComplete when we already know we +// won't have a localIndexId collision because we already have an entry in the +// pendingHostMap +func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + c.mainHostMap.Lock() + defer c.mainHostMap.Unlock() + + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + if found && existingHostInfo != nil { + // We are going to overwrite this entry, so remove the old references + delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) + delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) + } + + existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil { + // We have a collision, but this can happen since we can't control + // the remote ID. Just log about the situation as a note. + hostinfo.logger(). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + Info("New host shadows existing host remoteIndex") + } + + c.mainHostMap.addHostInfo(hostinfo, f) +} + +// AddIndexHostInfo generates a unique localIndexId for this HostInfo +// and adds it to the pendingHostMap. Will error if we are unable to generate +// a unique localIndexId +func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { + c.pendingHostMap.Lock() + defer c.pendingHostMap.Unlock() + c.mainHostMap.RLock() + defer c.mainHostMap.RUnlock() + + for i := 0; i < 32; i++ { + index, err := generateIndex() + if err != nil { + return err + } + + _, inPending := c.pendingHostMap.Indexes[index] + _, inMain := c.mainHostMap.Indexes[index] + + if !inMain && !inPending { + h.localIndexId = index + c.pendingHostMap.Indexes[index] = h + return nil + } + } + + return errors.New("failed to generate unique localIndexId") } func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index c4f1685..b1e1808 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -8,12 +8,11 @@ import ( "github.com/stretchr/testify/assert" ) -var indexes []uint32 = []uint32{1000, 2000, 3000, 4000} - //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} var ips []uint32 func Test_NewHandshakeManagerIndex(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -26,9 +25,18 @@ func Test_NewHandshakeManagerIndex(t *testing.T) { now := time.Now() blah.NextInboundHandshakeTimerTick(now) + var indexes = make([]uint32, 4) + var hostinfo = make([]*HostInfo, len(indexes)) + for i := range indexes { + hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}} + } + // Add four indexes - for _, v := range indexes { - blah.AddIndex(v, &ConnectionState{}) + for i := range indexes { + err := blah.AddIndexHostInfo(hostinfo[i]) + assert.NoError(t, err) + indexes[i] = hostinfo[i].localIndexId + blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10) } // Confirm they are in the pending index list for _, v := range indexes { @@ -169,8 +177,11 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { hostinfo := blah.AddVpnIP(vpnIP) // Pretned we have an index too - blah.AddIndexHostInfo(12341234, hostinfo) - assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) + err := blah.AddIndexHostInfo(hostinfo) + assert.NoError(t, err) + blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) + assert.NotZero(t, hostinfo.localIndexId) + assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId) // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending // but not main hostmap @@ -216,7 +227,10 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { now := time.Now() blah.NextInboundHandshakeTimerTick(now) - hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{}) + hostinfo := &HostInfo{ConnectionState: &ConnectionState{}} + err := blah.AddIndexHostInfo(hostinfo) + assert.NoError(t, err) + blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) // Pretned we have an index too blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) @@ -229,7 +243,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3) blah.NextInboundHandshakeTimerTick(next_tick) assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) - assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) + assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId)) } type mockEncWriter struct { diff --git a/hostmap.go b/hostmap.go index e04c14d..c252f42 100644 --- a/hostmap.go +++ b/hostmap.go @@ -166,40 +166,6 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { } } -func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { - hm.Lock() - if _, ok := hm.Indexes[index]; !ok { - h := &HostInfo{ - ConnectionState: ci, - Remotes: []*HostInfoDest{}, - localIndexId: index, - HandshakePacket: make(map[uint8][]byte, 0), - } - hm.Indexes[index] = h - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). - Debug("Hostmap index added") - - hm.Unlock() - return h, nil - } - hm.Unlock() - return nil, fmt.Errorf("refusing to overwrite existing index: %d", index) -} - -func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) { - hm.Lock() - h.localIndexId = index - hm.Indexes[index] = h - hm.Unlock() - - if l.Level > logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). - Debug("Hostmap index added") - } -} - // Only used by pendingHostMap when the remote index is not initially known func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { hm.Lock() @@ -234,16 +200,12 @@ func (hm *HostMap) DeleteIndex(index uint32) { hm.Lock() hostinfo, ok := hm.Indexes[index] if ok { - hostinfo.Lock() - defer hostinfo.Unlock() - delete(hm.Indexes, index) delete(hm.RemoteIndexes, hostinfo.remoteIndexId) // Check if we have an entry under hostId that matches the same hostinfo // instance. Clean it up as well if we do. - var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.hostId] + hostinfo2, ok := hm.Hosts[hostinfo.hostId] if ok && hostinfo2 == hostinfo { delete(hm.Hosts, hostinfo.hostId) } @@ -400,36 +362,26 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { } } -func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { - hm.RLock() - if i, ok := hm.Hosts[vpnIP]; ok { - if i == nil { - hm.RUnlock() - return false - } - complete := i.HandshakeComplete - hm.RUnlock() - return complete +// We already have the hm Lock when this is called, so make sure to not call +// any other methods that might try to grab it again +func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { + remoteCert := hostinfo.ConnectionState.peerCert + ip := ip2int(remoteCert.Details.Ips[0].IP) + f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote) + if f.serveDns { + dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } - hm.RUnlock() - return false -} -func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool { - hm.RLock() - if i, ok := hm.Indexes[index]; ok { - if i == nil { - hm.RUnlock() - return false - } - complete := i.HandshakeComplete - hm.RUnlock() - return complete + hm.Hosts[hostinfo.hostId] = hostinfo + hm.Indexes[hostinfo.localIndexId] = hostinfo + hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo + if l.Level >= logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). + Debug("Hostmap vpnIp added") } - hm.RUnlock() - return false } func (hm *HostMap) ClearRemotes(vpnIP uint32) { diff --git a/outside.go b/outside.go index 093c9ff..ec3b74f 100644 --- a/outside.go +++ b/outside.go @@ -106,7 +106,6 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, case recvError: f.messageMetrics.Rx(header.Type, header.Subtype, 1) - // TODO: Remove this with recv_error deprecation f.handleRecvError(addr, header) return @@ -312,8 +311,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { } func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { - // This flag is to stop caring about recv_error from old versions - // This should go away when the old version is gone from prod if l.Level >= logrus.DebugLevel { l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr).