From 32e2619323e830dbef20d983d7e46ee73f8727f3 Mon Sep 17 00:00:00 2001 From: Donatas Abraitis Date: Wed, 20 Oct 2021 21:23:33 +0300 Subject: [PATCH] Teardown tunnel automatically if peer's certificate expired (#370) --- connection_manager.go | 83 +++++++++++++++++++++++++-------- connection_manager_test.go | 95 ++++++++++++++++++++++++++++++++++++++ examples/config.yml | 4 +- interface.go | 3 ++ main.go | 1 + 5 files changed, 167 insertions(+), 19 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index db58274..78b1a8a 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -166,7 +166,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) // Check for traffic coming back in from this host. traf := n.CheckIn(vpnIP) - // If we saw incoming packets from this ip, just return + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + if err != nil { + n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + + if !n.intf.disconnectInvalid { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } + } + + if n.handleInvalidCertificate(now, vpnIP, hostinfo) { + continue + } + + // If we saw an incoming packets from this ip and peer's certificate is not + // expired, just ignore. if traf { if n.l.Level >= logrus.DebugLevel { n.l.WithField("vpnIp", IntIp(vpnIP)). @@ -178,15 +194,6 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) continue } - // If we didn't we may need to probe or destroy the conn - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) - if err != nil { - n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) - continue - } - hostinfo.logger(n.l). WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") @@ -213,22 +220,31 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { vpnIP := ep.(uint32) - // If we saw incoming packets from this ip, just return + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + if err != nil { + n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + + if !n.intf.disconnectInvalid { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } + } + + if n.handleInvalidCertificate(now, vpnIP, hostinfo) { + continue + } + + // If we saw an incoming packets from this ip and peer's certificate is not + // expired, just ignore. traf := n.CheckIn(vpnIP) if traf { n.l.WithField("vpnIp", IntIp(vpnIP)). WithField("tunnelCheck", m{"state": "alive", "method": "active"}). Debug("Tunnel status") - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) - continue - } - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) - if err != nil { n.ClearIP(vpnIP) n.ClearPendingDeletion(vpnIP) - n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) continue } @@ -256,3 +272,34 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { } } } + +// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid +func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool { + if !n.intf.disconnectInvalid { + return false + } + + remoteCert := hostinfo.GetCert() + if remoteCert == nil { + return false + } + + valid, err := remoteCert.Verify(now, n.intf.caPool) + if valid { + return false + } + + fingerprint, _ := remoteCert.Sha256Sum() + n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err). + WithField("certName", remoteCert.Details.Name). + WithField("fingerprint", fingerprint). + Info("Remote certificate is no longer valid, tearing down the tunnel") + + // Inform the remote and close the tunnel locally + n.intf.sendCloseTunnel(hostinfo) + n.intf.closeTunnel(hostinfo, false) + + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + return true +} diff --git a/connection_manager_test.go b/connection_manager_test.go index d88aed2..d3b2b49 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -1,6 +1,8 @@ package nebula import ( + "crypto/ed25519" + "crypto/rand" "net" "testing" "time" @@ -148,3 +150,96 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.Contains(t, nc.hostMap.Hosts, vpnIP) } + +// Check if we can disconnect the peer. +// Validate if the peer's certificate is invalid (expired, etc.) +// Disconnect only if disconnectInvalid: true is set. +func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { + now := time.Now() + l := NewTestLogger() + ipNet := net.IPNet{ + IP: net.IPv4(172, 1, 1, 2), + Mask: net.IPMask{255, 255, 255, 0}, + } + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + + // Generate keys for CA and peer's cert. + pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) + caCert := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "ca", + NotBefore: now, + NotAfter: now.Add(1 * time.Hour), + IsCA: true, + PublicKey: pubCA, + }, + } + caCert.Sign(privCA) + ncp := &cert.NebulaCAPool{ + CAs: cert.NewCAPool().CAs, + } + ncp.CAs["ca"] = &caCert + + pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) + peerCert := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host", + Ips: []*net.IPNet{&ipNet}, + Subnets: []*net.IPNet{}, + NotBefore: now, + NotAfter: now.Add(60 * time.Second), + PublicKey: pubCrt, + IsCA: false, + Issuer: "ca", + }, + } + peerCert.Sign(privCA) + + cs := &CertState{ + rawCertificate: []byte{}, + privateKey: []byte{}, + certificate: &cert.NebulaCertificate{}, + rawCertificateNoKey: []byte{}, + } + + lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) + ifce := &Interface{ + hostMap: hostMap, + inside: &Tun{}, + outside: &udpConn{}, + certState: cs, + firewall: &Firewall{}, + lightHouse: lh, + handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), + l: l, + disconnectInvalid: true, + caPool: ncp, + } + + // Create manager + nc := newConnectionManager(l, ifce, 5, 10) + ifce.connectionManager = nc + hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo.ConnectionState = &ConnectionState{ + certState: cs, + peerCert: &peerCert, + H: &noise.HandshakeState{}, + } + + // Move ahead 45s. + // Check if to disconnect with invalid certificate. + // Should be alive. + nextTick := now.Add(45 * time.Second) + destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) + assert.False(t, destroyed) + + // Move ahead 61s. + // Check if to disconnect with invalid certificate. + // Should be disconnected. + nextTick = now.Add(61 * time.Second) + destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo) + assert.True(t, destroyed) +} diff --git a/examples/config.yml b/examples/config.yml index baa4a1c..dce8ef9 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -7,9 +7,11 @@ pki: ca: /etc/nebula/ca.crt cert: /etc/nebula/host.crt key: /etc/nebula/host.key - #blocklist is a list of certificate fingerprints that we will refuse to talk to + # blocklist is a list of certificate fingerprints that we will refuse to talk to #blocklist: # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 + # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. + #disconnect_invalid: false # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. diff --git a/interface.go b/interface.go index 108ca05..9ea6c3b 100644 --- a/interface.go +++ b/interface.go @@ -43,6 +43,7 @@ type InterfaceConfig struct { MessageMetrics *MessageMetrics version string caPool *cert.NebulaCAPool + disconnectInvalid bool ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -67,6 +68,7 @@ type Interface struct { udpBatchSize int routines int caPool *cert.NebulaCAPool + disconnectInvalid bool // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse rebindCount int8 @@ -118,6 +120,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { writers: make([]*udpConn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), caPool: c.caPool, + disconnectInvalid: c.disconnectInvalid, myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP), conntrackCacheTimeout: c.ConntrackCacheTimeout, diff --git a/main.go b/main.go index 67d4b51..f18a971 100644 --- a/main.go +++ b/main.go @@ -371,6 +371,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L MessageMetrics: messageMetrics, version: buildVersion, caPool: caPool, + disconnectInvalid: config.GetBool("pki.disconnect_invalid", false), ConntrackCacheTimeout: conntrackCacheTimeout, l: l,