def test_equal_nodes(): assert equal_nodes(None, None) x = ref('x') y = ref('y') z = ref('z') ix = ampl.Indexing(x) iy = ampl.Indexing(y) check_equal_nodes(x, y) check_equal_nodes(ampl.ParenExpr(x), ampl.UnaryExpr('-', x)) check_equal_nodes(ampl.UnaryExpr('-', x), ampl.UnaryExpr('+', x), ampl.UnaryExpr('-', y)) check_equal_nodes(ampl.BinaryExpr('*', x, y), ampl.BinaryExpr('+', x, y), ampl.BinaryExpr('*', y, y), ampl.BinaryExpr('*', x, x)) check_equal_nodes(ampl.IfExpr(x, y, z), ampl.IfExpr(z, y, z), ampl.IfExpr(x, x, z), ampl.IfExpr(x, y, y)) check_equal_nodes(ampl.CallExpr('sin', [x]), ampl.CallExpr('cos', [x]), ampl.CallExpr('sin', [y]), ampl.CallExpr('sin', [x, y])) check_equal_nodes(ampl.SumExpr(ix, x), ampl.SumExpr(iy, x), ampl.SumExpr(ix, y)) check_equal_nodes(ix, iy) check_equal_nodes(ampl.Decl('var', 'x'), ampl.Decl('var', 'y')) check_equal_nodes(ampl.IncludeStmt('data'), ampl.IncludeStmt('model')) check_equal_nodes(ampl.DataStmt('param', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('var', 'S', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'T', ['a', 'b'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'c'], ['1', '2', '3']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2']), ampl.DataStmt('param', 'S', ['a', 'v'], ['1', '2', '4'])) assert not equal_nodes(ampl.Decl('var', 'x'), ampl.CompoundStmt([]))
def merge_models(models): """ Merge given AMPL models into a single one using product composition of objective functions. For example, two models minimize o: f1(x); and minimize o: f2(x); are combined into a single model minimize o: f1(x1) * f2(x2); """ merged_head = [] merged_tail = [] merged_best_obj = 1 merged_obj = ampl.Decl('minimize', 'f') for i in range(len(models)): head, obj, tail, best_obj = prepare_for_merge(models[i], i + 1) merged_head += head merged_tail += tail merged_best_obj *= best_obj if merged_obj.body: merged_obj.body = ampl.BinaryExpr('*', merged_obj.body, obj.body) else: merged_obj.body = obj.body # Invert sign if objectives are of different kinds. return ampl.CompoundStmt(merged_head + [merged_obj] + merged_tail), merged_best_obj
def test_pretty_print(): check_print('a', ref('a')) check_print('a[b]', ampl.SubscriptExpr('a', ref('b'))) check_print('(a)', ampl.ParenExpr(ref('a'))) check_print('-a', ampl.UnaryExpr('-', ref('a'))) check_print('a + b', ampl.BinaryExpr('+', ref('a'), ref('b'))) check_print('if a then b else c', ampl.IfExpr(ref('a'), ref('b'), ref('c'))) check_print('if a then b', ampl.IfExpr(ref('a'), ref('b'), None)) check_print('f(a, b)', ampl.CallExpr('f', [ref('a'), ref('b')])) check_print('sum{s in S} x[s]', ampl.SumExpr(ampl.Indexing(ref('S'), 's'), ampl.SubscriptExpr('x', ref('s')))) check_print('{s in S}', ampl.Indexing(ref('S'), 's')) check_print('{S}', ampl.Indexing(ref('S'))) check_print('= a', ampl.InitAttr(ref('a'))) check_print('in [a, b]', ampl.InAttr(ref('a'), ref('b'))) check_print('var x{S} = a;\n', ampl.Decl('var', 'x', ampl.Indexing(ref('S')), [ampl.InitAttr(ref('a'))])) decl = ampl.Decl('minimize', 'o') decl.body = ampl.UnaryExpr('-', ref('x')) check_print('minimize o: -x;\n', decl) check_print('model;\n', ampl.IncludeStmt('model')) param_names = ['a', 'b'] values = [str(n) for n in range(6)] check_print( 'param:\n' + 'S:a b :=\n' + '0 1 2\n' + '3 4 5\n' + ';\n', ampl.DataStmt('param', 'S', param_names, values)) check_print('model;\nvar x;\n', ampl.CompoundStmt([ampl.IncludeStmt('model'), ampl.Decl('var', 'x')]))
def check_equal_nodes(node, *inequal_nodes): "Check if equal_nodes compares nodes recursively." assert equal_nodes(node, node) assert equal_nodes(node, copy.deepcopy(node)) assert not equal_nodes(node, None) assert not equal_nodes(node, ref('x') if type(node) != ampl.Reference else ampl.CompoundStmt([])) for i in inequal_nodes: assert not equal_nodes(node, i)
def check_parse_expr(input, expr): "Check parsing arithmetic expression." assert equal_nodes(ampl.parse('var x = {};'.format(input), 'in'), ampl.CompoundStmt([ampl.Decl('var', 'x', None, [ampl.InitAttr(expr)])]))
def check_parse(input, *nodes): assert equal_nodes(ampl.parse(input, 'in'), ampl.CompoundStmt(nodes))
def test_compound(): nodes = [ampl.IncludeStmt('model'), ampl.Decl('var', 'x')] stmt = ampl.CompoundStmt(nodes) assert type(stmt) == ampl.CompoundStmt assert stmt.nodes == nodes check_accept(stmt, 'visit_compound')