package wg import ( "hash/fnv" "net" "os" "time" "github.com/costela/wesher/common" "github.com/pkg/errors" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // State holds the configured state of a Wesher Wireguard interface type State struct { iface string client *wgctrl.Client OverlayAddr net.IPNet Port int Mtu int PrivKey wgtypes.Key PubKey wgtypes.Key KeepaliveInterval *time.Duration } // New creates a new Wesher Wireguard state // The Wireguard keys are generated for every new interface // The interface must later be setup using SetUpInterface func New(iface string, port int, mtu int, ipnet *net.IPNet, name string, keepaliveInterval *time.Duration) (*State, *common.Node, error) { client, err := wgctrl.New() if err != nil { return nil, nil, errors.Wrap(err, "could not instantiate wireguard client") } privKey, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, nil, err } pubKey := privKey.PublicKey() state := State{ iface: iface, client: client, Port: port, Mtu: mtu, PrivKey: privKey, PubKey: pubKey, KeepaliveInterval: keepaliveInterval, } state.assignOverlayAddr(ipnet, name) node := &common.Node{} node.OverlayAddr = state.OverlayAddr node.PubKey = state.PubKey.String() return &state, node, nil } // assignOverlayAddr assigns a new address to the interface // The address is assigned inside the provided network and depends on the // provided name deterministically // Currently, the address is assigned by hashing the name and mapping that // hash in the target network space func (s *State) assignOverlayAddr(ipnet *net.IPNet, name string) { // TODO: this is way too brittle and opaque bits, size := ipnet.Mask.Size() ip := make([]byte, len(ipnet.IP)) copy(ip, []byte(ipnet.IP)) if ipnet.IP.Mask(ipnet.Mask).Equal(ipnet.IP) { h := fnv.New128a() h.Write([]byte(name)) hb := h.Sum(nil) for i := 1; i <= (size-bits)/8; i++ { ip[len(ip)-i] = hb[len(hb)-i] } } s.OverlayAddr = net.IPNet{ IP: net.IP(ip), Mask: net.CIDRMask(size, size), // either /32 or /128, depending if ipv4 or ipv6 } } // DownInterface shuts down the associated network interface func (s *State) DownInterface() error { if _, err := s.client.Device(s.iface); err != nil { if os.IsNotExist(err) { return nil // device already gone; noop } return err } link, err := netlink.LinkByName(s.iface) if err != nil { return err } return netlink.LinkDel(link) } // SetUpInterface creates and sets up the associated network interface func (s *State) SetUpInterface(nodes []common.Node, routedNet *net.IPNet) error { if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: s.iface}}); err != nil && !os.IsExist(err) { return errors.Wrapf(err, "could not create interface %s", s.iface) } peerCfgs, err := s.nodesToPeerConfigs(nodes) if err != nil { return errors.Wrap(err, "error converting received node information to wireguard format") } if err := s.client.ConfigureDevice(s.iface, wgtypes.Config{ PrivateKey: &s.PrivKey, ListenPort: &s.Port, ReplacePeers: true, Peers: peerCfgs, }); err != nil { return errors.Wrapf(err, "could not set wireguard configuration for %s", s.iface) } link, err := netlink.LinkByName(s.iface) if err != nil { return errors.Wrapf(err, "could not get link information for %s", s.iface) } if err := netlink.AddrReplace(link, &netlink.Addr{ IPNet: &s.OverlayAddr, }); err != nil { return errors.Wrapf(err, "could not set address for %s", s.iface) } if err := netlink.LinkSetMTU(link, s.Mtu-80); err != nil { return errors.Wrapf(err, "could not set MTU for %s", s.iface) } if err := netlink.LinkSetUp(link); err != nil { return errors.Wrapf(err, "could not enable interface %s", s.iface) } // first compute routes currentRoutes, err := netlink.RouteList(link, netlink.FAMILY_ALL) if err != nil { return errors.Wrapf(err, "could not update the routing table for %s", s.iface) } routes := make([]netlink.Route, 0) for index, node := range nodes { // dev route routes = append(routes, netlink.Route{ LinkIndex: link.Attrs().Index, Dst: &nodes[index].OverlayAddr, Scope: netlink.SCOPE_LINK, }) // via routes for _, route := range node.Routes { if !routedNet.Contains(route.IP) { continue } routes = append(routes, netlink.Route{ LinkIndex: link.Attrs().Index, Dst: &route, Gw: node.OverlayAddr.IP, Scope: netlink.SCOPE_SITE, }) } } // then actually update the routing table for _, route := range routes { match := matchRoute(currentRoutes, route) if match == nil { netlink.RouteAdd(&route) } else if match.Gw.String() != route.Gw.String() { netlink.RouteReplace(&route) } } for _, route := range routes { // only delete a reoute if it is a site scope route that belongs to the routed net, mainly to // avoid deleting otherwise manually set routes if matchRoute(currentRoutes, route) == nil && route.Scope == netlink.SCOPE_LINK && routedNet.Contains(route.Dst.IP) { netlink.RouteDel(&route) } } return nil } func (s *State) nodesToPeerConfigs(nodes []common.Node) ([]wgtypes.PeerConfig, error) { peerCfgs := make([]wgtypes.PeerConfig, len(nodes)) for i, node := range nodes { pubKey, err := wgtypes.ParseKey(node.PubKey) if err != nil { return nil, err } peerCfgs[i] = wgtypes.PeerConfig{ PublicKey: pubKey, ReplaceAllowedIPs: true, Endpoint: &net.UDPAddr{ IP: node.Addr, Port: s.Port, }, AllowedIPs: append([]net.IPNet{node.OverlayAddr}, node.Routes...), PersistentKeepaliveInterval: s.KeepaliveInterval, } } return peerCfgs, nil } func matchRoute(set []netlink.Route, needle netlink.Route) *netlink.Route { // routes are considered equal if they overlap and have the same prefix length prefixn, _ := needle.Dst.Mask.Size() for _, route := range set { prefixr, _ := route.Dst.Mask.Size() if prefixn == prefixr && route.Dst.Contains(needle.Dst.IP) { return &route } } return nil }