def test_equal_nodes(): assert equal_nodes(None, None) x = ref('x') y = ref('y') z = ref('z') ix = ampl.Indexing(x) iy = ampl.Indexing(y) check_equal_nodes(x, y) check_equal_nodes(ampl.ParenExpr(x), ampl.UnaryExpr('-', x)) check_equal_nodes(ampl.UnaryExpr('-', x), ampl.UnaryExpr('+', x), ampl.UnaryExpr('-', y)) check_equal_nodes(ampl.BinaryExpr('*', x, y), ampl.BinaryExpr('+', x, y), ampl.BinaryExpr('*', y, y), ampl.BinaryExpr('*', x, x)) check_equal_nodes(ampl.IfExpr(x, y, z), ampl.IfExpr(z, y, z), ampl.IfExpr(x, x, z), ampl.IfExpr(x, y, y)) check_equal_nodes(ampl.CallExpr('sin', [x]), ampl.CallExpr('cos', [x]), ampl.CallExpr('sin', [y]), ampl.CallExpr('sin', [x, y])) check_equal_nodes(ampl.SumExpr(ix, x), ampl.SumExpr(iy, x), ampl.SumExpr(ix, y)) check_equal_nodes(ix, iy) check_equal_nodes(ampl.Decl('var', 'x'), ampl.Decl('var', 'y')) check_equal_nodes(ampl.IncludeStmt('data'), ampl.IncludeStmt('model')) check_equal_nodes(ampl.DataStmt('param', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('var', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'T', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'c'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2', '4'])) assert not equal_nodes(ampl.Decl('var', 'x'), ampl.CompoundStmt([]))
def test_pretty_print(): check_print('a', ref('a')) check_print('a[b]', ampl.SubscriptExpr('a', ref('b'))) check_print('(a)', ampl.ParenExpr(ref('a'))) check_print('-a', ampl.UnaryExpr('-', ref('a'))) check_print('a + b', ampl.BinaryExpr('+', ref('a'), ref('b'))) check_print('if a then b else c', ampl.IfExpr(ref('a'), ref('b'), ref('c'))) check_print('if a then b', ampl.IfExpr(ref('a'), ref('b'), None)) check_print('f(a, b)', ampl.CallExpr('f', [ref('a'), ref('b')])) check_print('sum{s in S} x[s]', ampl.SumExpr(ampl.Indexing(ref('S'), 's'), ampl.SubscriptExpr('x', ref('s')))) check_print('{s in S}', ampl.Indexing(ref('S'), 's')) check_print('{S}', ampl.Indexing(ref('S'))) check_print('= a', ampl.InitAttr(ref('a'))) check_print('in [a, b]', ampl.InAttr(ref('a'), ref('b'))) check_print('var x{S} = a;\n', ampl.Decl('var', 'x', ampl.Indexing(ref('S')), [ampl.InitAttr(ref('a'))])) decl = ampl.Decl('minimize', 'o') decl.body = ampl.UnaryExpr('-', ref('x')) check_print('minimize o: -x;\n', decl) check_print('model;\n', ampl.IncludeStmt('model')) param_names = ['a', 'b'] values = [str(n) for n in range(6)] check_print( 'param:\n' + 'S:a b :=\n' + '0 1 2\n' + '3 4 5\n' + ';\n', ampl.DataStmt('param', 'S', param_names, values)) check_print('model;\nvar x;\n', ampl.CompoundStmt([ampl.IncludeStmt('model'), ampl.Decl('var', 'x')]))
def test_indexing(): index = 'a' set_expr = ref('b') expr = ampl.Indexing(set_expr, index) assert type(expr) == ampl.Indexing assert expr.index == index assert expr.set_expr == set_expr expr = ampl.Indexing(set_expr) assert expr.index == None assert expr.set_expr == set_expr check_accept(expr, 'visit_indexing')
def test_sum(): indexing = ampl.Indexing(ref('a')) arg = ref('b') expr = ampl.SumExpr(indexing, arg) assert type(expr) == ampl.SumExpr assert expr.indexing == indexing assert expr.arg == arg check_accept(expr, 'visit_sum')
def test_parse_expr(): check_parse_expr('42', ref('42')) check_parse_expr('-x', ampl.UnaryExpr('-', ref('x'))) check_parse_expr('(x)', ampl.ParenExpr(ref('x'))) check_parse_expr('sum{S} x', ampl.SumExpr(ampl.Indexing(ref('S')), ref('x'))) check_parse_expr('sum{s in S} x', ampl.SumExpr(ampl.Indexing(ref('S'), 's'), ref('x'))) check_parse_expr('if x then y', ampl.IfExpr(ref('x'), ref('y'))) check_parse_expr('if x then y else z', ampl.IfExpr(ref('x'), ref('y'), ref('z'))) for f in ['abs', 'acos', 'acosh', 'alias', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'ceil', 'ctime', 'cos', 'exp', 'floor', 'log', 'log10', 'max', 'min', 'precision', 'round', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'time', 'trunc']: check_parse_expr(f + '(x)', ampl.CallExpr(f, [ref('x')])) check_parse_expr('x[y]', ampl.SubscriptExpr('x', ref('y'))) for op in ['^', '**', '+', '-', '*', '/']: check_parse_expr('x ' + op + ' y', ampl.BinaryExpr(op, ref('x'), ref('y'))) for op in ['||', 'or', 'in', '<', '<=', '=', '==', '<>', '!=', '<=', '>']: check_parse_lexpr('x ' + op + ' y', ampl.BinaryExpr(op, ref('x'), ref('y')))
def test_decl(): indexing = ampl.Indexing(ref('a')) attrs = [ampl.InitAttr(ref('a'))] decl = ampl.Decl('var', 'x', indexing, attrs) assert type(decl) == ampl.Decl assert decl.kind == 'var' assert decl.name == 'x' assert decl.indexing == indexing assert decl.attrs == attrs assert decl.body == None decl = ampl.Decl('var', 'x', indexing) assert decl.attrs == [] decl = ampl.Decl('var', 'x') assert decl.indexing == None check_accept(decl, 'visit_decl')
def test_parse_expr_precedence(): x = ref('x') y = ref('y') z = ref('z') check_parse_lexpr('(x || y)', ampl.ParenExpr(ampl.BinaryExpr('||', x, y))) idx = ampl.Indexing(ref('S')) check_parse_lexpr('sum{S} x * y', ampl.SumExpr(idx, ampl.BinaryExpr('*', x, y))) check_parse_lexpr('sum{S} x + y', ampl.BinaryExpr('+', ampl.SumExpr(idx, x), y)) check_parse_expr('if x || y then z', ampl.IfExpr(ampl.BinaryExpr('||', x, y), z)) check_parse_expr('if x then y & z', ampl.IfExpr(x, ampl.BinaryExpr('&', y, z))) check_parse_lexpr('if x then y in S', ampl.BinaryExpr('in', ampl.IfExpr(x, y), ref('S'))) check_parse_expr('if x then y else x & z', ampl.IfExpr(x, y, ampl.BinaryExpr('&', x, z))) check_parse_lexpr('if x then y else z in S', ampl.BinaryExpr('in', ampl.IfExpr(x, y, z), ref('S'))) check_parse_expr('sin(if x then y)', ampl.CallExpr('sin', [ampl.IfExpr(x, y)])) check_parse_expr('a[if x then y]', ampl.SubscriptExpr('a', ampl.IfExpr(x, y))) check_parse_lexpr('x || y > z', ampl.BinaryExpr('||', x, ampl.BinaryExpr('>', y, z))) check_parse_lexpr('x || y in z', ampl.BinaryExpr('||', x, ampl.BinaryExpr('in', y, z)))
def test_parse_decls(): s_decl = ampl.Decl('set', 'S') indexing = ampl.Indexing(ref('S')) for kw in ['param', 'var', 'set', 'minimize', 'maximize']: check_parse(kw + ' a;', ampl.Decl(kw, 'a')) check_parse('set S; ' + kw + ' a{S};', s_decl, ampl.Decl(kw, 'a', indexing))