diff --git a/firewall.go b/firewall.go index 45373b6..7c907d3 100644 --- a/firewall.go +++ b/firewall.go @@ -354,8 +354,15 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host } // Make sure remote address matches nebula certificate - if h.remoteCidr.Contains(fp.RemoteIP) == nil { - return true + if remoteCidr := h.remoteCidr; remoteCidr != nil { + if remoteCidr.Contains(fp.RemoteIP) == nil { + return true + } + } else { + // Simple case: Certificate has one IP and no subnets + if fp.RemoteIP != h.hostId { + return true + } } // Make sure we are supposed to be handling this local ip address diff --git a/firewall_test.go b/firewall_test.go index ceb589d..e6f90d8 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -171,6 +171,7 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, + hostId: ip2int(ipNet.IP), } h.CreateRemoteCIDR(&c) @@ -344,6 +345,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, + hostId: ip2int(ipNet.IP), } h.CreateRemoteCIDR(&c) @@ -410,6 +412,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, + hostId: ip2int(ipNet.IP), } h1.CreateRemoteCIDR(&c1) @@ -424,6 +427,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, + hostId: ip2int(ipNet.IP), } h2.CreateRemoteCIDR(&c2) @@ -438,6 +442,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, + hostId: ip2int(ipNet.IP), } h3.CreateRemoteCIDR(&c3) diff --git a/hostmap.go b/hostmap.go index 639bfb2..aaba557 100644 --- a/hostmap.go +++ b/hostmap.go @@ -623,6 +623,11 @@ func (i *HostInfo) RecvErrorExceeded() bool { } func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { + if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { + // Simple case, no CIDRTree needed + return + } + remoteCidr := NewCIDRTree() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})