Example #1
0
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.inputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans, = Transpose(None)(orig_var)
    trans.change_order(target_orders[0])
    op.remove_input(orig_var)
    op.append_input(var_name, trans)
    return True
Example #2
0
def test_remove_input():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)
    op.remove_input(v1)

    assert "v1" not in op.inputs
    assert op.inputs["v2"] == v2
    assert v1.input_to == set()
    assert v2.input_to == {op}