diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index b7b00db35..a806d1caf 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/convert" "github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function/stdlib" "github.com/zclconf/go-cty/cty/gocty" @@ -106,6 +107,60 @@ var LengthFunc = function.New(&function.Spec{ }, }) +// CoalesceListFunc contructs a function that takes any number of list arguments +// and returns the first one that isn't empty. +var CoalesceListFunc = function.New(&function.Spec{ + Params: []function.Parameter{}, + VarParam: &function.Parameter{ + Name: "vals", + Type: cty.DynamicPseudoType, + AllowUnknown: true, + AllowDynamicType: true, + AllowNull: true, + }, + Type: func(args []cty.Value) (ret cty.Type, err error) { + if len(args) == 0 { + return cty.NilType, fmt.Errorf("at least one argument is required") + } + + argTypes := make([]cty.Type, len(args)) + + for i, arg := range args { + arg, err = convert.Convert(arg, cty.DynamicPseudoType) + if err != nil { + return cty.NilType, fmt.Errorf("all arguments must be lists or tuples") + } + + argTypes[i] = arg.Type() + } + + fmt.Printf("%#v\n", argTypes) + + return cty.List(cty.DynamicPseudoType), nil + }, + Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { + + vals := make([]cty.Value, 0, len(args)) + for _, arg := range args { + + // We already know this will succeed because of the checks in our Type func above + arg, _ = convert.Convert(arg, retType) + + it := arg.ElementIterator() + for it.Next() { + _, v := it.Element() + vals = append(vals, v) + } + + if len(vals) > 0 { + return cty.ListVal(vals), nil + } + } + + return cty.NilVal, fmt.Errorf("no non-null arguments") + }, +}) + // Element returns a single element from a given list at the given index. If // index is greater than the length of the list then it is wrapped modulo // the list length. @@ -118,3 +173,8 @@ func Element(list, index cty.Value) (cty.Value, error) { func Length(collection cty.Value) (cty.Value, error) { return LengthFunc.Call([]cty.Value{collection}) } + +// CoalesceList takes any number of list arguments and returns the first one that isn't empty. +func CoalesceList(args ...cty.Value) (cty.Value, error) { + return CoalesceListFunc.Call(args) +} diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index 62e3c9376..b30709a3c 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -222,3 +222,98 @@ func TestLength(t *testing.T) { }) } } + +func TestCoalesceList(t *testing.T) { + tests := []struct { + Values []cty.Value + Want cty.Value + Err bool + }{ + { + []cty.Value{ + cty.ListVal([]cty.Value{ + cty.StringVal("first"), cty.StringVal("second"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("third"), cty.StringVal("fourth"), + }), + }, + cty.ListVal([]cty.Value{ + cty.StringVal("first"), cty.StringVal("second"), + }), + false, + }, + { + []cty.Value{ + cty.ListValEmpty(cty.String), + cty.ListVal([]cty.Value{ + cty.StringVal("third"), cty.StringVal("fourth"), + }), + }, + cty.ListVal([]cty.Value{ + cty.StringVal("third"), cty.StringVal("fourth"), + }), + false, + }, + { + []cty.Value{ + cty.ListValEmpty(cty.Number), + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + }), + }, + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + }), + false, + }, + { // lists with value type mismatch + []cty.Value{ + cty.ListVal([]cty.Value{ + cty.StringVal("first"), cty.StringVal("second"), + }), + cty.ListVal([]cty.Value{ + cty.NumberIntVal(1), + cty.NumberIntVal(2), + }), + }, + cty.NilVal, + true, + }, + { // mixed list and tuple + []cty.Value{ + cty.ListVal([]cty.Value{ + cty.StringVal("first"), cty.StringVal("second"), + }), + cty.TupleVal([]cty.Value{ + cty.StringVal("third"), + }), + }, + cty.ListVal([]cty.Value{ + cty.StringVal("first"), cty.StringVal("second"), + }), + false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("coalescelist(%#v)", test.Values), func(t *testing.T) { + got, err := CoalesceList(test.Values...) + + if test.Err { + if err == nil { + t.Fatal("succeeded; want error") + } + return + } else if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !got.RawEquals(test.Want) { + t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.Want) + } + }) + } +} diff --git a/lang/functions.go b/lang/functions.go index 0863ef6b3..fd69d9002 100644 --- a/lang/functions.go +++ b/lang/functions.go @@ -43,7 +43,7 @@ func (s *Scope) Functions() map[string]function.Function { "cidrnetmask": funcs.CidrNetmaskFunc, "cidrsubnet": funcs.CidrSubnetFunc, "coalesce": stdlib.CoalesceFunc, - "coalescelist": unimplFunc, // TODO + "coalescelist": funcs.CoalesceListFunc, "compact": unimplFunc, // TODO "concat": stdlib.ConcatFunc, "contains": unimplFunc, // TODO