plugin: can execute plugins and communicate that way

This commit is contained in:
Mitchell Hashimoto 2014-05-28 21:09:47 -07:00
parent 91317a8608
commit 951b7b18eb
5 changed files with 763 additions and 0 deletions

356
plugin/client.go Normal file
View File

@ -0,0 +1,356 @@
package plugin
import (
"bufio"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/rpc"
"os"
"os/exec"
"strings"
"sync"
"time"
"unicode"
)
// If this is true, then the "unexpected EOF" panic will not be
// raised throughout the clients.
var Killed = false
// This is a slice of the "managed" clients which are cleaned up when
// calling Cleanup
var managedClients = make([]*Client, 0, 5)
// Client handles the lifecycle of a plugin application, determining its
// RPC address, and returning various types of Terraform interface implementations
// across the multi-process communication layer.
type Client struct {
config *ClientConfig
exited bool
doneLogging chan struct{}
l sync.Mutex
address net.Addr
service string
client *rpc.Client
}
// ClientConfig is the configuration used to initialize a new
// plugin client. After being used to initialize a plugin client,
// that configuration must not be modified again.
type ClientConfig struct {
// The unstarted subprocess for starting the plugin.
Cmd *exec.Cmd
// Managed represents if the client should be managed by the
// plugin package or not. If true, then by calling CleanupClients,
// it will automatically be cleaned up. Otherwise, the client
// user is fully responsible for making sure to Kill all plugin
// clients. By default the client is _not_ managed.
Managed bool
// The minimum and maximum port to use for communicating with
// the subprocess. If not set, this defaults to 10,000 and 25,000
// respectively.
MinPort, MaxPort uint
// StartTimeout is the timeout to wait for the plugin to say it
// has started successfully.
StartTimeout time.Duration
// If non-nil, then the stderr of the client will be written to here
// (as well as the log).
Stderr io.Writer
}
// This makes sure all the managed subprocesses are killed and properly
// logged. This should be called before the parent process running the
// plugins exits.
//
// This must only be called _once_.
func CleanupClients() {
// Set the killed to true so that we don't get unexpected panics
Killed = true
// Kill all the managed clients in parallel and use a WaitGroup
// to wait for them all to finish up.
var wg sync.WaitGroup
for _, client := range managedClients {
wg.Add(1)
go func(client *Client) {
client.Kill()
wg.Done()
}(client)
}
log.Println("waiting for all plugin processes to complete...")
wg.Wait()
}
// Creates a new plugin client which manages the lifecycle of an external
// plugin and gets the address for the RPC connection.
//
// The client must be cleaned up at some point by calling Kill(). If
// the client is a managed client (created with NewManagedClient) you
// can just call CleanupClients at the end of your program and they will
// be properly cleaned.
func NewClient(config *ClientConfig) (c *Client) {
if config.MinPort == 0 && config.MaxPort == 0 {
config.MinPort = 10000
config.MaxPort = 25000
}
if config.StartTimeout == 0 {
config.StartTimeout = 1 * time.Minute
}
if config.Stderr == nil {
config.Stderr = ioutil.Discard
}
c = &Client{config: config}
if config.Managed {
managedClients = append(managedClients, c)
}
return
}
// Client returns an RPC client for the plugin.
//
// Subsequent calls to this will return the same RPC client.
func (c *Client) Client() (*rpc.Client, error) {
addr, err := c.Start()
if err != nil {
return nil, err
}
c.l.Lock()
defer c.l.Unlock()
if c.client != nil {
return c.client, nil
}
conn, err := net.Dial(addr.Network(), addr.String())
if err != nil {
return nil, err
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
}
c.client = rpc.NewClient(conn)
return c.client, nil
}
// Tells whether or not the underlying process has exited.
func (c *Client) Exited() bool {
c.l.Lock()
defer c.l.Unlock()
return c.exited
}
// End the executing subprocess (if it is running) and perform any cleanup
// tasks necessary such as capturing any remaining logs and so on.
//
// This method blocks until the process successfully exits.
//
// This method can safely be called multiple times.
func (c *Client) Kill() {
cmd := c.config.Cmd
if cmd.Process == nil {
return
}
cmd.Process.Kill()
// Wait for the client to finish logging so we have a complete log
<-c.doneLogging
}
// Service returns the name of the service to use.
func (c *Client) Service() (string, error) {
if _, err := c.Start(); err != nil {
return "", err
}
return c.service, nil
}
// Starts the underlying subprocess, communicating with it to negotiate
// a port for RPC connections, and returning the address to connect via RPC.
//
// This method is safe to call multiple times. Subsequent calls have no effect.
// Once a client has been started once, it cannot be started again, even if
// it was killed.
func (c *Client) Start() (addr net.Addr, err error) {
c.l.Lock()
defer c.l.Unlock()
if c.address != nil {
return c.address, nil
}
c.doneLogging = make(chan struct{})
env := []string{
fmt.Sprintf("%s=%s", MagicCookieKey, MagicCookieValue),
fmt.Sprintf("TF_PLUGIN_MIN_PORT=%d", c.config.MinPort),
fmt.Sprintf("TF_PLUGIN_MAX_PORT=%d", c.config.MaxPort),
}
stdout_r, stdout_w := io.Pipe()
stderr_r, stderr_w := io.Pipe()
cmd := c.config.Cmd
cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, env...)
cmd.Stdin = os.Stdin
cmd.Stderr = stderr_w
cmd.Stdout = stdout_w
log.Printf("Starting plugin: %s %#v", cmd.Path, cmd.Args)
err = cmd.Start()
if err != nil {
return
}
// Make sure the command is properly cleaned up if there is an error
defer func() {
r := recover()
if err != nil || r != nil {
cmd.Process.Kill()
}
if r != nil {
panic(r)
}
}()
// Start goroutine to wait for process to exit
exitCh := make(chan struct{})
go func() {
// Make sure we close the write end of our stderr/stdout so
// that the readers send EOF properly.
defer stderr_w.Close()
defer stdout_w.Close()
// Wait for the command to end.
cmd.Wait()
// Log and make sure to flush the logs write away
log.Printf("%s: plugin process exited\n", cmd.Path)
os.Stderr.Sync()
// Mark that we exited
close(exitCh)
// Set that we exited, which takes a lock
c.l.Lock()
defer c.l.Unlock()
c.exited = true
}()
// Start goroutine that logs the stderr
go c.logStderr(stderr_r)
// Start a goroutine that is going to be reading the lines
// out of stdout
linesCh := make(chan []byte)
go func() {
defer close(linesCh)
buf := bufio.NewReader(stdout_r)
for {
line, err := buf.ReadBytes('\n')
if line != nil {
linesCh <- line
}
if err == io.EOF {
return
}
}
}()
// Make sure after we exit we read the lines from stdout forever
// so they dont' block since it is an io.Pipe
defer func() {
go func() {
for _ = range linesCh {
}
}()
}()
// Some channels for the next step
timeout := time.After(c.config.StartTimeout)
// Start looking for the address
log.Printf("Waiting for RPC address for: %s", cmd.Path)
select {
case <-timeout:
err = errors.New("timeout while waiting for plugin to start")
case <-exitCh:
err = errors.New("plugin exited before we could connect")
case lineBytes := <-linesCh:
// Trim the line and split by "|" in order to get the parts of
// the output.
line := strings.TrimSpace(string(lineBytes))
parts := strings.SplitN(line, "|", 4)
if len(parts) < 4 {
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
return
}
// Test the API version
if parts[0] != APIVersion {
err = fmt.Errorf("Incompatible API version with plugin. "+
"Plugin version: %s, Ours: %s", parts[0], APIVersion)
return
}
switch parts[1] {
case "tcp":
addr, err = net.ResolveTCPAddr("tcp", parts[2])
case "unix":
addr, err = net.ResolveUnixAddr("unix", parts[2])
default:
err = fmt.Errorf("Unknown address type: %s", parts[1])
}
// Grab the services
c.service = parts[3]
}
c.address = addr
return
}
func (c *Client) logStderr(r io.Reader) {
bufR := bufio.NewReader(r)
for {
line, err := bufR.ReadString('\n')
if line != "" {
c.config.Stderr.Write([]byte(line))
line = strings.TrimRightFunc(line, unicode.IsSpace)
log.Printf("%s: %s", c.config.Cmd.Path, line)
}
if err == io.EOF {
break
}
}
// Flag that we've completed logging for others
close(c.doneLogging)
}

