From add1b2177791a0e392a711fd1ffe3bb52f778ff7 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 2 Mar 2020 16:21:33 -0500 Subject: [PATCH] only create a CIDRTree for each host if necessary (#198) A CIDRTree can be expensive to create, so only do it if we need it. If the remote host only has one IP address and no subnets, just do an exact IP match instead. Fixes: #171 --- firewall.go | 11 +++++++++-- firewall_test.go | 5 +++++ hostmap.go | 5 +++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/firewall.go b/firewall.go index 45373b6..7c907d3 100644 --- a/firewall.go +++ b/firewall.go @@ -354,8 +354,15 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host } // Make sure remote address matches nebula certificate - if h.remoteCidr.Contains(fp.RemoteIP) == nil { - return true + if remoteCidr := h.remoteCidr; remoteCidr != nil { + if remoteCidr.Contains(fp.RemoteIP) == nil { + return true + } + } else { + // Simple case: Certificate has one IP and no subnets + if fp.RemoteIP != h.hostId { + return true + } } // Make sure we are supposed to be handling this local ip address diff --git a/firewall_test.go b/firewall_test.go index ceb589d..e6f90d8 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -171,6 +171,7 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, + hostId: ip2int(ipNet.IP), } h.CreateRemoteCIDR(&c) @@ -344,6 +345,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, + hostId: ip2int(ipNet.IP), } h.CreateRemoteCIDR(&c) @@ -410,6 +412,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, + hostId: ip2int(ipNet.IP), } h1.CreateRemoteCIDR(&c1) @@ -424,6 +427,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, + hostId: ip2int(ipNet.IP), } h2.CreateRemoteCIDR(&c2) @@ -438,6 +442,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, + hostId: ip2int(ipNet.IP), } h3.CreateRemoteCIDR(&c3) diff --git a/hostmap.go b/hostmap.go index 639bfb2..aaba557 100644 --- a/hostmap.go +++ b/hostmap.go @@ -623,6 +623,11 @@ func (i *HostInfo) RecvErrorExceeded() bool { } func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { + if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { + // Simple case, no CIDRTree needed + return + } + remoteCidr := NewCIDRTree() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})