예제 #1
0
def setup_graph_residual():
    """
    generate residual structure

    v0 --[op1]--> v1 -+----------------+--[op3]--> v3
                      |                |
                      +--[op2]--> v2 --+
    """
    global graph, op1, op2, op3
    global v0, v1, v2, v3

    v0 = Variable((1, 1), OrderNC)
    op1 = Operator("op1")
    v1 = Variable((1, 2), OrderNC)
    op2 = TestOperator("op2")
    v2 = Variable((1, 3), OrderNC)
    op3 = Operator("op3")
    v3 = Variable((1, 4), OrderNC)

    op1.append_input("v0", v0)
    op1.append_output("v1", v1)

    op2.append_input("v1", v1)
    op2.append_output("v2", v2)

    op3.append_input("v1", v1)
    op3.append_input("v2", v2)
    op3.append_output("v3", v3)

    graph = Graph([v0], [v3])
def fn(x: Variable):
    y = Variable(x.shape, x.order)
    op = Operator(None)

    op.append_input("x", x)
    op.append_output("y", y)

    return y
예제 #3
0
def test_get_input_name():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)

    assert op.get_input_name(v1) == "v1"
    assert op.get_input_name(v2) == "v2"
예제 #4
0
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.inputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans, = Transpose(None)(orig_var)
    trans.change_order(target_orders[0])
    op.remove_input(orig_var)
    op.append_input(var_name, trans)
    return True
예제 #5
0
def test_replace_input():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.replace_input(v1, v2)

    assert op.inputs["v1"] == v2
    assert v1.input_to == set()
    assert v2.input_to == {op}
예제 #6
0
def test_listup_nodes_hidden_output():
    v0 = Variable((1, 1), OrderNC)
    op1 = Operator("op1")
    v1 = Variable((1, 2), OrderNC)
    op2 = TestOperator("op2")
    v2 = Variable((1, 3), OrderNC)

    op1.append_input("v0", v0)
    op1.append_output("v1", v1)
    op2.append_input("v1", v1)
    op2.append_output("v2", v2)

    graph = Graph([v0], [v1, v2])  # outputs hidden variable

    nodes = listup_nodes(graph)

    assert tuple(nodes) == (v0, op1, v1, op2, v2), str(nodes)
예제 #7
0
def test_replace_all():
    op1 = Operator("op1")
    op2 = Operator("op2")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op1.append_input("v1", v1)
    op1.append_output("v2", v2)

    op1.replace(op2)

    assert len(op1.inputs) == 0
    assert len(op1.outputs) == 0
    assert len(op2.inputs) == 1 and op2.inputs["v1"] == v1
    assert len(op2.outputs) == 1 and op2.outputs["v2"] == v2
    assert v1.input_to == {op2}
    assert v2.output_from == op2
예제 #8
0
def test_remove_all():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)
    v3 = Variable((1, 2, 3, 4), OrderNHWC)
    v4 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)
    op.append_output("v3", v3)
    op.append_output("v4", v4)

    op.remove_all()

    assert len(op.inputs) == 0
    assert len(op.outputs) == 0
    assert v1.input_to == set()
    assert v2.input_to == set()
    assert v3.output_from is None
    assert v4.output_from is None