153
plugin/client_test.go Normal file
View File

@ -0,0 +1,153 @@
package plugin
import (
"bytes"
"io/ioutil"
"os"
"strings"
"testing"
"time"
)
func TestClient(t *testing.T) {
process := helperProcess("mock")
c := NewClient(&ClientConfig{Cmd: process})
defer c.Kill()
// Test that it parses the proper address
addr, err := c.Start()
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}
if addr.Network() != "tcp" {
t.Fatalf("bad: %#v", addr)
}
if addr.String() != ":1234" {
t.Fatalf("bad: %#v", addr)
}
service, err := c.Service()
if err != nil {
t.Fatalf("err: %s", err)
}
if service != "foo" {
t.Fatalf("bad: %#v", service)
}
// Test that it exits properly if killed
c.Kill()
if process.ProcessState == nil {
t.Fatal("should have process state")
}
// Test that it knows it is exited
if !c.Exited() {
t.Fatal("should say client has exited")
}
}
func TestClientStart_badVersion(t *testing.T) {
config := &ClientConfig{
Cmd: helperProcess("bad-version"),
StartTimeout: 50 * time.Millisecond,
}
c := NewClient(config)
defer c.Kill()
_, err := c.Start()
if err == nil {
t.Fatal("err should not be nil")
}
}
func TestClient_Start_Timeout(t *testing.T) {
config := &ClientConfig{
Cmd: helperProcess("start-timeout"),
StartTimeout: 50 * time.Millisecond,
}
c := NewClient(config)
defer c.Kill()
_, err := c.Start()
if err == nil {
t.Fatal("err should not be nil")
}
}
func TestClient_Stderr(t *testing.T) {
stderr := new(bytes.Buffer)
process := helperProcess("stderr")
c := NewClient(&ClientConfig{
Cmd: process,
Stderr: stderr,
})
defer c.Kill()
if _, err := c.Start(); err != nil {
t.Fatalf("err: %s", err)
}
for !c.Exited() {
time.Sleep(10 * time.Millisecond)
}
if !strings.Contains(stderr.String(), "HELLO\n") {
t.Fatalf("bad log data: '%s'", stderr.String())
}
if !strings.Contains(stderr.String(), "WORLD\n") {
t.Fatalf("bad log data: '%s'", stderr.String())
}
}
func TestClient_Stdin(t *testing.T) {
// Overwrite stdin for this test with a temporary file
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.Remove(tf.Name())
defer tf.Close()
if _, err = tf.WriteString("hello"); err != nil {
t.Fatalf("error: %s", err)
}
if err = tf.Sync(); err != nil {
t.Fatalf("error: %s", err)
}
if _, err = tf.Seek(0, 0); err != nil {
t.Fatalf("error: %s", err)
}
oldStdin := os.Stdin
defer func() { os.Stdin = oldStdin }()
os.Stdin = tf
process := helperProcess("stdin")
c := NewClient(&ClientConfig{Cmd: process})
defer c.Kill()
_, err = c.Start()
if err != nil {
t.Fatalf("error: %s", err)
}
for {
if c.Exited() {
break
}
time.Sleep(50 * time.Millisecond)
}
if !process.ProcessState.Success() {
t.Fatal("process didn't exit cleanly")
}
}

