Split the application into modules

Splitting into modules will help keep concerns separate,
at the cost of a slightly more verbose code.
This commit is contained in:
kaiyou 2020-05-07 10:25:58 +02:00 committed by Leo Antunes
parent 740a9c44c6
commit dadfbee083
7 changed files with 87 additions and 81 deletions

View File

@ -1,4 +1,4 @@
package main
package cluster
import (
"crypto/rand"
@ -10,30 +10,33 @@ import (
"path"
"time"
"github.com/costela/wesher/common"
"github.com/hashicorp/memberlist"
"github.com/mattn/go-isatty"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// ClusterState keeps track of information needed to rejoin the cluster
type ClusterState struct {
const KeyLen = 32
// State keeps track of information needed to rejoin the cluster
type State struct {
ClusterKey []byte
Nodes []node
Nodes []common.Node
}
type cluster struct {
localName string // used to avoid LocalNode(); should not change
type Cluster struct {
LocalName string // used to avoid LocalNode(); should not change
ml *memberlist.Memberlist
getMeta func(int) []byte
state *ClusterState
state *State
events chan memberlist.NodeEvent
}
const statePath = "/var/lib/wesher/state.json"
func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*cluster, error) {
state := &ClusterState{}
func New(init bool, clusterKey []byte, bindAddr string, bindPort int, useIPAsName bool, getMeta func(int) []byte) (*Cluster, error) {
state := &State{}
if !init {
loadState(state)
}
@ -58,8 +61,8 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
return nil, err
}
cluster := cluster{
localName: ml.LocalNode().Name,
cluster := Cluster{
LocalName: ml.LocalNode().Name,
ml: ml,
getMeta: getMeta,
// The big channel buffer is a work-around for https://github.com/hashicorp/memberlist/issues/23
@ -74,21 +77,21 @@ func newCluster(init bool, clusterKey []byte, bindAddr string, bindPort int, use
return &cluster, nil
}
func (c *cluster) NotifyConflict(node, other *memberlist.Node) {
func (c *Cluster) NotifyConflict(node, other *memberlist.Node) {
logrus.Errorf("node name conflict detected: %s", other.Name)
}
func (c *cluster) NodeMeta(limit int) []byte {
func (c *Cluster) NodeMeta(limit int) []byte {
return c.getMeta(limit)
}
// none of these are used
func (c *cluster) NotifyMsg([]byte) {}
func (c *cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
func (c *cluster) LocalState(join bool) []byte { return nil }
func (c *cluster) MergeRemoteState(buf []byte, join bool) {}
func (c *Cluster) NotifyMsg([]byte) {}
func (c *Cluster) GetBroadcasts(overhead, limit int) [][]byte { return nil }
func (c *Cluster) LocalState(join bool) []byte { return nil }
func (c *Cluster) MergeRemoteState(buf []byte, join bool) {}
func (c *cluster) join(addrs []string) error {
func (c *Cluster) Join(addrs []string) error {
if len(addrs) == 0 {
for _, n := range c.state.Nodes {
addrs = append(addrs, n.Addr.String())
@ -103,22 +106,22 @@ func (c *cluster) join(addrs []string) error {
return nil
}
func (c *cluster) leave() {
func (c *Cluster) Leave() {
c.saveState()
c.ml.Leave(10 * time.Second)
c.ml.Shutdown() //nolint: errcheck
}
func (c *cluster) update() {
func (c *Cluster) Update() {
c.ml.UpdateNode(1 * time.Second) // we currently do not update after creation
}
func (c *cluster) members() <-chan []node {
changes := make(chan []node)
func (c *Cluster) Members() <-chan []common.Node {
changes := make(chan []common.Node)
go func() {
for {
event := <-c.events
if event.Node.Name == c.localName {
if event.Node.Name == c.LocalName {
// ignore events about ourselves
continue
}
@ -131,12 +134,12 @@ func (c *cluster) members() <-chan []node {
logrus.Infof("node %s left", event.Node)
}
nodes := make([]node, 0)
nodes := make([]common.Node, 0)
for _, n := range c.ml.Members() {
if n.Name == c.localName {
if n.Name == c.LocalName {
continue
}
nodes = append(nodes, node{
nodes = append(nodes, common.Node{
Name: n.Name,
Addr: n.Addr,
Meta: n.Meta,
@ -150,12 +153,12 @@ func (c *cluster) members() <-chan []node {
return changes
}
func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
func computeClusterKey(state *State, clusterKey []byte) ([]byte, error) {
if len(clusterKey) == 0 {
clusterKey = state.ClusterKey
}
if len(clusterKey) == 0 {
clusterKey = make([]byte, clusterKeyLen)
clusterKey = make([]byte, KeyLen)
_, err := rand.Read(clusterKey)
if err != nil {
return nil, err
@ -169,7 +172,7 @@ func computeClusterKey(state *ClusterState, clusterKey []byte) ([]byte, error) {
return clusterKey, nil
}
func (c *cluster) saveState() error {
func (c *Cluster) saveState() error {
if err := os.MkdirAll(path.Dir(statePath), 0700); err != nil {
return err
}
@ -182,7 +185,7 @@ func (c *cluster) saveState() error {
return ioutil.WriteFile(statePath, stateOut, 0600)
}
func loadState(cs *ClusterState) {
func loadState(cs *State) {
content, err := ioutil.ReadFile(statePath)
if err != nil {
if !os.IsNotExist(err) {
@ -192,7 +195,7 @@ func loadState(cs *ClusterState) {
}
// avoid partially unmarshalled content by using a temp var
csTmp := &ClusterState{}
csTmp := &State{}
if err := json.Unmarshal(content, csTmp); err != nil {
logrus.Warnf("could not decode state: %s", err)
} else {

View File

@ -1,4 +1,4 @@
package main
package common
import (
"bytes"
@ -9,25 +9,25 @@ import (
"github.com/sirupsen/logrus"
)
// nodeMeta holds metadata sent over the cluster
type nodeMeta struct {
// NodeMeta holds metadata sent over the cluster
type NodeMeta struct {
OverlayAddr net.IPNet
PubKey string
}
// Node holds the memberlist node structure
type node struct {
type Node struct {
Name string
Addr net.IP
Meta []byte
nodeMeta
NodeMeta
}
func (n *node) String() string {
func (n *Node) String() string {
return n.Addr.String()
}
func encodeNodeMeta(nm nodeMeta, limit int) []byte {
func EncodeNodeMeta(nm NodeMeta, limit int) []byte {
buf := &bytes.Buffer{}
if err := gob.NewEncoder(buf).Encode(nm); err != nil {
logrus.Errorf("could not encode local state: %s", err)
@ -40,10 +40,10 @@ func encodeNodeMeta(nm nodeMeta, limit int) []byte {
return buf.Bytes()
}
func decodeNodeMeta(b []byte) (nodeMeta, error) {
func DecodeNodeMeta(b []byte) (NodeMeta, error) {
// TODO: we blindly trust the info we get from the peers; We should be more defensive to limit the damage a leaked
// PSK can cause.
nm := nodeMeta{}
nm := NodeMeta{}
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&nm); err != nil {
return nm, errors.Wrap(err, "could not decode node meta")
}

View File

@ -4,12 +4,11 @@ import (
"fmt"
"net"
"github.com/costela/wesher/cluster"
"github.com/hashicorp/go-sockaddr"
"github.com/stevenroose/gonfig"
)
const clusterKeyLen = 32
type config struct {
ClusterKey []byte `id:"cluster-key" desc:"shared key for cluster membership; must be 32 bytes base64 encoded; will be generated if not provided"`
Join []string `desc:"comma separated list of hostnames or IP addresses to existing cluster members; if not provided, will attempt resuming any known state or otherwise wait for further members."`
@ -36,8 +35,8 @@ func loadConfig() (*config, error) {
}
// perform some validation
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != clusterKeyLen {
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", clusterKeyLen, len(config.ClusterKey))
if len(config.ClusterKey) != 0 && len(config.ClusterKey) != cluster.KeyLen {
return nil, fmt.Errorf("unsupported cluster key length; expected %d, got %d", cluster.KeyLen, len(config.ClusterKey))
}
if bits, _ := ((*net.IPNet)(config.OverlayNet)).Mask.Size(); bits%8 != 0 {

33
main.go
View File

@ -9,7 +9,10 @@ import (
"time"
"github.com/cenkalti/backoff"
"github.com/costela/wesher/cluster"
"github.com/costela/wesher/common"
"github.com/costela/wesher/etchosts"
"github.com/costela/wesher/wg"
"github.com/sirupsen/logrus"
)
@ -30,28 +33,28 @@ func main() {
}
logrus.SetLevel(logLevel)
wg, err := newWGConfig(config.Interface, config.WireguardPort)
wg, err := wg.NewWGConfig(config.Interface, config.WireguardPort)
if err != nil {
logrus.WithError(err).Fatal("could not instantiate wireguard controller")
}
getMeta := func(limit int) []byte {
return encodeNodeMeta(nodeMeta{
return common.EncodeNodeMeta(common.NodeMeta{
OverlayAddr: wg.OverlayAddr,
PubKey: wg.PubKey.String(),
}, limit)
}
cluster, err := newCluster(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
cluster, err := cluster.New(config.Init, config.ClusterKey, config.BindAddr, config.ClusterPort, config.UseIPAsName, getMeta)
if err != nil {
logrus.WithError(err).Fatal("could not create cluster")
}
wg.assignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.localName)
cluster.update()
wg.AssignOverlayAddr((*net.IPNet)(config.OverlayNet), cluster.LocalName)
cluster.Update()
nodec := cluster.members() // avoid deadlocks by starting before join
nodec := cluster.Members() // avoid deadlocks by starting before join
if err := backoff.RetryNotify(
func() error { return cluster.join(config.Join) },
func() error { return cluster.Join(config.Join) },
backoff.NewExponentialBackOff(),
func(err error, dur time.Duration) {
logrus.WithError(err).Errorf("could not join cluster, retrying in %s", dur)
@ -67,20 +70,20 @@ func main() {
select {
case rawNodes := <-nodec:
logrus.Info("cluster members:\n")
nodes := make([]node, 0, len(rawNodes))
nodes := make([]common.Node, 0, len(rawNodes))
for _, node := range rawNodes {
meta, err := decodeNodeMeta(node.Meta)
meta, err := common.DecodeNodeMeta(node.Meta)
if err != nil {
logrus.Warnf("\t addr: %s, could not decode metadata", node.Addr)
continue
}
node.nodeMeta = meta
node.NodeMeta = meta
nodes = append(nodes, node)
logrus.Infof("\taddr: %s, overlay: %s, pubkey: %s", node.Addr, node.OverlayAddr, node.PubKey)
}
if err := wg.setUpInterface(nodes); err != nil {
if err := wg.SetUpInterface(nodes); err != nil {
logrus.WithError(err).Error("could not up interface")
wg.downInterface()
wg.DownInterface()
}
if !config.NoEtcHosts {
if err := writeToEtcHosts(nodes); err != nil {
@ -89,13 +92,13 @@ func main() {
}
case <-incomingSigs:
logrus.Info("terminating...")
cluster.leave()
cluster.Leave()
if !config.NoEtcHosts {
if err := writeToEtcHosts(nil); err != nil {
logrus.WithError(err).Error("could not remove stale hosts entries")
}
}
if err := wg.downInterface(); err != nil {
if err := wg.DownInterface(); err != nil {
logrus.WithError(err).Error("could not down interface")
}
os.Exit(0)
@ -103,7 +106,7 @@ func main() {
}
}
func writeToEtcHosts(nodes []node) error {
func writeToEtcHosts(nodes []common.Node) error {
hosts := make(map[string][]string, len(nodes))
for _, n := range nodes {
hosts[n.OverlayAddr.IP.String()] = []string{n.Name}

View File

@ -1,4 +1,4 @@
package main
package wg
import "github.com/vishvananda/netlink"

View File

@ -1,17 +1,18 @@
package main
package wg
import (
"hash/fnv"
"net"
"os"
"github.com/costela/wesher/common"
"github.com/pkg/errors"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wgState struct {
type WgState struct {
iface string
client *wgctrl.Client
OverlayAddr net.IPNet
@ -20,7 +21,7 @@ type wgState struct {
PubKey wgtypes.Key
}
func newWGConfig(iface string, port int) (*wgState, error) {
func NewWGConfig(iface string, port int) (*WgState, error) {
client, err := wgctrl.New()
if err != nil {
return nil, errors.Wrap(err, "could not instantiate wireguard client")
@ -32,7 +33,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
}
pubKey := privKey.PublicKey()
wgState := wgState{
wgState := WgState{
iface: iface,
client: client,
Port: port,
@ -42,7 +43,7 @@ func newWGConfig(iface string, port int) (*wgState, error) {
return &wgState, nil
}
func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
func (wg *WgState) AssignOverlayAddr(ipnet *net.IPNet, name string) {
// TODO: this is way too brittle and opaque
bits, size := ipnet.Mask.Size()
ip := make([]byte, len(ipnet.IP))
@ -62,7 +63,7 @@ func (wg *wgState) assignOverlayAddr(ipnet *net.IPNet, name string) {
}
}
func (wg *wgState) downInterface() error {
func (wg *WgState) DownInterface() error {
if _, err := wg.client.Device(wg.iface); err != nil {
if os.IsNotExist(err) {
return nil // device already gone; noop
@ -76,12 +77,12 @@ func (wg *wgState) downInterface() error {
return netlink.LinkDel(link)
}
func (wg *wgState) setUpInterface(nodes []node) error {
func (wg *WgState) SetUpInterface(nodes []common.Node) error {
if err := netlink.LinkAdd(&wireguard{LinkAttrs: netlink.LinkAttrs{Name: wg.iface}}); err != nil && !os.IsExist(err) {
return errors.Wrapf(err, "could not create interface %s", wg.iface)
}
peerCfgs, err := wg.nodesToPeerConfigs(nodes)
peerCfgs, err := wg.NodesToPeerConfigs(nodes)
if err != nil {
return errors.Wrap(err, "error converting received node information to wireguard format")
}
@ -121,7 +122,7 @@ func (wg *wgState) setUpInterface(nodes []node) error {
return nil
}
func (wg *wgState) nodesToPeerConfigs(nodes []node) ([]wgtypes.PeerConfig, error) {
func (wg *WgState) NodesToPeerConfigs(nodes []common.Node) ([]wgtypes.PeerConfig, error) {
peerCfgs := make([]wgtypes.PeerConfig, len(nodes))
for i, node := range nodes {
pubKey, err := wgtypes.ParseKey(node.PubKey)

View File

@ -1,4 +1,4 @@
package main
package wg
import (
"net"
@ -31,8 +31,8 @@ func Test_wgState_assignOverlayAddr(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wg := &wgState{}
wg.assignOverlayAddr(tt.args.ipnet, tt.args.name)
wg := &WgState{}
wg.AssignOverlayAddr(tt.args.ipnet, tt.args.name)
if !reflect.DeepEqual(wg.OverlayAddr.IP.String(), tt.want) {
t.Errorf("assignOverlayAddr() set = %s, want %s", wg.OverlayAddr, tt.want)
@ -47,8 +47,8 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("10.0.0.0/24")
assignments := make(map[string]string)
for _, n := range []string{"test", "test1", "test2", "1test", "2test"} {
wg := &wgState{}
wg.assignOverlayAddr(ipnet, n)
wg := &WgState{}
wg.AssignOverlayAddr(ipnet, n)
if assigned, ok := assignments[wg.OverlayAddr.String()]; ok {
t.Errorf("IP assignment collision: hash(%s) = hash(%s)", n, assigned)
}
@ -59,10 +59,10 @@ func Test_wgState_assignOverlayAddr_no_obvious_collisions(t *testing.T) {
// This should ensure the obvious fact that the same name should map to the same IP if called twice.
func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
wg1 := &wgState{}
wg1.assignOverlayAddr(ipnet, "test")
wg2 := &wgState{}
wg2.assignOverlayAddr(ipnet, "test")
wg1 := &WgState{}
wg1.AssignOverlayAddr(ipnet, "test")
wg2 := &WgState{}
wg2.AssignOverlayAddr(ipnet, "test")
if wg1.OverlayAddr.String() != wg2.OverlayAddr.String() {
t.Errorf("assignOverlayAddr() %s != %s", wg1.OverlayAddr, wg2.OverlayAddr)
}
@ -70,10 +70,10 @@ func Test_wgState_assignOverlayAddr_consistent(t *testing.T) {
func Test_wgState_assignOverlayAddr_repeatable(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("10.0.0.0/8")
wg := &wgState{}
wg.assignOverlayAddr(ipnet, "test")
wg := &WgState{}
wg.AssignOverlayAddr(ipnet, "test")
gen1 := wg.OverlayAddr.String()
wg.assignOverlayAddr(ipnet, "test")
wg.AssignOverlayAddr(ipnet, "test")
gen2 := wg.OverlayAddr.String()
if gen1 != gen2 {
t.Errorf("assignOverlayAddr() %s != %s", gen1, gen2)