fix makeRouteTree allowMTU (#611)

With the previous implementation, we check if route.MTU is greater than zero,
but it will always be because we set it to the default MTU in
parseUnsafeRoutes. This change leaves it as zero in parseUnsafeRoutes so
it can be examined later.
This commit is contained in:
Wade Simmons 2021-12-14 11:52:28 -05:00 committed by GitHub
parent 15fdabc3ab
commit 068a93d1f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 33 additions and 28 deletions

View File

@ -7,6 +7,7 @@ import (
"runtime"
"strconv"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
@ -19,11 +20,11 @@ type Route struct {
Via *iputil.VpnIp
}
func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
routeTree := cidr.NewTree4()
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS)
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
if r.Via != nil {
@ -127,21 +128,19 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
}
rMtu, ok := m["mtu"]
if !ok {
rMtu = c.GetInt("tun.mtu", DefaultMTU)
}
mtu, ok := rMtu.(int)
if !ok {
mtu, err = strconv.Atoi(rMtu.(string))
if err != nil {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
var mtu int
if rMtu, ok := m["mtu"]; ok {
mtu, ok = rMtu.(int)
if !ok {
mtu, err = strconv.Atoi(rMtu.(string))
if err != nil {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err)
}
}
}
if mtu < 500 {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
if mtu != 0 && mtu < 500 {
return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu)
}
}
rMetric, ok := m["metric"]

View File

@ -191,7 +191,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Equal(t, DefaultMTU, routes[0].MTU)
assert.Equal(t, 0, routes[0].MTU)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
@ -249,7 +249,7 @@ func Test_makeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, n)
assert.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(routes, true)
routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err)
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))

View File

@ -77,7 +77,7 @@ type ifreqMTU struct {
}
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(routes, false)
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}

View File

@ -43,7 +43,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(routes, false)
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}

View File

@ -64,7 +64,7 @@ type ifreqQLEN struct {
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
routeTree, err := makeRouteTree(routes, true)
routeTree, err := makeRouteTree(l, routes, true)
if err != nil {
return nil, err
}
@ -105,12 +105,16 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
maxMTU := defaultMTU
for _, r := range routes {
if r.MTU == 0 {
r.MTU = defaultMTU
}
if r.MTU > maxMTU {
maxMTU = r.MTU
}
}
routeTree, err := makeRouteTree(routes, true)
routeTree, err := makeRouteTree(l, routes, true)
if err != nil {
return nil, err
}

View File

@ -25,7 +25,7 @@ type TestTun struct {
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
routeTree, err := makeRouteTree(routes, false)
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}

View File

@ -7,6 +7,7 @@ import (
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
"github.com/songgao/water"
@ -22,8 +23,8 @@ type waterTun struct {
*water.Interface
}
func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
routeTree, err := makeRouteTree(routes, false)
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}

View File

@ -30,14 +30,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
}
if useWintun {
device, err := newWinTun(deviceName, cidr, defaultMTU, routes)
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
device, err := newWaterTun(cidr, defaultMTU, routes)
device, err := newWaterTun(l, cidr, defaultMTU, routes)
if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err)
}

View File

@ -7,6 +7,7 @@ import (
"net"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/wintun"
@ -45,7 +46,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
@ -56,7 +57,7 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Rout
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
routeTree, err := makeRouteTree(routes, false)
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}