92
plugin/plugin_test.go Normal file
View File

@ -0,0 +1,92 @@
package plugin
import (
"fmt"
"log"
"os"
"os/exec"
"testing"
"time"
"github.com/hashicorp/terraform/terraform"
)
func helperProcess(s ...string) *exec.Cmd {
cs := []string{"-test.run=TestHelperProcess", "--"}
cs = append(cs, s...)
env := []string{
"GO_WANT_HELPER_PROCESS=1",
"TF_PLUGIN_MIN_PORT=10000",
"TF_PLUGIN_MAX_PORT=25000",
}
cmd := exec.Command(os.Args[0], cs...)
cmd.Env = append(env, os.Environ()...)
return cmd
}
// This is not a real test. This is just a helper process kicked off by
// tests.
func TestHelperProcess(*testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
defer os.Exit(0)
args := os.Args
for len(args) > 0 {
if args[0] == "--" {
args = args[1:]
break
}
args = args[1:]
}
if len(args) == 0 {
fmt.Fprintf(os.Stderr, "No command\n")
os.Exit(2)
}
cmd, args := args[0], args[1:]
switch cmd {
case "bad-version":
fmt.Printf("%s1|tcp|:1234|foo\n", APIVersion)
<-make(chan int)
case "resource-provider":
err := Serve(new(terraform.MockResourceProvider))
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
case "invalid-rpc-address":
fmt.Println("lolinvalid")
case "mock":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion)
<-make(chan int)
case "start-timeout":
time.Sleep(1 * time.Minute)
os.Exit(1)
case "stderr":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion)
log.Println("HELLO")
log.Println("WORLD")
case "stdin":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion)
data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err)
os.Exit(100)
}
if string(data) == "hello" {
os.Exit(0)
}
os.Exit(1)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd)
os.Exit(2)
}
}

