Preserve conntrack table during firewall rules reload (SIGHUP) (#233)

Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
This commit is contained in:
Wade Simmons 2020-07-31 18:53:36 -04:00 committed by GitHub
parent 9b06748506
commit f3a6d8d990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 42 deletions

View File

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@ -37,13 +38,19 @@ type FirewallInterface interface {
type conn struct { type conn struct {
Expires time.Time // Time when this conntrack entry will expire Expires time.Time // Time when this conntrack entry will expire
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
} }
// TODO: need conntrack max tracked connections handling // TODO: need conntrack max tracked connections handling
type Firewall struct { type Firewall struct {
Conns map[FirewallPacket]*conn Conntrack *FirewallConntrack
InRules *FirewallTable InRules *FirewallTable
OutRules *FirewallTable OutRules *FirewallTable
@ -54,18 +61,23 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
TimerWheel *TimerWheel
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree localIps *CIDRTree
connMutex sync.Mutex rules string
rules string rulesVersion uint16
trackTCPRTT bool trackTCPRTT bool
metricTCPRTT metrics.Histogram metricTCPRTT metrics.Histogram
} }
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
TimerWheel *TimerWheel
}
type FirewallTable struct { type FirewallTable struct {
TCP firewallPort TCP firewallPort
UDP firewallPort UDP firewallPort
@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
} }
return &Firewall{ return &Firewall{
Conns: make(map[FirewallPacket]*conn), Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TimerWheel: NewTimerWheel(min, max),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error { func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) { if f.inConns(packet, fp, incoming, h, caPool) {
return nil return nil
} }
@ -398,26 +412,66 @@ func (f *Firewall) Destroy() {
} }
func (f *Firewall) EmitStats() { func (f *Firewall) EmitStats() {
conntrackCount := len(f.Conns) conntrack := f.Conntrack
conntrack.Lock()
conntrackCount := len(conntrack.Conns)
conntrack.Unlock()
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
} }
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
f.connMutex.Lock() conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test // Purge every time we test
ep, has := f.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep) f.evict(ep)
} }
c, ok := f.Conns[fp] c, ok := conntrack.Conns[fp]
if !ok { if !ok {
f.connMutex.Unlock() conntrack.Unlock()
return false return false
} }
if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate
// it still passes with the current rule set
table := f.OutRules
if c.incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
}
c.rulesVersion = f.rulesVersion
}
switch fp.Protocol { switch fp.Protocol {
case fwProtoTCP: case fwProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = time.Now().Add(f.TCPTimeout)
@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = time.Now().Add(f.DefaultTimeout)
} }
f.connMutex.Unlock() conntrack.Unlock()
return true return true
} }
@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
timeout = f.DefaultTimeout timeout = f.DefaultTimeout
} }
f.connMutex.Lock() conntrack := f.Conntrack
if _, ok := f.Conns[fp]; !ok { conntrack.Lock()
f.TimerWheel.Add(fp, timeout) if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Add(fp, timeout)
} }
// Record which rulesVersion allowed this connection, so we can retest after
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = time.Now().Add(timeout)
f.Conns[fp] = c conntrack.Conns[fp] = c
f.connMutex.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) evict(p FirewallPacket) { func (f *Firewall) evict(p FirewallPacket) {
//TODO: report a stat if the tcp rtt tracking was never resolved? //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn? // Are we still tracking this conn?
t, ok := f.Conns[p] conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
if !ok { if !ok {
return return
} }
@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) {
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
f.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }
// This conn is done // This conn is done
delete(f.Conns, p) delete(conntrack.Conns, p)
} }
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

View File

@ -17,37 +17,39 @@ import (
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c) fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.Conns) conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
assert.NotNil(t, conntrack.TimerWheel)
assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
assert.NotNil(t, fw.TimerWheel)
assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Second, fw.TCPTimeout)
assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout)
assert.Equal(t, time.Hour, fw.DefaultTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c) fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c) fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c) fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c) fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c) fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
} }
func TestFirewall_DropConntrackReload(t *testing.T) {
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) { func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) { ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
} }
func resetConntrack(fw *Firewall) { func resetConntrack(fw *Firewall) {
fw.connMutex.Lock() fw.Conntrack.Lock()
fw.Conns = map[FirewallPacket]*conn{} fw.Conntrack.Conns = map[FirewallPacket]*conn{}
fw.connMutex.Unlock() fw.Conntrack.Unlock()
} }

View File

@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) {
} }
oldFw := f.firewall oldFw := f.firewall
conntrack := oldFw.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
fw.rulesVersion = oldFw.rulesVersion + 1
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
} else {
fw.Conntrack = conntrack
}
f.firewall = fw f.firewall = fw
oldFw.Destroy() oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()). l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed") Info("New firewall has been installed")
} }