def replace_output(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True): op.replace_output(old_var, new_var, with_assert=with_assert) if old_var in graph.outputs: i = graph.outputs.index(old_var) graph.outputs.remove(old_var) graph.outputs.insert(i, new_var)
def test_replace_output(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_output("v1", v1) op.replace_output(v1, v2) assert op.outputs["v1"] == v2 assert v1.output_from is None assert v2.output_from == op
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) Transpose(None)(v_new)[0].replace(v, with_assert=False) return True
def _replace_output(graph: Graph, op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return _optimize_redundant_transposed_output(graph, op, var_name, target_orders) v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) v_new.transpose(v.order).replace(v, with_assert=False) return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: ConvertRGBAtoR(None)(v_new)[0].replace(v) else: ConvertRtoRGBA(None)(v_new)[0].replace(v) return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: convert_rgba_to_r(v_new).change_order(v.order).replace(v) else: convert_r_to_rgba(v_new).change_order(v.order).replace(v) return True