From 2a4beb41b944d8dadeda212fb821fcc914e5cf08 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 1 Mar 2021 19:52:17 -0500 Subject: [PATCH] Routine-local conntrack cache (#391) Previously, every packet we see gets a lock on the conntrack table and updates it. When running with multiple routines, this can cause heavy lock contention and limit our ability for the threads to run independently. This change caches reads from the conntrack table for a very short period of time to reduce this lock contention. This cache will currently default to disabled unless you are running with multiple routines, in which case the default cache delay will be 1 second. This means that entries in the conntrack table may be up to 1 second out of date and remain in a routine local cache for up to 1 second longer than the global table. Instead of calling time.Now() for every packet, this cache system relies on a tick thread that updates the current cache "version" each tick. Every packet we check if the cache version is out of date, and reset the cache if so. --- firewall.go | 67 +++++++++++++++++++++++++++++++++++++++++++++--- firewall_test.go | 36 +++++++++++++------------- inside.go | 6 ++--- interface.go | 10 +++++++- main.go | 14 ++++++++++ outside.go | 8 +++--- udp_generic.go | 4 ++- udp_linux.go | 4 ++- 8 files changed, 118 insertions(+), 31 deletions(-) diff --git a/firewall.go b/firewall.go index 42919fc..f09a701 100644 --- a/firewall.go +++ b/firewall.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" @@ -372,9 +373,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error { +func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming, h, caPool) { + if f.inConns(packet, fp, incoming, h, caPool, localCache) { return nil } @@ -426,7 +427,12 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) } -func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { +func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool { + if localCache != nil { + if _, ok := localCache[fp]; ok { + return true + } + } conntrack := f.Conntrack conntrack.Lock() @@ -494,6 +500,10 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H conntrack.Unlock() + if localCache != nil { + localCache[fp] = struct{}{} + } + return true } @@ -923,3 +933,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { c.Seq = 0 return true } + +// ConntrackCache is used as a local routine cache to know if a given flow +// has been seen in the conntrack table. +type ConntrackCache map[FirewallPacket]struct{} + +type ConntrackCacheTicker struct { + cacheV uint64 + cacheTick uint64 + + cache ConntrackCache +} + +func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { + if d == 0 { + return nil + } + + c := &ConntrackCacheTicker{ + cache: ConntrackCache{}, + } + + go c.tick(d) + + return c +} + +func (c *ConntrackCacheTicker) tick(d time.Duration) { + for { + time.Sleep(d) + atomic.AddUint64(&c.cacheTick, 1) + } +} + +// Get checks if the cache ticker has moved to the next version before returning +// the map. If it has moved, we reset the map. +func (c *ConntrackCacheTicker) Get() ConntrackCache { + if c == nil { + return nil + } + if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { + c.cacheV = tick + if ll := len(c.cache); ll > 0 { + if l.GetLevel() == logrus.DebugLevel { + l.WithField("len", ll).Debug("resetting conntrack cache") + } + c.cache = make(ConntrackCache, ll) + } + } + + return c.cache +} diff --git a/firewall_test.go b/firewall_test.go index 8068c8a..3995e8d 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -182,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteIP p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -370,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule) + assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -454,13 +454,13 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -505,12 +505,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) @@ -519,7 +519,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) @@ -528,7 +528,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) } func BenchmarkLookup(b *testing.B) { diff --git a/inside.go b/inside.go index 6192a1c..302b22b 100644 --- a/inside.go +++ b/inside.go @@ -7,7 +7,7 @@ import ( "github.com/sirupsen/logrus" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) @@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, ci.queueLock.Unlock() } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache) if dropReason == nil { mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) if f.lightHouse != nil && mc%5000 == 0 { @@ -129,7 +129,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) + dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) if dropReason != nil { if l.Level >= logrus.DebugLevel { l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 825ba97..d17f6a8 100644 --- a/interface.go +++ b/interface.go @@ -40,6 +40,8 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string + + ConntrackCacheTimeout time.Duration } type Interface struct { @@ -61,6 +63,8 @@ type Interface struct { routines int version string + conntrackCacheTimeout time.Duration + writers []*udpConn readers []io.ReadWriteCloser @@ -102,6 +106,8 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { writers: make([]*udpConn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), + conntrackCacheTimeout: c.ConntrackCacheTimeout, + metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, } @@ -173,6 +179,8 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { fwPacket := &FirewallPacket{} nb := make([]byte, 12, 12) + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { n, err := reader.Read(packet) if err != nil { @@ -181,7 +189,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } } diff --git a/main.go b/main.go index 0800ffc..2f81fac 100644 --- a/main.go +++ b/main.go @@ -117,6 +117,18 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } } + // EXPERIMENTAL + // Intentionally not documented yet while we do more testing and determine + // a good default value. + conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0) + if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") { + // Use a different default if we are running with multiple routines + conntrackCacheTimeout = 1 * time.Second + } + if conntrackCacheTimeout > 0 { + l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + } + var tun Inside if !configTest { config.CatchHUP() @@ -359,6 +371,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L routines: routines, MessageMetrics: messageMetrics, version: buildVersion, + + ConntrackCacheTimeout: conntrackCacheTimeout, } switch ifConfig.Cipher { diff --git a/outside.go b/outside.go index e0f9aaa..75f4eba 100644 --- a/outside.go +++ b/outside.go @@ -17,7 +17,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int) { +func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) { err := header.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -45,7 +45,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, return } - f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q) + f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache) // Fallthrough to the bottom to record incoming traffic @@ -257,7 +257,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int) { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) @@ -281,7 +281,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) if dropReason != nil { if l.Level >= logrus.DebugLevel { hostinfo.logger().WithField("fwPacket", fwPacket). diff --git a/udp_generic.go b/udp_generic.go index 5a1d204..2de6e29 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -115,6 +115,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) { lhh := f.lightHouse.NewRequestHandler() + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) @@ -124,7 +126,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { } udpAddr.UDPAddr = *rua - f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q) + f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get()) } } diff --git a/udp_linux.go b/udp_linux.go index 69eee31..dbdad2c 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -174,6 +174,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) { read = u.ReadSingle } + conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout) + for { n, err := read(msgs) if err != nil { @@ -186,7 +188,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) { udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q) + f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get()) } } }