def test_equal_nodes(): assert equal_nodes(None, None) x = ref('x') y = ref('y') z = ref('z') ix = ampl.Indexing(x) iy = ampl.Indexing(y) check_equal_nodes(x, y) check_equal_nodes(ampl.ParenExpr(x), ampl.UnaryExpr('-', x)) check_equal_nodes(ampl.UnaryExpr('-', x), ampl.UnaryExpr('+', x), ampl.UnaryExpr('-', y)) check_equal_nodes(ampl.BinaryExpr('*', x, y), ampl.BinaryExpr('+', x, y), ampl.BinaryExpr('*', y, y), ampl.BinaryExpr('*', x, x)) check_equal_nodes(ampl.IfExpr(x, y, z), ampl.IfExpr(z, y, z), ampl.IfExpr(x, x, z), ampl.IfExpr(x, y, y)) check_equal_nodes(ampl.CallExpr('sin', [x]), ampl.CallExpr('cos', [x]), ampl.CallExpr('sin', [y]), ampl.CallExpr('sin', [x, y])) check_equal_nodes(ampl.SumExpr(ix, x), ampl.SumExpr(iy, x), ampl.SumExpr(ix, y)) check_equal_nodes(ix, iy) check_equal_nodes(ampl.Decl('var', 'x'), ampl.Decl('var', 'y')) check_equal_nodes(ampl.IncludeStmt('data'), ampl.IncludeStmt('model')) check_equal_nodes(ampl.DataStmt('param', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('var', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'T', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'c'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2', '4'])) assert not equal_nodes(ampl.Decl('var', 'x'), ampl.CompoundStmt([]))
def test_pretty_print(): check_print('a', ref('a')) check_print('a[b]', ampl.SubscriptExpr('a', ref('b'))) check_print('(a)', ampl.ParenExpr(ref('a'))) check_print('-a', ampl.UnaryExpr('-', ref('a'))) check_print('a + b', ampl.BinaryExpr('+', ref('a'), ref('b'))) check_print('if a then b else c', ampl.IfExpr(ref('a'), ref('b'), ref('c'))) check_print('if a then b', ampl.IfExpr(ref('a'), ref('b'), None)) check_print('f(a, b)', ampl.CallExpr('f', [ref('a'), ref('b')])) check_print('sum{s in S} x[s]', ampl.SumExpr(ampl.Indexing(ref('S'), 's'), ampl.SubscriptExpr('x', ref('s')))) check_print('{s in S}', ampl.Indexing(ref('S'), 's')) check_print('{S}', ampl.Indexing(ref('S'))) check_print('= a', ampl.InitAttr(ref('a'))) check_print('in [a, b]', ampl.InAttr(ref('a'), ref('b'))) check_print('var x{S} = a;\n', ampl.Decl('var', 'x', ampl.Indexing(ref('S')), [ampl.InitAttr(ref('a'))])) decl = ampl.Decl('minimize', 'o') decl.body = ampl.UnaryExpr('-', ref('x')) check_print('minimize o: -x;\n', decl) check_print('model;\n', ampl.IncludeStmt('model')) param_names = ['a', 'b'] values = [str(n) for n in range(6)] check_print( 'param:\n' + 'S:a b :=\n' + '0 1 2\n' + '3 4 5\n' + ';\n', ampl.DataStmt('param', 'S', param_names, values)) check_print('model;\nvar x;\n', ampl.CompoundStmt([ampl.IncludeStmt('model'), ampl.Decl('var', 'x')]))
def test_call(): arg0 = ref('a') arg1 = ref('b') expr = ampl.CallExpr('foo', [arg0, arg1]) assert type(expr) == ampl.CallExpr assert expr.func_name == 'foo' assert expr.args == [arg0, arg1] check_accept(expr, 'visit_call')
def test_parse_expr_precedence(): x = ref('x') y = ref('y') z = ref('z') check_parse_lexpr('(x || y)', ampl.ParenExpr(ampl.BinaryExpr('||', x, y))) idx = ampl.Indexing(ref('S')) check_parse_lexpr('sum{S} x * y', ampl.SumExpr(idx, ampl.BinaryExpr('*', x, y))) check_parse_lexpr('sum{S} x + y', ampl.BinaryExpr('+', ampl.SumExpr(idx, x), y)) check_parse_expr('if x || y then z', ampl.IfExpr(ampl.BinaryExpr('||', x, y), z)) check_parse_expr('if x then y & z', ampl.IfExpr(x, ampl.BinaryExpr('&', y, z))) check_parse_lexpr('if x then y in S', ampl.BinaryExpr('in', ampl.IfExpr(x, y), ref('S'))) check_parse_expr('if x then y else x & z', ampl.IfExpr(x, y, ampl.BinaryExpr('&', x, z))) check_parse_lexpr('if x then y else z in S', ampl.BinaryExpr('in', ampl.IfExpr(x, y, z), ref('S'))) check_parse_expr('sin(if x then y)', ampl.CallExpr('sin', [ampl.IfExpr(x, y)])) check_parse_expr('a[if x then y]', ampl.SubscriptExpr('a', ampl.IfExpr(x, y))) check_parse_lexpr('x || y > z', ampl.BinaryExpr('||', x, ampl.BinaryExpr('>', y, z))) check_parse_lexpr('x || y in z', ampl.BinaryExpr('||', x, ampl.BinaryExpr('in', y, z)))
def test_parse_expr(): check_parse_expr('42', ref('42')) check_parse_expr('-x', ampl.UnaryExpr('-', ref('x'))) check_parse_expr('(x)', ampl.ParenExpr(ref('x'))) check_parse_expr('sum{S} x', ampl.SumExpr(ampl.Indexing(ref('S')), ref('x'))) check_parse_expr('sum{s in S} x', ampl.SumExpr(ampl.Indexing(ref('S'), 's'), ref('x'))) check_parse_expr('if x then y', ampl.IfExpr(ref('x'), ref('y'))) check_parse_expr('if x then y else z', ampl.IfExpr(ref('x'), ref('y'), ref('z'))) for f in ['abs', 'acos', 'acosh', 'alias', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'ceil', 'ctime', 'cos', 'exp', 'floor', 'log', 'log10', 'max', 'min', 'precision', 'round', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'time', 'trunc']: check_parse_expr(f + '(x)', ampl.CallExpr(f, [ref('x')])) check_parse_expr('x[y]', ampl.SubscriptExpr('x', ref('y'))) for op in ['^', '**', '+', '-', '*', '/']: check_parse_expr('x ' + op + ' y', ampl.BinaryExpr(op, ref('x'), ref('y'))) for op in ['||', 'or', 'in', '<', '<=', '=', '==', '<>', '!=', '<=', '>']: check_parse_lexpr('x ' + op + ' y', ampl.BinaryExpr(op, ref('x'), ref('y')))