diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index ee3ad2b3a..5422a0f9d 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -427,32 +427,47 @@ var FlattenFunc = function.New(&function.Spec{ Params: []function.Parameter{ { Name: "list", - Type: cty.List(cty.DynamicPseudoType), + Type: cty.DynamicPseudoType, }, }, - Type: function.StaticReturnType(cty.List(cty.DynamicPseudoType)), + Type: func(args []cty.Value) (cty.Type, error) { + if !args[0].IsWhollyKnown() { + return cty.DynamicPseudoType, nil + } + + argTy := args[0].Type() + if !argTy.IsListType() && !argTy.IsSetType() && !argTy.IsTupleType() { + return cty.NilType, fmt.Errorf("can only flatten lists, sets and tuples") + } + + outputList := make([]cty.Value, 0) + retVal := flattener(outputList, args[0]) + tys := make([]cty.Type, len(retVal)) + for i, ty := range retVal { + tys[i] = ty.Type() + } + return cty.Tuple(tys), nil + }, Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { inputList := args[0] if !inputList.IsWhollyKnown() { return cty.UnknownVal(retType), nil } - if inputList.LengthInt() == 0 { - return cty.ListValEmpty(retType.ElementType()), nil + return cty.EmptyTupleVal, nil } outputList := make([]cty.Value, 0) - return cty.ListVal(flattener(outputList, inputList)), nil + return cty.TupleVal(flattener(outputList, inputList)), nil }, }) // Flatten until it's not a cty.List func flattener(finalList []cty.Value, flattenList cty.Value) []cty.Value { - for it := flattenList.ElementIterator(); it.Next(); { _, val := it.Element() - if val.Type().IsListType() { + if val.Type().IsListType() || val.Type().IsSetType() || val.Type().IsTupleType() { finalList = flattener(finalList, val) } else { finalList = append(finalList, val) diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index ac686fa72..4339d3af4 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -1035,7 +1035,7 @@ func TestFlatten(t *testing.T) { cty.StringVal("d"), }), }), - cty.ListVal([]cty.Value{ + cty.TupleVal([]cty.Value{ cty.StringVal("a"), cty.StringVal("b"), cty.StringVal("c"), @@ -1054,12 +1054,50 @@ func TestFlatten(t *testing.T) { cty.StringVal("d"), }), }), - cty.UnknownVal(cty.List(cty.DynamicPseudoType)), + cty.DynamicVal, false, }, { cty.ListValEmpty(cty.String), - cty.ListValEmpty(cty.DynamicPseudoType), + cty.EmptyTupleVal, + false, + }, + { + cty.SetVal([]cty.Value{ + cty.SetVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + }), + cty.SetVal([]cty.Value{ + cty.StringVal("c"), + cty.StringVal("d"), + }), + }), + cty.TupleVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + cty.StringVal("d"), + }), + false, + }, + { + cty.TupleVal([]cty.Value{ + cty.SetVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("c"), + cty.StringVal("d"), + }), + }), + cty.TupleVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + cty.StringVal("d"), + }), false, }, } diff --git a/lang/functions_test.go b/lang/functions_test.go index 9bab6887d..9233d0ff9 100644 --- a/lang/functions_test.go +++ b/lang/functions_test.go @@ -309,8 +309,8 @@ func TestFunctions(t *testing.T) { "flatten": { { - `flatten([tolist(["a", "b"]), tolist(["c", "d"])])`, - cty.ListVal([]cty.Value{ + `flatten([["a", "b"], ["c", "d"]])`, + cty.TupleVal([]cty.Value{ cty.StringVal("a"), cty.StringVal("b"), cty.StringVal("c"),