diff --git a/connection_manager.go b/connection_manager.go index bb38e3d..bc2ce05 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -247,8 +247,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { if n.intf.lightHouse != nil { n.intf.lightHouse.DeleteVpnIP(vpnIP) } - n.hostMap.DeleteVpnIP(vpnIP) - n.hostMap.DeleteIndex(hostinfo.localIndexId) + n.hostMap.DeleteHostInfo(hostinfo) } else { n.ClearIP(vpnIP) n.ClearPendingDeletion(vpnIP) diff --git a/handshake.go b/handshake.go index 0a10d7b..cd929a6 100644 --- a/handshake.go +++ b/handshake.go @@ -30,7 +30,6 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head } if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteIndex(newHostinfo.localIndexId) - f.handshakeManager.DeleteVpnIP(newHostinfo.hostId) + f.handshakeManager.DeleteHostInfo(newHostinfo) } } diff --git a/handshake_ix.go b/handshake_ix.go index f2afb51..f462fbc 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -157,7 +157,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message received") - hostinfo.remoteIndexId = hs.Details.InitiatorIndex + f.handshakeManager.addRemoteIndexHostInfo(hs.Details.InitiatorIndex, hostinfo) hs.Details.ResponderIndex = myIndex hs.Details.Cert = ci.certState.rawCertificateNoKey @@ -245,11 +245,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ WithField("fingerprint", fingerprint). WithField("action", "removing stale index"). WithField("index", ho.localIndexId). + WithField("remoteIndex", ho.remoteIndexId). Debug("Handshake processing") - f.hostMap.DeleteIndex(ho.localIndexId) + f.hostMap.DeleteHostInfo(ho) } - f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo) f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) hostinfo.handshakeComplete() @@ -363,12 +363,12 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ WithField("fingerprint", fingerprint). WithField("action", "removing stale index"). WithField("index", ho.localIndexId). + WithField("remoteIndex", ho.remoteIndexId). Debug("Handshake processing") - f.hostMap.DeleteIndex(ho.localIndexId) + f.hostMap.DeleteHostInfo(ho) } f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) - f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo) hostinfo.handshakeComplete() f.metricHandshakes.Update(duration) diff --git a/handshake_manager.go b/handshake_manager.go index ca496bc..1e55453 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -99,10 +99,6 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr } func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) { - index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) - if err != nil { - return - } hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) if err != nil { return @@ -172,8 +168,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } else { - c.pendingHostMap.DeleteVpnIP(vpnIP) - c.pendingHostMap.DeleteIndex(index) + c.pendingHostMap.DeleteHostInfo(hostinfo) } } @@ -186,12 +181,11 @@ func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) { } index := ep.(uint32) - vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index) + hostinfo, err := c.pendingHostMap.QueryIndex(index) if err != nil { continue } - c.pendingHostMap.DeleteIndex(index) - c.pendingHostMap.DeleteVpnIP(vpnIP) + c.pendingHostMap.DeleteHostInfo(hostinfo) } } @@ -204,11 +198,6 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { return hostinfo } -func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) { - //l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP)) - c.pendingHostMap.DeleteVpnIP(vpnIP) -} - func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { hostinfo, err := c.pendingHostMap.AddIndex(index, ci) if err != nil { @@ -223,9 +212,13 @@ func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { c.pendingHostMap.AddIndexHostInfo(index, h) } -func (c *HandshakeManager) DeleteIndex(index uint32) { - //l.Debugln("Deleting pending index :", index) - c.pendingHostMap.DeleteIndex(index) +func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { + c.pendingHostMap.addRemoteIndexHostInfo(index, h) +} + +func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + //l.Debugln("Deleting pending hostinfo :", hostinfo) + c.pendingHostMap.DeleteHostInfo(hostinfo) } func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) { @@ -241,13 +234,19 @@ func (c *HandshakeManager) EmitStats() { func generateIndex() (uint32, error) { b := make([]byte, 4) - _, err := rand.Read(b) - if err != nil { - l.Errorln(err) - return 0, err + + // Let zero mean we don't know the ID, so don't generate zero + var index uint32 + for index == 0 { + _, err := rand.Read(b) + if err != nil { + l.Errorln(err) + return 0, err + } + + index = binary.BigEndian.Uint32(b) } - index := binary.BigEndian.Uint32(b) if l.Level >= logrus.DebugLevel { l.WithField("index", index). Debug("Generated index") diff --git a/hostmap.go b/hostmap.go index 0069319..c15c3cf 100644 --- a/hostmap.go +++ b/hostmap.go @@ -25,6 +25,7 @@ type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps name string Indexes map[uint32]*HostInfo + RemoteIndexes map[uint32]*HostInfo Hosts map[uint32]*HostInfo preferredRanges []*net.IPNet vpnCIDR *net.IPNet @@ -77,9 +78,11 @@ type Probe struct { func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { h := map[uint32]*HostInfo{} i := map[uint32]*HostInfo{} + r := map[uint32]*HostInfo{} m := HostMap{ name: name, Indexes: i, + RemoteIndexes: r, Hosts: h, preferredRanges: preferredRanges, vpnCIDR: vpnCIDR, @@ -94,10 +97,12 @@ func (hm *HostMap) EmitStats(name string) { hm.RLock() hostLen := len(hm.Hosts) indexLen := len(hm.Indexes) + remoteIndexLen := len(hm.RemoteIndexes) hm.RUnlock() metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen)) + metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) } func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { @@ -111,17 +116,6 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { return 0, errors.New("vpn IP not found") } -func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) { - hm.RLock() - if i, ok := hm.Indexes[index]; ok { - vpnIP := i.hostId - hm.RUnlock() - return vpnIP, nil - } - hm.RUnlock() - return 0, errors.New("vpn IP not found") -} - func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) { hm.Lock() hm.Hosts[ip] = hostinfo @@ -198,10 +192,26 @@ func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) { } } +// Only used by pendingHostMap when the remote index is not initially known +func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { + hm.Lock() + h.remoteIndexId = index + hm.RemoteIndexes[index] = h + hm.Unlock() + + if l.Level > logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), + "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + Debug("Hostmap remoteIndex added") + } +} + func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) { hm.Lock() h.hostId = vpnIP hm.Hosts[vpnIP] = h + hm.Indexes[h.localIndexId] = h + hm.RemoteIndexes[h.remoteIndexId] = h hm.Unlock() if l.Level > logrus.DebugLevel { @@ -225,6 +235,29 @@ func (hm *HostMap) DeleteIndex(index uint32) { } } +func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + delete(hm.Hosts, hostinfo.hostId) + if len(hm.Hosts) == 0 { + hm.Hosts = map[uint32]*HostInfo{} + } + delete(hm.Indexes, hostinfo.localIndexId) + if len(hm.Indexes) == 0 { + hm.Indexes = map[uint32]*HostInfo{} + } + delete(hm.RemoteIndexes, hostinfo.remoteIndexId) + if len(hm.RemoteIndexes) == 0 { + hm.RemoteIndexes = map[uint32]*HostInfo{} + } + hm.Unlock() + + if l.Level >= logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), + "vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + Debug("Hostmap hostInfo deleted") + } +} + func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { //TODO: we probably just want ot return bool instead of error, or at least a static error hm.RLock() @@ -237,23 +270,15 @@ func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { } } -// This function needs to range because we don't keep a map of remote indexes. func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { hm.RLock() - for _, h := range hm.Indexes { - if h.ConnectionState != nil && h.remoteIndexId == index { - hm.RUnlock() - return h, nil - } + if h, ok := hm.RemoteIndexes[index]; ok { + hm.RUnlock() + return h, nil + } else { + hm.RUnlock() + return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) } - for _, h := range hm.Hosts { - if h.ConnectionState != nil && h.remoteIndexId == index { - hm.RUnlock() - return h, nil - } - } - hm.RUnlock() - return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) } func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo { diff --git a/outside.go b/outside.go index 064b0a1..7738e80 100644 --- a/outside.go +++ b/outside.go @@ -138,8 +138,7 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { f.connectionManager.ClearIP(hostInfo.hostId) f.connectionManager.ClearPendingDeletion(hostInfo.hostId) f.lightHouse.DeleteVpnIP(hostInfo.hostId) - f.hostMap.DeleteVpnIP(hostInfo.hostId) - f.hostMap.DeleteIndex(hostInfo.localIndexId) + f.hostMap.DeleteHostInfo(hostInfo) } func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { @@ -335,17 +334,13 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { return } - id := hostinfo.localIndexId - host := hostinfo.hostId // We delete this host from the main hostmap - f.hostMap.DeleteIndex(id) - f.hostMap.DeleteVpnIP(host) + f.hostMap.DeleteHostInfo(hostinfo) // We also delete it from pending to allow for // fast reconnect. We must null the connectionstate // or a counter reuse may happen hostinfo.ConnectionState = nil - f.handshakeManager.DeleteIndex(id) - f.handshakeManager.DeleteVpnIP(host) + f.handshakeManager.DeleteHostInfo(hostinfo) } /*