From 883e09a39282ccc38a2ae40ee39ddba79f9ce222 Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Mon, 29 Mar 2021 12:10:19 -0500 Subject: [PATCH] Don't use a global ca pool (#426) --- cert.go | 2 -- handshake_ix.go | 4 ++-- inside.go | 4 ++-- interface.go | 8 ++++++-- main.go | 6 +++--- outside.go | 6 +++--- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cert.go b/cert.go index 0e2ce3c..dfee77a 100644 --- a/cert.go +++ b/cert.go @@ -11,8 +11,6 @@ import ( "github.com/slackhq/nebula/cert" ) -var trustedCAs *cert.NebulaCAPool - type CertState struct { certificate *cert.NebulaCertificate rawCertificate []byte diff --git a/handshake_ix.go b/handshake_ix.go index 63070de..1749c16 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). @@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [ return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) if err != nil { f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). diff --git a/inside.go b/inside.go index c682f19..8cb1e51 100644 --- a/inside.go +++ b/inside.go @@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, ci.queueLock.Unlock() } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) if dropReason == nil { mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q) if f.lightHouse != nil && mc%5000 == 0 { @@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) + dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 8bd95ae..5afb84f 100644 --- a/interface.go +++ b/interface.go @@ -10,6 +10,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" ) const mtu = 9001 @@ -41,6 +42,7 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string + caPool *cert.NebulaCAPool ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -63,6 +65,7 @@ type Interface struct { dropMulticast bool udpBatchSize int routines int + caPool *cert.NebulaCAPool // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse rebindCount int8 @@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]*udpConn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), + caPool: c.caPool, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) { return } - trustedCAs = newCAs - f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") + f.caPool = newCAs + f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") } func (f *Interface) reloadCertKey(c *Config) { diff --git a/main.go b/main.go index 35dda76..ac70adf 100644 --- a/main.go +++ b/main.go @@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L } }) - // trustedCAs is currently a global, so loadCA operates on that global directly - trustedCAs, err = loadCAFromConfig(l, config) + caPool, err := loadCAFromConfig(l, config) if err != nil { //The errors coming out of loadCA are already nicely formatted return nil, NewContextualError("Failed to load ca from config", nil, err) } - l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") + l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") cs, err := NewCertStateFromConfig(config) if err != nil { @@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L routines: routines, MessageMetrics: messageMetrics, version: buildVersion, + caPool: caPool, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, diff --git a/outside.go b/outside.go index 9acd2e1..1832dbc 100644 --- a/outside.go +++ b/outside.go @@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). @@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N } */ -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) { +func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) { pk := h.PeerStatic() if pk == nil { @@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce } c, _ := cert.UnmarshalNebulaCertificate(recombined) - isValid, err := c.Verify(time.Now(), trustedCAs) + isValid, err := c.Verify(time.Now(), caPool) if err != nil { return c, fmt.Errorf("certificate validation failed: %s", err) } else if !isValid {