From f3a6d8d990f2a714c2fed8580fe37024470f2d0d Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 31 Jul 2020 18:53:36 -0400 Subject: [PATCH] Preserve conntrack table during firewall rules reload (SIGHUP) (#233) Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe). This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded. --- firewall.go | 108 ++++++++++++++++++++++++++++++++++++----------- firewall_test.go | 106 ++++++++++++++++++++++++++++++++++++++-------- interface.go | 17 ++++++++ 3 files changed, 189 insertions(+), 42 deletions(-) diff --git a/firewall.go b/firewall.go index fd25098..91638e1 100644 --- a/firewall.go +++ b/firewall.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" ) @@ -37,13 +38,19 @@ type FirewallInterface interface { type conn struct { Expires time.Time // Time when this conntrack entry will expire - Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set + Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack + + // record why the original connection passed the firewall, so we can re-validate + // after ruleset changes. Note, rulesVersion is a uint16 so that these two + // fields pack for free after the uint32 above + incoming bool + rulesVersion uint16 } // TODO: need conntrack max tracked connections handling type Firewall struct { - Conns map[FirewallPacket]*conn + Conntrack *FirewallConntrack InRules *FirewallTable OutRules *FirewallTable @@ -54,18 +61,23 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - TimerWheel *TimerWheel - // Used to ensure we don't emit local packets for ips we don't own localIps *CIDRTree - connMutex sync.Mutex - rules string + rules string + rulesVersion uint16 trackTCPRTT bool metricTCPRTT metrics.Histogram } +type FirewallConntrack struct { + sync.Mutex + + Conns map[FirewallPacket]*conn + TimerWheel *TimerWheel +} + type FirewallTable struct { TCP firewallPort UDP firewallPort @@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N } return &Firewall{ - Conns: make(map[FirewallPacket]*conn), + Conntrack: &FirewallConntrack{ + Conns: make(map[FirewallPacket]*conn), + TimerWheel: NewTimerWheel(min, max), + }, InRules: newFirewallTable(), OutRules: newFirewallTable(), - TimerWheel: NewTimerWheel(min, max), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, @@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // returns nil if the packet should not be dropped. func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming) { + if f.inConns(packet, fp, incoming, h, caPool) { return nil } @@ -398,26 +412,66 @@ func (f *Firewall) Destroy() { } func (f *Firewall) EmitStats() { - conntrackCount := len(f.Conns) + conntrack := f.Conntrack + conntrack.Lock() + conntrackCount := len(conntrack.Conns) + conntrack.Unlock() metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) + metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) } -func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { - f.connMutex.Lock() +func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { + conntrack := f.Conntrack + conntrack.Lock() // Purge every time we test - ep, has := f.TimerWheel.Purge() + ep, has := conntrack.TimerWheel.Purge() if has { f.evict(ep) } - c, ok := f.Conns[fp] + c, ok := conntrack.Conns[fp] if !ok { - f.connMutex.Unlock() + conntrack.Unlock() return false } + if c.rulesVersion != f.rulesVersion { + // This conntrack entry was for an older rule set, validate + // it still passes with the current rule set + table := f.OutRules + if c.incoming { + table = f.InRules + } + + // We now know which firewall table to check against + if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { + if l.Level >= logrus.DebugLevel { + h.logger(). + WithField("fwPacket", fp). + WithField("incoming", c.incoming). + WithField("rulesVersion", f.rulesVersion). + WithField("oldRulesVersion", c.rulesVersion). + Debugln("dropping old conntrack entry, does not match new ruleset") + } + delete(conntrack.Conns, fp) + conntrack.Unlock() + return false + } + + if l.Level >= logrus.DebugLevel { + h.logger(). + WithField("fwPacket", fp). + WithField("incoming", c.incoming). + WithField("rulesVersion", f.rulesVersion). + WithField("oldRulesVersion", c.rulesVersion). + Debugln("keeping old conntrack entry, does match new ruleset") + } + + c.rulesVersion = f.rulesVersion + } + switch fp.Protocol { case fwProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) @@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool c.Expires = time.Now().Add(f.DefaultTimeout) } - f.connMutex.Unlock() + conntrack.Unlock() return true } @@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { timeout = f.DefaultTimeout } - f.connMutex.Lock() - if _, ok := f.Conns[fp]; !ok { - f.TimerWheel.Add(fp, timeout) + conntrack := f.Conntrack + conntrack.Lock() + if _, ok := conntrack.Conns[fp]; !ok { + conntrack.TimerWheel.Add(fp, timeout) } + // Record which rulesVersion allowed this connection, so we can retest after + // firewall reload + c.incoming = incoming + c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) - f.Conns[fp] = c - f.connMutex.Unlock() + conntrack.Conns[fp] = c + conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel @@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { func (f *Firewall) evict(p FirewallPacket) { //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? - t, ok := f.Conns[p] + conntrack := f.Conntrack + t, ok := conntrack.Conns[p] if !ok { return } @@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) { // Timeout is in the future, re-add the timer if newT > 0 { - f.TimerWheel.Add(p, newT) + conntrack.TimerWheel.Add(p, newT) return } // This conn is done - delete(f.Conns, p) + delete(conntrack.Conns, p) } func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { diff --git a/firewall_test.go b/firewall_test.go index d7ca789..8068c8a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -17,37 +17,39 @@ import ( func TestNewFirewall(t *testing.T) { c := &cert.NebulaCertificate{} fw := NewFirewall(time.Second, time.Minute, time.Hour, c) - assert.NotNil(t, fw.Conns) + conntrack := fw.Conntrack + assert.NotNil(t, conntrack) + assert.NotNil(t, conntrack.Conns) + assert.NotNil(t, conntrack.TimerWheel) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - assert.NotNil(t, fw.TimerWheel) assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Second, time.Hour, time.Minute, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Hour, time.Second, time.Minute, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Hour, time.Minute, time.Second, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Minute, time.Hour, time.Second, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) fw = NewFirewall(time.Minute, time.Second, time.Hour, c) - assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) - assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) + assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) } func TestFirewall_AddRule(t *testing.T) { @@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) { assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule) } +func TestFirewall_DropConntrackReload(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + ip2int(net.IPv4(1, 2, 3, 4)), + 10, + 90, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Groups: []string{"default-group"}, + InvertedGroups: map[string]struct{}{"default-group": {}}, + Issuer: "signer-shasum", + }, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + hostId: ip2int(ipNet.IP), + } + h.CreateRemoteCIDR(&c) + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + cp := cert.NewCAPool() + + // Drop outbound + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp)) + // Allow outbound because conntrack + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + + oldFw := fw + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) + fw.Conntrack = oldFw.Conntrack + fw.rulesVersion = oldFw.rulesVersion + 1 + + // Allow outbound because conntrack and new rules allow port 10 + assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp)) + + oldFw = fw + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) + fw.Conntrack = oldFw.Conntrack + fw.rulesVersion = oldFw.rulesVersion + 1 + + // Drop outbound because conntrack doesn't match new ruleset + assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end } func resetConntrack(fw *Firewall) { - fw.connMutex.Lock() - fw.Conns = map[FirewallPacket]*conn{} - fw.connMutex.Unlock() + fw.Conntrack.Lock() + fw.Conntrack.Conns = map[FirewallPacket]*conn{} + fw.Conntrack.Unlock() } diff --git a/interface.go b/interface.go index 5739ea0..95caa12 100644 --- a/interface.go +++ b/interface.go @@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) { } oldFw := f.firewall + conntrack := oldFw.Conntrack + conntrack.Lock() + defer conntrack.Unlock() + + fw.rulesVersion = oldFw.rulesVersion + 1 + // If rulesVersion is back to zero, we have wrapped all the way around. Be + // safe and just reset conntrack in this case. + if fw.rulesVersion == 0 { + l.WithField("firewallHash", fw.GetRuleHash()). + WithField("oldFirewallHash", oldFw.GetRuleHash()). + WithField("rulesVersion", fw.rulesVersion). + Warn("firewall rulesVersion has overflowed, resetting conntrack") + } else { + fw.Conntrack = conntrack + } + f.firewall = fw oldFw.Destroy() l.WithField("firewallHash", fw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()). + WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") }