diff --git a/connection_manager_test.go b/connection_manager_test.go index 0dd3f7a..15baae2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -49,9 +49,8 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, - messageCounter: new(uint64), + certState: cs, + H: &noise.HandshakeState{}, } // We saw traffic out to vpnIP @@ -115,9 +114,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := nc.hostMap.AddVpnIP(vpnIP) hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, - messageCounter: new(uint64), + certState: cs, + H: &noise.HandshakeState{}, } // We saw traffic out to vpnIP diff --git a/connection_state.go b/connection_state.go index 2583745..25cdc58 100644 --- a/connection_state.go +++ b/connection_state.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/json" "sync" + "sync/atomic" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" @@ -12,17 +13,17 @@ import ( const ReplayWindow = 1024 type ConnectionState struct { - eKey *NebulaCipherState - dKey *NebulaCipherState - H *noise.HandshakeState - certState *CertState - peerCert *cert.NebulaCertificate - initiator bool - messageCounter *uint64 - window *Bits - queueLock sync.Mutex - writeLock sync.Mutex - ready bool + eKey *NebulaCipherState + dKey *NebulaCipherState + H *noise.HandshakeState + certState *CertState + peerCert *cert.NebulaCertificate + initiator bool + atomicMessageCounter uint64 + window *Bits + queueLock sync.Mutex + writeLock sync.Mutex + ready bool } func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { @@ -54,12 +55,11 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa // The queue and ready params prevent a counter race that would happen when // sending stored packets and simultaneously accepting new traffic. ci := &ConnectionState{ - H: hs, - initiator: initiator, - window: b, - ready: false, - certState: curCertState, - messageCounter: new(uint64), + H: hs, + initiator: initiator, + window: b, + ready: false, + certState: curCertState, } return ci @@ -69,7 +69,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "certificate": cs.peerCert, "initiator": cs.initiator, - "message_counter": cs.messageCounter, + "message_counter": atomic.LoadUint64(&cs.atomicMessageCounter), "ready": cs.ready, }) } diff --git a/control.go b/control.go index 8e7eb0c..4964164 100644 --- a/control.go +++ b/control.go @@ -4,6 +4,7 @@ import ( "net" "os" "os/signal" + "sync/atomic" "syscall" "github.com/sirupsen/logrus" @@ -156,7 +157,7 @@ func copyHostInfo(h *HostInfo) ControlHostInfo { RemoteIndex: h.remoteIndexId, RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)), CachedPackets: len(h.packetStore), - MessageCounter: *h.ConnectionState.messageCounter, + MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter), } if c := h.GetCert(); c != nil { diff --git a/control_test.go b/control_test.go index f3ad7df..ca68c75 100644 --- a/control_test.go +++ b/control_test.go @@ -43,15 +43,13 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { }, Signature: []byte{1, 2, 1, 2, 1, 3}, } - counter := uint64(0) remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)} hm.Add(ip2int(ipNet.IP), &HostInfo{ remote: remote1, Remotes: remotes, ConnectionState: &ConnectionState{ - peerCert: crt, - messageCounter: &counter, + peerCert: crt, }, remoteIndexId: 200, localIndexId: 201, @@ -62,8 +60,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) { remote: remote1, Remotes: remotes, ConnectionState: &ConnectionState{ - peerCert: nil, - messageCounter: &counter, + peerCert: nil, }, remoteIndexId: 200, localIndexId: 201, diff --git a/handshake_ix.go b/handshake_ix.go index f462fbc..d43804a 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -54,7 +54,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { } header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1) - atomic.AddUint64(ci.messageCounter, 1) + atomic.AddUint64(&ci.atomicMessageCounter, 1) msg, _, _, err := ci.H.WriteMessage(header, hsBytes) if err != nil { @@ -99,26 +99,31 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ } hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) - if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { - if msg, ok := hostinfo.HandshakePacket[2]; ok { - f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) - err := f.outside.WriteTo(msg, addr) - if err != nil { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") - } else { - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") - } - return false - } + if hostinfo != nil { + hostinfo.RLock() + defer hostinfo.RUnlock() - l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true). - WithField("packets", hostinfo.HandshakePacket). - Error("Seen this handshake packet already but don't have a cached packet to return") + if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { + if msg, ok := hostinfo.HandshakePacket[2]; ok { + f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) + err := f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + WithError(err).Error("Failed to send handshake message") + } else { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + Info("Handshake message sent") + } + return false + } + + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true). + WithField("packets", hostinfo.HandshakePacket). + Error("Seen this handshake packet already but don't have a cached packet to return") + } } remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) @@ -150,6 +155,9 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ return true } + hostinfo.Lock() + defer hostinfo.Unlock() + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -272,6 +280,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ if hostinfo == nil { return true } + hostinfo.Lock() + defer hostinfo.Unlock() if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). diff --git a/handshake_manager.go b/handshake_manager.go index 9c4b445..2223e74 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -103,6 +103,8 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT if err != nil { return } + hostinfo.Lock() + defer hostinfo.Unlock() // If we haven't finished the handshake and we haven't hit max retries, query // lighthouse and then send the handshake packet again. diff --git a/hostmap.go b/hostmap.go index 391cdfb..e04c14d 100644 --- a/hostmap.go +++ b/hostmap.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" @@ -35,6 +36,8 @@ type HostMap struct { } type HostInfo struct { + sync.RWMutex + remote *udpAddr Remotes []*HostInfoDest promoteCounter uint32 @@ -231,6 +234,9 @@ func (hm *HostMap) DeleteIndex(index uint32) { hm.Lock() hostinfo, ok := hm.Indexes[index] if ok { + hostinfo.Lock() + defer hostinfo.Unlock() + delete(hm.Indexes, index) delete(hm.RemoteIndexes, hostinfo.remoteIndexId) @@ -513,8 +519,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) return } - i.promoteCounter++ - if i.promoteCounter%PromoteEvery == 0 { + if atomic.AddUint32(&i.promoteCounter, 1)&PromoteEvery == 0 { // return early if we are already on a preferred remote rIP := udp2ip(i.remote) for _, l := range preferredRanges { @@ -615,10 +620,12 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac copy(tempPacket, packet) //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) - i.logger(). - WithField("length", len(i.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if l.Level >= logrus.DebugLevel { + i.logger(). + WithField("length", len(i.packetStore)). + WithField("stored", true). + Debugf("Packet store") + } } else if l.Level >= logrus.DebugLevel { i.logger(). @@ -638,7 +645,7 @@ func (i *HostInfo) handshakeComplete() { i.HandshakeComplete = true //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. // Clamping it to 2 gets us out of the woods for now - *i.ConnectionState.messageCounter = 2 + atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2) i.logger().Debugf("Sending %d stored packets", len(i.packetStore)) nb := make([]byte, 12, 12) out := make([]byte, mtu) diff --git a/inside.go b/inside.go index 921fae6..1e0632a 100644 --- a/inside.go +++ b/inside.go @@ -103,16 +103,21 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { // If we have already created the handshake packet, we don't want to call the function at all. if !hostinfo.HandshakeReady { - ixHandshakeStage0(f, vpnIp, hostinfo) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //xx_handshakeStage0(f, ip, hostinfo) + hostinfo.Lock() + defer hostinfo.Unlock() - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - if _, ok := f.lightHouse.staticList[vpnIp]; ok { - select { - case f.handshakeManager.trigger <- vpnIp: - default: + if !hostinfo.HandshakeReady { + ixHandshakeStage0(f, vpnIp, hostinfo) + // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. + //xx_handshakeStage0(f, ip, hostinfo) + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + if _, ok := f.lightHouse.staticList[vpnIp]; ok { + select { + case f.handshakeManager.trigger <- vpnIp: + default: + } } } } @@ -139,8 +144,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, return } - f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) - if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 { + messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0) + if f.lightHouse != nil && messageCounter%5000 == 0 { f.lightHouse.Query(fp.RemoteIP, f) } } @@ -223,7 +228,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, var err error //TODO: enable if we do more than 1 tun queue //ci.writeLock.Lock() - c := atomic.AddUint64(ci.messageCounter, 1) + c := atomic.AddUint64(&ci.atomicMessageCounter, 1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c) @@ -247,7 +252,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, if err != nil { hostinfo.logger().WithError(err). WithField("udpAddr", remote).WithField("counter", c). - WithField("attemptedCounter", ci.messageCounter). + WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") return c } diff --git a/interface.go b/interface.go index 377dde0..e90fef0 100644 --- a/interface.go +++ b/interface.go @@ -134,11 +134,6 @@ func (f *Interface) run() { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) - // Launch n queues to read packets from udp - for i := 0; i < f.routines; i++ { - go f.listenOut(i) - } - // Prepare n tun queues var reader io.ReadWriteCloser = f.inside for i := 0; i < f.routines; i++ { @@ -155,6 +150,11 @@ func (f *Interface) run() { l.Fatal(err) } + // Launch n queues to read packets from udp + for i := 0; i < f.routines; i++ { + go f.listenOut(i) + } + // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { go f.listenIn(f.readers[i], i) diff --git a/ssh.go b/ssh.go index 9e409bc..aff63ef 100644 --- a/ssh.go +++ b/ssh.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime/pprof" "strings" + "sync/atomic" "syscall" "github.com/sirupsen/logrus" @@ -353,7 +354,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error } if v.ConnectionState != nil { - h["messageCounter"] = v.ConnectionState.messageCounter + h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter) } d[x] = h