Routine-local conntrack cache (#391)

Previously, every packet we see gets a lock on the conntrack table and updates it. When running with multiple routines, this can cause heavy lock contention and limit our ability for the threads to run independently. This change caches reads from the conntrack table for a very short period of time to reduce this lock contention. This cache will currently default to disabled unless you are running with multiple routines, in which case the default cache delay will be 1 second. This means that entries in the conntrack table may be up to 1 second out of date and remain in a routine local cache for up to 1 second longer than the global table.

Instead of calling time.Now() for every packet, this cache system relies on a tick thread that updates the current cache "version" each tick. Every packet we check if the cache version is out of date, and reset the cache if so.
This commit is contained in:
Wade Simmons 2021-03-01 19:52:17 -05:00 committed by GitHub
parent d232ccbfab
commit 2a4beb41b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 118 additions and 31 deletions

View File

@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
@ -372,9 +373,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool) {
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
return nil
}
@ -426,7 +427,12 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
}
conntrack := f.Conntrack
conntrack.Lock()
@ -494,6 +500,10 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
conntrack.Unlock()
if localCache != nil {
localCache[fp] = struct{}{}
}
return true
}
@ -923,3 +933,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
c.Seq = 0
return true
}
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[FirewallPacket]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get() ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.GetLevel() == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

View File

@ -182,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -370,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@ -454,13 +454,13 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@ -505,12 +505,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
@ -519,7 +519,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
@ -528,7 +528,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) {

View File

@ -7,7 +7,7 @@ import (
"github.com/sirupsen/logrus"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
ci.queueLock.Unlock()
}
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
if f.lightHouse != nil && mc%5000 == 0 {
@ -129,7 +129,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).

View File

@ -40,6 +40,8 @@ type InterfaceConfig struct {
routines int
MessageMetrics *MessageMetrics
version string
ConntrackCacheTimeout time.Duration
}
type Interface struct {
@ -61,6 +63,8 @@ type Interface struct {
routines int
version string
conntrackCacheTimeout time.Duration
writers []*udpConn
readers []io.ReadWriteCloser
@ -102,6 +106,8 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
writers: make([]*udpConn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
conntrackCacheTimeout: c.ConntrackCacheTimeout,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
}
@ -173,6 +179,8 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
fwPacket := &FirewallPacket{}
nb := make([]byte, 12, 12)
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
@ -181,7 +189,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i)
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
}
}

14
main.go
View File

@ -117,6 +117,18 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
}
// EXPERIMENTAL
// Intentionally not documented yet while we do more testing and determine
// a good default value.
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
// Use a different default if we are running with multiple routines
conntrackCacheTimeout = 1 * time.Second
}
if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
}
var tun Inside
if !configTest {
config.CatchHUP()
@ -359,6 +371,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
ConntrackCacheTimeout: conntrackCacheTimeout,
}
switch ifConfig.Cipher {

View File

@ -17,7 +17,7 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int) {
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
err := header.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
@ -45,7 +45,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q)
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
// Fallthrough to the bottom to record incoming traffic
@ -257,7 +257,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int) {
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
@ -281,7 +281,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket).

View File

@ -115,6 +115,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
@ -124,7 +126,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
}
udpAddr.UDPAddr = *rua
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
}
}

View File

@ -174,6 +174,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
read = u.ReadSingle
}
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := read(msgs)
if err != nil {
@ -186,7 +188,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8])
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q)
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
}
}
}