Don't use a global ca pool (#426)

This commit is contained in:
Nathan Brown 2021-03-29 12:10:19 -05:00 committed by GitHub
parent 4603b5b2dd
commit 883e09a392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 16 additions and 14 deletions

View File

@ -11,8 +11,6 @@ import (
"github.com/slackhq/nebula/cert"
)
var trustedCAs *cert.NebulaCAPool
type CertState struct {
certificate *cert.NebulaCertificate
rawCertificate []byte

View File

@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).

View File

@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
ci.queueLock.Unlock()
}
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
if f.lightHouse != nil && mc%5000 == 0 {
@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View File

@ -10,6 +10,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
const mtu = 9001
@ -41,6 +42,7 @@ type InterfaceConfig struct {
routines int
MessageMetrics *MessageMetrics
version string
caPool *cert.NebulaCAPool
ConntrackCacheTimeout time.Duration
l *logrus.Logger
@ -63,6 +65,7 @@ type Interface struct {
dropMulticast bool
udpBatchSize int
routines int
caPool *cert.NebulaCAPool
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
rebindCount int8
@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
version: c.version,
writers: make([]*udpConn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) {
return
}
trustedCAs = newCAs
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
f.caPool = newCAs
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *Config) {

View File

@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
})
// trustedCAs is currently a global, so loadCA operates on that global directly
trustedCAs, err = loadCAFromConfig(l, config)
caPool, err := loadCAFromConfig(l, config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
if err != nil {
@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
caPool: caPool,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,

View File

@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
}
c, _ := cert.UnmarshalNebulaCertificate(recombined)
isValid, err := c.Verify(time.Now(), trustedCAs)
isValid, err := c.Verify(time.Now(), caPool)
if err != nil {
return c, fmt.Errorf("certificate validation failed: %s", err)
} else if !isValid {