Ejemplo n.º 1
0
def setup_graph_residual():
    """
    generate residual structure

    v0 --[op1]--> v1 -+----------------+--[op3]--> v3
                      |                |
                      +--[op2]--> v2 --+
    """
    global graph, op1, op2, op3
    global v0, v1, v2, v3

    v0 = Variable((1, 1), OrderNC)
    op1 = Operator("op1")
    v1 = Variable((1, 2), OrderNC)
    op2 = TestOperator("op2")
    v2 = Variable((1, 3), OrderNC)
    op3 = Operator("op3")
    v3 = Variable((1, 4), OrderNC)

    op1.append_input("v0", v0)
    op1.append_output("v1", v1)

    op2.append_input("v1", v1)
    op2.append_output("v2", v2)

    op3.append_input("v1", v1)
    op3.append_input("v2", v2)
    op3.append_output("v3", v3)

    graph = Graph([v0], [v3])
def fn(x: Variable):
    y = Variable(x.shape, x.order)
    op = Operator(None)

    op.append_input("x", x)
    op.append_output("y", y)

    return y
Ejemplo n.º 3
0
def test_get_output_name():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.append_output("v2", v2)

    assert op.get_output_name(v1) == "v1"
    assert op.get_output_name(v2) == "v2"
Ejemplo n.º 4
0
def test_replace_output():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.replace_output(v1, v2)

    assert op.outputs["v1"] == v2
    assert v1.output_from is None
    assert v2.output_from == op
Ejemplo n.º 5
0
def test_append_output():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.append_output("v2", v2)

    assert op.outputs["v1"] == v1
    assert op.outputs["v2"] == v2
    assert v1.output_from == op
    assert v2.output_from == op
Ejemplo n.º 6
0
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.outputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans = Variable(orig_var.shape, orig_var.order)
    trans.change_order(target_orders[0])
    op.remove_output(orig_var)
    op.append_output(var_name, trans)
    transpose_op = Transpose(None)
    dummy_out, = transpose_op(trans)
    transpose_op.remove_output(dummy_out)
    transpose_op.append_output("y", orig_var)
    return True
Ejemplo n.º 7
0
def test_listup_nodes_hidden_output():
    v0 = Variable((1, 1), OrderNC)
    op1 = Operator("op1")
    v1 = Variable((1, 2), OrderNC)
    op2 = TestOperator("op2")
    v2 = Variable((1, 3), OrderNC)

    op1.append_input("v0", v0)
    op1.append_output("v1", v1)
    op2.append_input("v1", v1)
    op2.append_output("v2", v2)

    graph = Graph([v0], [v1, v2])  # outputs hidden variable

    nodes = listup_nodes(graph)

    assert tuple(nodes) == (v0, op1, v1, op2, v2), str(nodes)
Ejemplo n.º 8
0
def test_replace_all():
    op1 = Operator("op1")
    op2 = Operator("op2")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op1.append_input("v1", v1)
    op1.append_output("v2", v2)

    op1.replace(op2)

    assert len(op1.inputs) == 0
    assert len(op1.outputs) == 0
    assert len(op2.inputs) == 1 and op2.inputs["v1"] == v1
    assert len(op2.outputs) == 1 and op2.outputs["v2"] == v2
    assert v1.input_to == {op2}
    assert v2.output_from == op2
Ejemplo n.º 9
0
def test_remove_all():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)
    v3 = Variable((1, 2, 3, 4), OrderNHWC)
    v4 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)
    op.append_output("v3", v3)
    op.append_output("v4", v4)

    op.remove_all()

    assert len(op.inputs) == 0
    assert len(op.outputs) == 0
    assert v1.input_to == set()
    assert v2.input_to == set()
    assert v3.output_from is None
    assert v4.output_from is None