Refactor handshake_ix (#401)

There are some subtle race conditions with the previous handshake_ix implementation, mostly around collisions with localIndexId. This change refactors it so that we have a "commit" phase during the handshake where we grab the lock for the hostmap and ensure that we have a unique local index before storing it. We also now avoid using the pending hostmap at all for receiving stage1 packets, since we have everything we need to just store the completed handshake.

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Ryan Huber <rhuber@gmail.com>
Co-authored-by: forfuncsake <drussell@slack-corp.com>
This commit is contained in:
Wade Simmons 2021-03-12 14:16:25 -05:00 committed by GitHub
parent 64d8035d09
commit 6c55d67f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 345 additions and 315 deletions

View File

@ -6,30 +6,23 @@ const (
) )
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
//TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
//if err != nil {
// l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
// return
//}
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) { if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
tearDown := false
switch h.Subtype { switch h.Subtype {
case handshakeIXPSK0: case handshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h) ixHandshakeStage1(f, addr, packet, h)
case 2: case 2:
tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h) newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)
}
} }
} }
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)
}
} }

View File

@ -25,17 +25,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
} }
myIndex, err := generateIndex() err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return return
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo)
hsProto := &NebulaHandshakeDetails{ hsProto := &NebulaHandshakeDetails{
InitiatorIndex: myIndex, InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().Unix()), Time: uint64(time.Now().Unix()),
Cert: ci.certState.rawCertificateNoKey, Cert: ci.certState.rawCertificateNoKey,
} }
@ -73,122 +73,140 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
var ip uint32 ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
if h.RemoteIndex == 0 { // Mark packet 1 as seen so it doesn't show up as missed
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) ci.window.Update(1)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil { if err != nil {
l.WithError(err).WithField("udpAddr", addr). l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return true return
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs) err = proto.Unmarshal(msg, hs)
/* /*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/ */
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr). l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true return
} }
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if hostinfo != nil { if err != nil {
hostinfo.RLock() l.WithError(err).WithField("udpAddr", addr).
defer hostinfo.RUnlock() WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { myIndex, err := generateIndex()
if msg, ok := hostinfo.HandshakePacket[2]; ok { if err != nil {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return false
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
WithField("packets", hostinfo.HandshakePacket).
Error("Seen this handshake packet already but don't have a cached packet to return")
}
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
myIndex, err := generateIndex()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true
}
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true
}
hostinfo.Lock()
defer hostinfo.Unlock()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). return
Info("Handshake message received") }
f.handshakeManager.addRemoteIndexHostInfo(hs.Details.InitiatorIndex, hostinfo) hostinfo := &HostInfo{
hs.Details.ResponderIndex = myIndex ConnectionState: ci,
hs.Details.Cert = ci.certState.rawCertificateNoKey Remotes: []*HostInfoDest{},
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
}
hsBytes, err := proto.Marshal(hs) l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
if err != nil { WithField("certName", certName).
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("fingerprint", fingerprint).
WithField("certName", certName). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("fingerprint", fingerprint). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") Info("Handshake message received")
return true
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) hs.Details.ResponderIndex = myIndex
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) hs.Details.Cert = ci.certState.rawCertificateNoKey
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true
}
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { hsBytes, err := proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
return
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
hostinfo.CreateRemoteCIDR(remoteCert)
hostinfo.Lock()
defer hostinfo.Unlock()
// Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil {
switch err {
case ErrAlreadySeen:
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win
// handshake avoidance
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@ -198,82 +216,52 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return true return
} case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(2)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
}
ip = ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
WithField("remoteIndex", ho.remoteIndexId).
Debug("Handshake processing")
f.hostMap.DeleteHostInfo(ho)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
hostinfo.handshakeComplete()
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
Error("Noise did not arrive at a key") WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
return true WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
Error("Failed to add HostInfo due to localIndex collision")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
return
} }
} }
f.hostMap.AddRemote(ip, addr) // Do the send
return false f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
}
hostinfo.handshakeComplete()
return
} }
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
@ -286,7 +274,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Already seen this handshake packet") Info("Already seen this handshake packet")
return false return false
} }
@ -307,6 +295,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future // near future
return false return false
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
@ -351,45 +344,20 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// Regardless of whether you are the sender or receiver, you should arrive here // Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection. // and complete standing up the connection.
if dKey != nil && eKey != nil {
ip := ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes() ci.peerCert = remoteCert
f.hostMap.AddRemote(ip, addr) ci.dKey = NewNebulaCipherState(dKey)
hostinfo.CreateRemoteCIDR(remoteCert) ci.eKey = NewNebulaCipherState(eKey)
f.lightHouse.AddRemoteAndReset(ip, addr) //l.Debugln("got symmetric pairs")
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP) //hostinfo.ClearRemotes()
if err == nil && ho.localIndexId != 0 { hostinfo.AddRemote(*addr)
l.WithField("vpnIp", vpnIP). hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
WithField("certName", certName). hostinfo.CreateRemoteCIDR(remoteCert)
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
WithField("remoteIndex", ho.remoteIndexId).
Debug("Handshake processing")
f.hostMap.DeleteHostInfo(ho)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) f.handshakeManager.Complete(hostinfo, f)
hostinfo.handshakeComplete()
hostinfo.handshakeComplete() f.metricHandshakes.Update(duration)
f.metricHandshakes.Update(duration)
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
return false return false
} }

View File

