diff --git a/config/lang/check_types.go b/config/lang/check_types.go index 882c64ca3..0165eadfc 100644 --- a/config/lang/check_types.go +++ b/config/lang/check_types.go @@ -24,9 +24,19 @@ type TypeCheck struct { // value is the function to call (which must be registered in the Scope). Implicit map[ast.Type]map[ast.Type]string - stack []ast.Type - err error - lock sync.Mutex + // Stack of types. This shouldn't be used directly except by implementations + // of TypeCheckNode. + Stack []ast.Type + + err error + lock sync.Mutex +} + +// TypeCheckNode is the interface that must be implemented by any +// ast.Node that wants to support type-checking. If the type checker +// encounters a node that doesn't implement this, it will error. +type TypeCheckNode interface { + TypeCheck(*TypeCheck) (ast.Node, error) } func (v *TypeCheck) Visit(root ast.Node) error { @@ -42,49 +52,69 @@ func (v *TypeCheck) visit(raw ast.Node) ast.Node { return raw } + var result ast.Node + var err error switch n := raw.(type) { case *ast.Call: - v.visitCall(n) + tc := &typeCheckCall{n} + result, err = tc.TypeCheck(v) case *ast.Concat: - v.visitConcat(n) + tc := &typeCheckConcat{n} + result, err = tc.TypeCheck(v) case *ast.LiteralNode: - v.visitLiteral(n) + tc := &typeCheckLiteral{n} + result, err = tc.TypeCheck(v) case *ast.VariableAccess: - v.visitVariableAccess(n) + tc := &typeCheckVariableAccess{n} + result, err = tc.TypeCheck(v) default: - v.createErr(n, fmt.Sprintf("unknown node: %#v", raw)) + tc, ok := raw.(TypeCheckNode) + if !ok { + err = fmt.Errorf("unknown node: %#v", raw) + break + } + + result, err = tc.TypeCheck(v) } - return raw + if err != nil { + pos := raw.Pos() + v.err = fmt.Errorf("At column %d, line %d: %s", + pos.Column, pos.Line, err) + } + + return result } -func (v *TypeCheck) visitCall(n *ast.Call) { +type typeCheckCall struct { + n *ast.Call +} + +func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) { // Look up the function in the map - function, ok := v.Scope.LookupFunc(n.Func) + function, ok := v.Scope.LookupFunc(tc.n.Func) if !ok { - v.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func)) - return + return nil, fmt.Errorf("unknown function called: %s", tc.n.Func) } // The arguments are on the stack in reverse order, so pop them off. - args := make([]ast.Type, len(n.Args)) - for i, _ := range n.Args { - args[len(n.Args)-1-i] = v.stackPop() + args := make([]ast.Type, len(tc.n.Args)) + for i, _ := range tc.n.Args { + args[len(tc.n.Args)-1-i] = v.StackPop() } // Verify the args for i, expected := range function.ArgTypes { if args[i] != expected { - cn := v.implicitConversion(args[i], expected, n.Args[i]) + cn := v.ImplicitConversion(args[i], expected, tc.n.Args[i]) if cn != nil { - n.Args[i] = cn + tc.n.Args[i] = cn continue } - v.createErr(n, fmt.Sprintf( + return nil, fmt.Errorf( "%s: argument %d should be %s, got %s", - n.Func, i+1, expected, args[i])) - return + tc.n.Func, i+1, expected, args[i]) } } @@ -94,75 +124,86 @@ func (v *TypeCheck) visitCall(n *ast.Call) { for i, t := range args { if t != function.VariadicType { realI := i + len(function.ArgTypes) - cn := v.implicitConversion( - t, function.VariadicType, n.Args[realI]) + cn := v.ImplicitConversion( + t, function.VariadicType, tc.n.Args[realI]) if cn != nil { - n.Args[realI] = cn + tc.n.Args[realI] = cn continue } - v.createErr(n, fmt.Sprintf( + return nil, fmt.Errorf( "%s: argument %d should be %s, got %s", - n.Func, realI, - function.VariadicType, t)) - return + tc.n.Func, realI, + function.VariadicType, t) } } } // Return type - v.stackPush(function.ReturnType) + v.StackPush(function.ReturnType) + + return tc.n, nil } -func (v *TypeCheck) visitConcat(n *ast.Concat) { +type typeCheckConcat struct { + n *ast.Concat +} + +func (tc *typeCheckConcat) TypeCheck(v *TypeCheck) (ast.Node, error) { + n := tc.n types := make([]ast.Type, len(n.Exprs)) for i, _ := range n.Exprs { - types[len(n.Exprs)-1-i] = v.stackPop() + types[len(n.Exprs)-1-i] = v.StackPop() } // All concat args must be strings, so validate that for i, t := range types { if t != ast.TypeString { - cn := v.implicitConversion(t, ast.TypeString, n.Exprs[i]) + cn := v.ImplicitConversion(t, ast.TypeString, n.Exprs[i]) if cn != nil { n.Exprs[i] = cn continue } - v.createErr(n, fmt.Sprintf( - "argument %d must be a string", i+1)) - return + return nil, fmt.Errorf( + "argument %d must be a string", i+1) } } // This always results in type string - v.stackPush(ast.TypeString) + v.StackPush(ast.TypeString) + + return n, nil } -func (v *TypeCheck) visitLiteral(n *ast.LiteralNode) { - v.stackPush(n.Typex) +type typeCheckLiteral struct { + n *ast.LiteralNode } -func (v *TypeCheck) visitVariableAccess(n *ast.VariableAccess) { +func (tc *typeCheckLiteral) TypeCheck(v *TypeCheck) (ast.Node, error) { + v.StackPush(tc.n.Typex) + return tc.n, nil +} + +type typeCheckVariableAccess struct { + n *ast.VariableAccess +} + +func (tc *typeCheckVariableAccess) TypeCheck(v *TypeCheck) (ast.Node, error) { // Look up the variable in the map - variable, ok := v.Scope.LookupVar(n.Name) + variable, ok := v.Scope.LookupVar(tc.n.Name) if !ok { - v.createErr(n, fmt.Sprintf( - "unknown variable accessed: %s", n.Name)) - return + return nil, fmt.Errorf( + "unknown variable accessed: %s", tc.n.Name) } // Add the type to the stack - v.stackPush(variable.Type) + v.StackPush(variable.Type) + + return tc.n, nil } -func (v *TypeCheck) createErr(n ast.Node, str string) { - pos := n.Pos() - v.err = fmt.Errorf("At column %d, line %d: %s", - pos.Column, pos.Line, str) -} - -func (v *TypeCheck) implicitConversion( +func (v *TypeCheck) ImplicitConversion( actual ast.Type, expected ast.Type, n ast.Node) ast.Node { if v.Implicit == nil { return nil @@ -186,16 +227,16 @@ func (v *TypeCheck) implicitConversion( } func (v *TypeCheck) reset() { - v.stack = nil + v.Stack = nil v.err = nil } -func (v *TypeCheck) stackPush(t ast.Type) { - v.stack = append(v.stack, t) +func (v *TypeCheck) StackPush(t ast.Type) { + v.Stack = append(v.Stack, t) } -func (v *TypeCheck) stackPop() ast.Type { +func (v *TypeCheck) StackPop() ast.Type { var x ast.Type - x, v.stack = v.stack[len(v.stack)-1], v.stack[:len(v.stack)-1] + x, v.Stack = v.Stack[len(v.Stack)-1], v.Stack[:len(v.Stack)-1] return x }