View File

@ -0,0 +1,23 @@
package plugin
import (
"testing"
)
func TestResourceProvider(t *testing.T) {
c := NewClient(&ClientConfig{Cmd: helperProcess("resource-provider")})
defer c.Kill()
_, err := c.Client()
if err != nil {
t.Fatalf("should not have error: %s", err)
}
service, err := c.Service()
if err != nil {
t.Fatalf("err: %s", err)
}
if service == "" {
t.Fatal("service should not be blank")
}
}

139
plugin/server.go Normal file
View File

@ -0,0 +1,139 @@
package plugin
import (
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/rpc"
"os"
"os/signal"
"runtime"
"strconv"
"sync/atomic"
tfrpc "github.com/hashicorp/terraform/rpc"
)
// The APIVersion is outputted along with the RPC address. The plugin
// client validates this API version and will show an error if it doesn't
// know how to speak it.
const APIVersion = "1"
// The "magic cookie" is used to verify that the user intended to
// actually run this binary. If this cookie isn't present as an
// environmental variable, then we bail out early with an error.
const MagicCookieKey = "TF_PLUGIN_MAGIC_COOKIE"
const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2"
func Serve(svc interface{}) error {
// First check the cookie
if os.Getenv(MagicCookieKey) != MagicCookieValue {
return errors.New(
"Please do not execute plugins directly. " +
"Terraform will execute these for you.")
}
// Create the server to serve our interface
server := rpc.NewServer()
// Register the service
name, err := tfrpc.Register(server, svc)
if err != nil {
return err
}
// Register a listener so we can accept a connection
listener, err := serverListener()
if err != nil {
return err
}
defer listener.Close()
// Output the address and service name to stdout
log.Printf("Plugin address: %s %s\n",
listener.Addr().Network(), listener.Addr().String())
fmt.Printf("%s|%s|%s|%s\n",
APIVersion,
listener.Addr().Network(),
listener.Addr().String(),
name)
os.Stdout.Sync()
// Accept a connection
log.Println("Waiting for connection...")
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %s\n", err.Error())
return err
}
// Eat the interrupts
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
go func() {
var count int32 = 0
for {
<-ch
newCount := atomic.AddInt32(&count, 1)
log.Printf(
"Received interrupt signal (count: %d). Ignoring.",
newCount)
}
}()
// Serve a single connection
log.Println("Serving a plugin connection...")
server.ServeConn(conn)
return nil
}
func serverListener() (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp()
}
return serverListener_unix()
}
func serverListener_tcp() (net.Listener, error) {
minPort, err := strconv.ParseInt(os.Getenv("TF_PLUGIN_MIN_PORT"), 10, 32)
if err != nil {
return nil, err
}
maxPort, err := strconv.ParseInt(os.Getenv("TF_PLUGIN_MAX_PORT"), 10, 32)
if err != nil {
return nil, err
}
for port := minPort; port <= maxPort; port++ {
address := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", address)
if err == nil {
return listener, nil
}
}
return nil, errors.New("Couldn't bind plugin TCP listener")
}
func serverListener_unix() (net.Listener, error) {
tf, err := ioutil.TempFile("", "tf-plugin")
if err != nil {
return nil, err
}
path := tf.Name()
// Close the file and remove it because it has to not exist for
// the domain socket.
if err := tf.Close(); err != nil {
return nil, err
}
if err := os.Remove(path); err != nil {
return nil, err
}
return net.Listen("unix", path)
}