def replace_input(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True): op.replace_input(old_var, new_var, with_assert=with_assert) if old_var in graph.inputs: i = graph.inputs.index(old_var) graph.inputs.remove(old_var) graph.inputs.insert(i, new_var)
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False op.replace_input(v, v.transpose(target_orders[0]), with_assert=False) return True
def test_replace_input(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.replace_input(v1, v2) assert op.inputs["v1"] == v2 assert v1.input_to == set() assert v2.input_to == {op}
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new, = Transpose(None)(v) op.replace_input(v, v_new, with_assert=False) v_new.change_order(target_orders[0]) return True
def _replace_input(graph: Graph, op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return _optimize_redundant_transposed_input(graph, op, var_name, target_orders) op.replace_input(v, v.transpose(target_orders[0]), with_assert=False) return True
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum): """ before) v -{op}- after) v -{conversion}- v' -{op- """ v = op.inputs[var_name] if ChannelMode.get(v) == target: return False if target == ChannelModeEnum.RGBA: v_new, = ConvertRtoRGBA(None)(v) else: v_new, = ConvertRGBAtoR(None)(v) op.replace_input(v, v_new) return True
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum): """ before) v -{op}- after) v -{conversion}- v' -{op}- """ v = op.inputs[var_name] if ChannelMode.get(v) == target: return False if target == ChannelModeEnum.RGBA: v_new = convert_r_to_rgba(v) else: v_new = convert_rgba_to_r(v) TextureShape.set(v_new, height=TextureShape.get(v)[0], width=TextureShape.get(v)[1]) op.replace_input(v, v_new) return True