def test_remove_output(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_output("v1", v1) op.append_output("v2", v2) op.remove_output(v1) assert "v1" not in op.outputs assert op.outputs["v2"] == v2 assert v1.output_from is None assert v2.output_from == op
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans = Variable(orig_var.shape, orig_var.order) trans.change_order(target_orders[0]) op.remove_output(orig_var) op.append_output(var_name, trans) transpose_op = Transpose(None) dummy_out, = transpose_op(trans) transpose_op.remove_output(dummy_out) transpose_op.append_output("y", orig_var) return True