@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "errors"
"net" "net"
"time" "time"
@ -196,18 +197,123 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
return hostinfo return hostinfo
} }
func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { var (
hostinfo, err := c.pendingHostMap.AddIndex(index, ci) ErrExistingHostInfo = errors.New("existing hostinfo")
if err != nil { ErrAlreadySeen = errors.New("already seen")
return nil, fmt.Errorf("Issue adding index: %d", index) ErrLocalIndexCollision = errors.New("local index collision")
)
// CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet
//
// ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and overwrite was false.
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
c.pendingHostMap.RLock()
defer c.pendingHostMap.RUnlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen
}
if !overwrite {
return existingHostInfo, ErrExistingHostInfo
}
} }
//c.mainHostMap.AddIndexHostInfo(index, hostinfo)
c.InboundHandshakeTimer.Add(index, time.Second*10) existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
return hostinfo, nil if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
c.mainHostMap.addHostInfo(hostinfo, f)
return existingHostInfo, nil
} }
func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { // Complete is a simpler version of CheckAndComplete when we already know we
c.pendingHostMap.AddIndexHostInfo(index, h) // won't have a localIndexId collision because we already have an entry in the
// pendingHostMap
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.addHostInfo(hostinfo, f)
}
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.RLock()
defer c.mainHostMap.RUnlock()
for i := 0; i < 32; i++ {
index, err := generateIndex()
if err != nil {
return err
}
_, inPending := c.pendingHostMap.Indexes[index]
_, inMain := c.mainHostMap.Indexes[index]
if !inMain && !inPending {
h.localIndexId = index
c.pendingHostMap.Indexes[index] = h
return nil
}
}
return errors.New("failed to generate unique localIndexId")
} }
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {

View File

@ -8,12 +8,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32 var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) { func Test_NewHandshakeManagerIndex(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -26,9 +25,18 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
var indexes = make([]uint32, 4)
var hostinfo = make([]*HostInfo, len(indexes))
for i := range indexes {
hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
}
// Add four indexes // Add four indexes
for _, v := range indexes { for i := range indexes {
blah.AddIndex(v, &ConnectionState{}) err := blah.AddIndexHostInfo(hostinfo[i])
assert.NoError(t, err)
indexes[i] = hostinfo[i].localIndexId
blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
} }
// Confirm they are in the pending index list // Confirm they are in the pending index list
for _, v := range indexes { for _, v := range indexes {
@ -169,8 +177,11 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
hostinfo := blah.AddVpnIP(vpnIP) hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too // Pretned we have an index too
blah.AddIndexHostInfo(12341234, hostinfo) err := blah.AddIndexHostInfo(hostinfo)
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
assert.NotZero(t, hostinfo.localIndexId)
assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap // but not main hostmap
@ -216,7 +227,10 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{}) hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
err := blah.AddIndexHostInfo(hostinfo)
assert.NoError(t, err)
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
// Pretned we have an index too // Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
@ -229,7 +243,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3) next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick) blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
} }
type mockEncWriter struct { type mockEncWriter struct {

View File

@ -166,40 +166,6 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
} }
} }
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hm.Lock()
if _, ok := hm.Indexes[index]; !ok {
h := &HostInfo{
ConnectionState: ci,
Remotes: []*HostInfoDest{},
localIndexId: index,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Indexes[index] = h
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
hm.Unlock()
return h, nil
}
hm.Unlock()
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
}
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.localIndexId = index
hm.Indexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
}
}
// Only used by pendingHostMap when the remote index is not initially known // Only used by pendingHostMap when the remote index is not initially known
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock() hm.Lock()
@ -234,16 +200,12 @@ func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock() hm.Lock()
hostinfo, ok := hm.Indexes[index] hostinfo, ok := hm.Indexes[index]
if ok { if ok {
hostinfo.Lock()
defer hostinfo.Unlock()
delete(hm.Indexes, index) delete(hm.Indexes, index)
delete(hm.RemoteIndexes, hostinfo.remoteIndexId) delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
// Check if we have an entry under hostId that matches the same hostinfo // Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do. // instance. Clean it up as well if we do.
var hostinfo2 *HostInfo hostinfo2, ok := hm.Hosts[hostinfo.hostId]
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
if ok && hostinfo2 == hostinfo { if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId) delete(hm.Hosts, hostinfo.hostId)
} }
@ -400,36 +362,26 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
} }
} }
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { // We already have the hm Lock when this is called, so make sure to not call
hm.RLock() // any other methods that might try to grab it again
if i, ok := hm.Hosts[vpnIP]; ok { func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
if i == nil { remoteCert := hostinfo.ConnectionState.peerCert
hm.RUnlock() ip := ip2int(remoteCert.Details.Ips[0].IP)
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
} }
hm.RUnlock()
return false
}
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool { hm.Hosts[hostinfo.hostId] = hostinfo
hm.RLock() hm.Indexes[hostinfo.localIndexId] = hostinfo
if i, ok := hm.Indexes[index]; ok { hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
Debug("Hostmap vpnIp added")
} }
hm.RUnlock()
return false
} }
func (hm *HostMap) ClearRemotes(vpnIP uint32) { func (hm *HostMap) ClearRemotes(vpnIP uint32) {

View File

@ -106,7 +106,6 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
case recvError: case recvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header) f.handleRecvError(addr, header)
return return
@ -312,8 +311,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
// This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex). l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr). WithField("udpAddr", addr).