diff --git a/handshake.go b/handshake.go index cd929a6..6ed0014 100644 --- a/handshake.go +++ b/handshake.go @@ -6,30 +6,23 @@ const ( ) func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { - newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) - //TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases - //if err != nil { - // l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message") - // return - //} - if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } - tearDown := false switch h.Subtype { case handshakeIXPSK0: switch h.MessageCounter { case 1: - tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h) + ixHandshakeStage1(f, addr, packet, h) case 2: - tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h) + newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) + tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h) + if tearDown && newHostinfo != nil { + f.handshakeManager.DeleteHostInfo(newHostinfo) + } } } - if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteHostInfo(newHostinfo) - } } diff --git a/handshake_ix.go b/handshake_ix.go index d43804a..813c68d 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -25,17 +25,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } } - myIndex, err := generateIndex() + err := f.handshakeManager.AddIndexHostInfo(hostinfo) if err != nil { l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return } + ci := hostinfo.ConnectionState - f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo) hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: myIndex, + InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().Unix()), Cert: ci.certState.rawCertificateNoKey, } @@ -73,122 +73,140 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } -func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { - var ip uint32 - if h.RemoteIndex == 0 { - ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(1) +func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { + ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) + // Mark packet 1 as seen so it doesn't show up as missed + ci.window.Update(1) - msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) - if err != nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") - return true - } + msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + return + } - hs := &NebulaHandshake{} - err = proto.Unmarshal(msg, hs) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ - if err != nil || hs.Details == nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") - return true - } + hs := &NebulaHandshake{} + err = proto.Unmarshal(msg, hs) + /* + l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) + */ + if err != nil || hs.Details == nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + return + } - hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) - if hostinfo != nil { - hostinfo.RLock() - defer hostinfo.RUnlock() + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Invalid certificate from host") + return + } + vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + certName := remoteCert.Details.Name + fingerprint, _ := remoteCert.Sha256Sum() - if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { - if msg, ok := hostinfo.HandshakePacket[2]; ok { - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) - err := f.outside.WriteTo(msg, addr) - if err != nil { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") - } - return false - } - - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true). - WithField("packets", hostinfo.HandshakePacket). - Error("Seen this handshake packet already but don't have a cached packet to return") - } - } - - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) - if err != nil { - l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Invalid certificate from host") - return true - } - vpnIP := ip2int(remoteCert.Details.Ips[0].IP) - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - - myIndex, err := generateIndex() - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") - return true - } - - hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") - - return true - } - hostinfo.Lock() - defer hostinfo.Unlock() - - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + myIndex, err := generateIndex() + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message received") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + return + } - f.handshakeManager.addRemoteIndexHostInfo(hs.Details.InitiatorIndex, hostinfo) - hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hostinfo := &HostInfo{ + ConnectionState: ci, + Remotes: []*HostInfoDest{}, + localIndexId: myIndex, + remoteIndexId: hs.Details.InitiatorIndex, + hostId: vpnIP, + HandshakePacket: make(map[uint8][]byte, 0), + } - hsBytes, err := proto.Marshal(hs) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return true - } + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message received") - header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) - if err != nil { - l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return true - } + hs.Details.ResponderIndex = myIndex + hs.Details.Cert = ci.certState.rawCertificateNoKey - if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { + hsBytes, err := proto.Marshal(hs) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + return + } + + header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) + msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + return + } else if dKey == nil || eKey == nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") + return + } + + hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) + copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) + + // Regardless of whether you are the sender or receiver, you should arrive here + // and complete standing up the connection. + hostinfo.HandshakePacket[2] = make([]byte, len(msg)) + copy(hostinfo.HandshakePacket[2], msg) + + // We are sending handshake packet 2, so we don't expect to receive + // handshake packet 2 from the initiator. + ci.window.Update(2) + + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") + + //hostinfo.ClearRemotes() + hostinfo.AddRemote(*addr) + hostinfo.ForcePromoteBest(f.hostMap.preferredRanges) + hostinfo.CreateRemoteCIDR(remoteCert) + + hostinfo.Lock() + defer hostinfo.Unlock() + + // Only overwrite existing record if we should win the handshake race + overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP) + existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) + if err != nil { + switch err { + case ErrAlreadySeen: + msg = existing.HandshakePacket[2] + f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + err := f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + WithError(err).Error("Failed to send handshake message") + } else { + l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + Info("Handshake message sent") + } + return + case ErrExistingHostInfo: + // This means there was an existing tunnel and we didn't win + // handshake avoidance l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -198,82 +216,52 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - return true - } - - hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) - copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - if dKey != nil && eKey != nil { - hostinfo.HandshakePacket[2] = make([]byte, len(msg)) - copy(hostinfo.HandshakePacket[2], msg) - - // We are sending handshake packet 2, so we don't expect to receive - // handshake packet 2 from the initiator. - ci.window.Update(2) - - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) - err := f.outside.WriteTo(msg, addr) - if err != nil { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake") - } else { - l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") - } - - ip = ip2int(remoteCert.Details.Ips[0].IP) - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") - - //hostinfo.ClearRemotes() - hostinfo.AddRemote(*addr) - hostinfo.CreateRemoteCIDR(remoteCert) - f.lightHouse.AddRemoteAndReset(ip, addr) - if f.serveDns { - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) - } - - ho, err := f.hostMap.QueryVpnIP(vpnIP) - if err == nil && ho.localIndexId != 0 { - l.WithField("vpnIp", vpnIP). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("action", "removing stale index"). - WithField("index", ho.localIndexId). - WithField("remoteIndex", ho.remoteIndexId). - Debug("Handshake processing") - f.hostMap.DeleteHostInfo(ho) - } - - f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) - - hostinfo.handshakeComplete() - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + return + case ErrLocalIndexCollision: + // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") - return true + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)). + Error("Failed to add HostInfo due to localIndex collision") + return + default: + // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete + // And we forget to update it here + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to add HostInfo to HostMap") + return } - } - f.hostMap.AddRemote(ip, addr) - return false + // Do the send + f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + err = f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake") + } else { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake message sent") + } + + hostinfo.handshakeComplete() + + return } func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { @@ -286,7 +274,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Already seen this handshake packet") + Info("Already seen this handshake packet") return false } @@ -307,6 +295,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false + } else if dKey == nil || eKey == nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Error("Noise did not arrive at a key") + return true } hs := &NebulaHandshake{} @@ -351,45 +344,20 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ // Regardless of whether you are the sender or receiver, you should arrive here // and complete standing up the connection. - if dKey != nil && eKey != nil { - ip := ip2int(remoteCert.Details.Ips[0].IP) - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - //l.Debugln("got symmetric pairs") - //hostinfo.ClearRemotes() - f.hostMap.AddRemote(ip, addr) - hostinfo.CreateRemoteCIDR(remoteCert) - f.lightHouse.AddRemoteAndReset(ip, addr) - if f.serveDns { - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) - } + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") - ho, err := f.hostMap.QueryVpnIP(vpnIP) - if err == nil && ho.localIndexId != 0 { - l.WithField("vpnIp", vpnIP). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("action", "removing stale index"). - WithField("index", ho.localIndexId). - WithField("remoteIndex", ho.remoteIndexId). - Debug("Handshake processing") - f.hostMap.DeleteHostInfo(ho) - } + //hostinfo.ClearRemotes() + hostinfo.AddRemote(*addr) + hostinfo.ForcePromoteBest(f.hostMap.preferredRanges) + hostinfo.CreateRemoteCIDR(remoteCert) - f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) - - hostinfo.handshakeComplete() - f.metricHandshakes.Update(duration) - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") - return true - } + f.handshakeManager.Complete(hostinfo, f) + hostinfo.handshakeComplete() + f.metricHandshakes.Update(duration) return false } diff --git a/handshake_manager.go b/handshake_manager.go index 2223e74..f82e603 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -1,9 +1,10 @@ package nebula import ( + "bytes" "crypto/rand" "encoding/binary" - "fmt" + "errors" "net" "time" @@ -196,18 +197,123 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { return hostinfo } -func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { - hostinfo, err := c.pendingHostMap.AddIndex(index, ci) - if err != nil { - return nil, fmt.Errorf("Issue adding index: %d", index) +var ( + ErrExistingHostInfo = errors.New("existing hostinfo") + ErrAlreadySeen = errors.New("already seen") + ErrLocalIndexCollision = errors.New("local index collision") +) + +// CheckAndComplete checks for any conflicts in the main and pending hostmap +// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be: + +// ErrAlreadySeen if we already have an entry in the hostmap that has seen the +// exact same handshake packet +// +// ErrExistingHostInfo if we already have an entry in the hostmap for this +// VpnIP and overwrite was false. +// +// ErrLocalIndexCollision if we already have an entry in the main or pending +// hostmap for the hostinfo.localIndexId. +func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { + c.pendingHostMap.RLock() + defer c.pendingHostMap.RUnlock() + c.mainHostMap.Lock() + defer c.mainHostMap.Unlock() + + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + if found && existingHostInfo != nil { + if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { + return existingHostInfo, ErrAlreadySeen + } + if !overwrite { + return existingHostInfo, ErrExistingHostInfo + } } - //c.mainHostMap.AddIndexHostInfo(index, hostinfo) - c.InboundHandshakeTimer.Add(index, time.Second*10) - return hostinfo, nil + + existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + if found { + // We have a collision, but for a different hostinfo + return existingIndex, ErrLocalIndexCollision + } + existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] + if found && existingIndex != hostinfo { + // We have a collision, but for a different hostinfo + return existingIndex, ErrLocalIndexCollision + } + + existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { + // We have a collision, but this can happen since we can't control + // the remote ID. Just log about the situation as a note. + hostinfo.logger(). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + Info("New host shadows existing host remoteIndex") + } + + if existingHostInfo != nil { + // We are going to overwrite this entry, so remove the old references + delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) + delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) + } + + c.mainHostMap.addHostInfo(hostinfo, f) + return existingHostInfo, nil } -func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { - c.pendingHostMap.AddIndexHostInfo(index, h) +// Complete is a simpler version of CheckAndComplete when we already know we +// won't have a localIndexId collision because we already have an entry in the +// pendingHostMap +func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + c.mainHostMap.Lock() + defer c.mainHostMap.Unlock() + + existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId] + if found && existingHostInfo != nil { + // We are going to overwrite this entry, so remove the old references + delete(c.mainHostMap.Hosts, existingHostInfo.hostId) + delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) + delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) + } + + existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil { + // We have a collision, but this can happen since we can't control + // the remote ID. Just log about the situation as a note. + hostinfo.logger(). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). + Info("New host shadows existing host remoteIndex") + } + + c.mainHostMap.addHostInfo(hostinfo, f) +} + +// AddIndexHostInfo generates a unique localIndexId for this HostInfo +// and adds it to the pendingHostMap. Will error if we are unable to generate +// a unique localIndexId +func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { + c.pendingHostMap.Lock() + defer c.pendingHostMap.Unlock() + c.mainHostMap.RLock() + defer c.mainHostMap.RUnlock() + + for i := 0; i < 32; i++ { + index, err := generateIndex() + if err != nil { + return err + } + + _, inPending := c.pendingHostMap.Indexes[index] + _, inMain := c.mainHostMap.Indexes[index] + + if !inMain && !inPending { + h.localIndexId = index + c.pendingHostMap.Indexes[index] = h + return nil + } + } + + return errors.New("failed to generate unique localIndexId") } func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index c4f1685..b1e1808 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -8,12 +8,11 @@ import ( "github.com/stretchr/testify/assert" ) -var indexes []uint32 = []uint32{1000, 2000, 3000, 4000} - //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} var ips []uint32 func Test_NewHandshakeManagerIndex(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") @@ -26,9 +25,18 @@ func Test_NewHandshakeManagerIndex(t *testing.T) { now := time.Now() blah.NextInboundHandshakeTimerTick(now) + var indexes = make([]uint32, 4) + var hostinfo = make([]*HostInfo, len(indexes)) + for i := range indexes { + hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}} + } + // Add four indexes - for _, v := range indexes { - blah.AddIndex(v, &ConnectionState{}) + for i := range indexes { + err := blah.AddIndexHostInfo(hostinfo[i]) + assert.NoError(t, err) + indexes[i] = hostinfo[i].localIndexId + blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10) } // Confirm they are in the pending index list for _, v := range indexes { @@ -169,8 +177,11 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { hostinfo := blah.AddVpnIP(vpnIP) // Pretned we have an index too - blah.AddIndexHostInfo(12341234, hostinfo) - assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) + err := blah.AddIndexHostInfo(hostinfo) + assert.NoError(t, err) + blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) + assert.NotZero(t, hostinfo.localIndexId) + assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId) // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending // but not main hostmap @@ -216,7 +227,10 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { now := time.Now() blah.NextInboundHandshakeTimerTick(now) - hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{}) + hostinfo := &HostInfo{ConnectionState: &ConnectionState{}} + err := blah.AddIndexHostInfo(hostinfo) + assert.NoError(t, err) + blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10) // Pretned we have an index too blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) @@ -229,7 +243,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3) blah.NextInboundHandshakeTimerTick(next_tick) assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) - assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) + assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId)) } type mockEncWriter struct { diff --git a/hostmap.go b/hostmap.go index e04c14d..c252f42 100644 --- a/hostmap.go +++ b/hostmap.go @@ -166,40 +166,6 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { } } -func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { - hm.Lock() - if _, ok := hm.Indexes[index]; !ok { - h := &HostInfo{ - ConnectionState: ci, - Remotes: []*HostInfoDest{}, - localIndexId: index, - HandshakePacket: make(map[uint8][]byte, 0), - } - hm.Indexes[index] = h - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). - Debug("Hostmap index added") - - hm.Unlock() - return h, nil - } - hm.Unlock() - return nil, fmt.Errorf("refusing to overwrite existing index: %d", index) -} - -func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) { - hm.Lock() - h.localIndexId = index - hm.Indexes[index] = h - hm.Unlock() - - if l.Level > logrus.DebugLevel { - l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). - Debug("Hostmap index added") - } -} - // Only used by pendingHostMap when the remote index is not initially known func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { hm.Lock() @@ -234,16 +200,12 @@ func (hm *HostMap) DeleteIndex(index uint32) { hm.Lock() hostinfo, ok := hm.Indexes[index] if ok { - hostinfo.Lock() - defer hostinfo.Unlock() - delete(hm.Indexes, index) delete(hm.RemoteIndexes, hostinfo.remoteIndexId) // Check if we have an entry under hostId that matches the same hostinfo // instance. Clean it up as well if we do. - var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.hostId] + hostinfo2, ok := hm.Hosts[hostinfo.hostId] if ok && hostinfo2 == hostinfo { delete(hm.Hosts, hostinfo.hostId) } @@ -400,36 +362,26 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 { } } -func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { - hm.RLock() - if i, ok := hm.Hosts[vpnIP]; ok { - if i == nil { - hm.RUnlock() - return false - } - complete := i.HandshakeComplete - hm.RUnlock() - return complete +// We already have the hm Lock when this is called, so make sure to not call +// any other methods that might try to grab it again +func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { + remoteCert := hostinfo.ConnectionState.peerCert + ip := ip2int(remoteCert.Details.Ips[0].IP) + f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote) + if f.serveDns { + dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } - hm.RUnlock() - return false -} -func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool { - hm.RLock() - if i, ok := hm.Indexes[index]; ok { - if i == nil { - hm.RUnlock() - return false - } - complete := i.HandshakeComplete - hm.RUnlock() - return complete + hm.Hosts[hostinfo.hostId] = hostinfo + hm.Indexes[hostinfo.localIndexId] = hostinfo + hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo + if l.Level >= logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). + Debug("Hostmap vpnIp added") } - hm.RUnlock() - return false } func (hm *HostMap) ClearRemotes(vpnIP uint32) { diff --git a/outside.go b/outside.go index 093c9ff..ec3b74f 100644 --- a/outside.go +++ b/outside.go @@ -106,7 +106,6 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, case recvError: f.messageMetrics.Rx(header.Type, header.Subtype, 1) - // TODO: Remove this with recv_error deprecation f.handleRecvError(addr, header) return @@ -312,8 +311,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { } func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { - // This flag is to stop caring about recv_error from old versions - // This should go away when the old version is gone from prod if l.Level >= logrus.DebugLevel { l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr).