def test_get_var_by_id(): r1, r2 = MyVariable("v1"), MyVariable("v2") o1 = MyOp("op1")(r1, r2) o1.name = "o1" igo_in_1 = MyVariable("v4") igo_in_2 = MyVariable("v5") igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) r3 = MyVariable("v3") o2 = igo(r3, o1) res = get_node_by_id(o1, "blah") assert res is None res = get_node_by_id([o1, o2], "C") assert res == r2 res = get_node_by_id([o1, o2], "F") exp_res = igo.fgraph.outputs[0].owner assert res == exp_res
def test_PatternPrinter(): r1, r2 = MyVariable("1"), MyVariable("2") op1 = MyOp("op1") o1 = op1(r1, r2) o1.name = "o1" pprint = PPrinter() pprint.assign(op1, PatternPrinter(("|%(0)s - %(1)s|", -1000))) pprint.assign(lambda pstate, r: True, default_printer) res = pprint(o1) assert res == "|1 - 2|"
def test_remove_node_multi_out(self): var1 = MyVariable("var1") var2 = MyVariable("var2") multi_op = MyOp("mop", n_outs=2) op1_out = op1(var1) mop_out_1, mop_out_2 = multi_op(op1_out, var2) op3_out = op3(mop_out_2) fg = FunctionGraph([var1, var2], [mop_out_1, op3_out], clone=False) fg.remove_node(mop_out_1.owner) fg.check_integrity() assert fg.inputs == [var1, var2] assert fg.outputs == [] assert mop_out_1 not in fg.clients assert mop_out_2 not in fg.clients assert mop_out_1 not in fg.variables assert mop_out_2 not in fg.variables mop1_out_1, mop1_out_2 = multi_op(var1) op2_out = op2(mop1_out_1) op3_out = op3(mop1_out_1, mop1_out_2) fg = FunctionGraph([var1], [op2_out, op3_out], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out] # If we only want to track "active" variables in the graphs, the # following would need to be true, as well # assert mop1_out_2 not in fg.clients # assert mop1_out_2 not in fg.variables fg = FunctionGraph([var1], [op2_out, op3_out, mop1_out_2], clone=False) fg.remove_node(op3_out.owner) fg.check_integrity() assert fg.inputs == [var1] assert fg.outputs == [op2_out, mop1_out_2] assert mop1_out_2 in fg.clients assert mop1_out_2 in fg.variables assert mop1_out_2 in fg.outputs
def test_get_var_by_id(): r1, r2 = MyVariable("v1"), MyVariable("v2") o1 = MyOp("op1")(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable("v4") igo_in_2 = MyVariable("v5") igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) r3 = MyVariable("v3") o2 = igo(r3, o1) # import aesara; aesara.dprint([o1, o2]) # op1 [id A] 'o1' # |1 [id B] # |2 [id C] # MyInnerGraphOp [id D] '' # |3 [id E] # |op1 [id A] 'o1' # # Inner graphs: # # MyInnerGraphOp [id D] '' # >op2 [id F] 'igo1' # > |4 [id G] # > |5 [id H] res = get_node_by_id(o1, "blah") assert res is None res = get_node_by_id([o1, o2], "C") assert res == r2 res = get_node_by_id([o1, o2], "F") assert res == igo_out_1.owner
def test_debugprint_inner_graph(): r1, r2 = MyVariable("1"), MyVariable("2") o1 = MyOp("op1")(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable("4") igo_in_2 = MyVariable("5") igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) r3, r4 = MyVariable("3"), MyVariable("4") out = igo(r3, r4) output_str = debugprint(out, file="str") lines = output_str.split("\n") exp_res = """MyInnerGraphOp [id A] |3 [id B] |4 [id C] Inner graphs: MyInnerGraphOp [id A] >op2 [id D] 'igo1' > |*0-<MyType()> [id E] > |*1-<MyType()> [id F] """ for exp_line, res_line in zip(exp_res.split("\n"), lines): assert exp_line.strip() == res_line.strip() # Test nested inner-graph `Op`s igo_2 = MyInnerGraphOp([r3, r4], [out]) r5 = MyVariable("5") out_2 = igo_2(r5) output_str = debugprint(out_2, file="str") lines = output_str.split("\n") exp_res = """MyInnerGraphOp [id A] |5 [id B] Inner graphs: MyInnerGraphOp [id A] >MyInnerGraphOp [id C] > |*0-<MyType()> [id D] > |*1-<MyType()> [id E] MyInnerGraphOp [id C] >op2 [id F] 'igo1' > |*0-<MyType()> [id D] > |*1-<MyType()> [id E] """ for exp_line, res_line in zip(exp_res.split("\n"), lines): assert exp_line.strip() == res_line.strip()