def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans, = Transpose(None)(orig_var) trans.change_order(target_orders[0]) op.remove_input(orig_var) op.append_input(var_name, trans) return True
def test_remove_input(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) op.remove_input(v1) assert "v1" not in op.inputs assert op.inputs["v2"] == v2 assert v1.input_to == set() assert v2.input_to == {op}