def setup_graph_residual(): """ generate residual structure v0 --[op1]--> v1 -+----------------+--[op3]--> v3 | | +--[op2]--> v2 --+ """ global graph, op1, op2, op3 global v0, v1, v2, v3 v0 = Variable((1, 1), OrderNC) op1 = Operator("op1") v1 = Variable((1, 2), OrderNC) op2 = TestOperator("op2") v2 = Variable((1, 3), OrderNC) op3 = Operator("op3") v3 = Variable((1, 4), OrderNC) op1.append_input("v0", v0) op1.append_output("v1", v1) op2.append_input("v1", v1) op2.append_output("v2", v2) op3.append_input("v1", v1) op3.append_input("v2", v2) op3.append_output("v3", v3) graph = Graph([v0], [v3])
def fn(x: Variable): y = Variable(x.shape, x.order) op = Operator(None) op.append_input("x", x) op.append_output("y", y) return y
def test_get_input_name(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) assert op.get_input_name(v1) == "v1" assert op.get_input_name(v2) == "v2"
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans, = Transpose(None)(orig_var) trans.change_order(target_orders[0]) op.remove_input(orig_var) op.append_input(var_name, trans) return True
def test_replace_input(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.replace_input(v1, v2) assert op.inputs["v1"] == v2 assert v1.input_to == set() assert v2.input_to == {op}
def test_listup_nodes_hidden_output(): v0 = Variable((1, 1), OrderNC) op1 = Operator("op1") v1 = Variable((1, 2), OrderNC) op2 = TestOperator("op2") v2 = Variable((1, 3), OrderNC) op1.append_input("v0", v0) op1.append_output("v1", v1) op2.append_input("v1", v1) op2.append_output("v2", v2) graph = Graph([v0], [v1, v2]) # outputs hidden variable nodes = listup_nodes(graph) assert tuple(nodes) == (v0, op1, v1, op2, v2), str(nodes)
def test_replace_all(): op1 = Operator("op1") op2 = Operator("op2") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op1.append_input("v1", v1) op1.append_output("v2", v2) op1.replace(op2) assert len(op1.inputs) == 0 assert len(op1.outputs) == 0 assert len(op2.inputs) == 1 and op2.inputs["v1"] == v1 assert len(op2.outputs) == 1 and op2.outputs["v2"] == v2 assert v1.input_to == {op2} assert v2.output_from == op2
def test_remove_all(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) v3 = Variable((1, 2, 3, 4), OrderNHWC) v4 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) op.append_output("v3", v3) op.append_output("v4", v4) op.remove_all() assert len(op.inputs) == 0 assert len(op.outputs) == 0 assert v1.input_to == set() assert v2.input_to == set() assert v3.output_from is None assert v4.output_from is None