Beispiel #1
0
    def replace_output(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True):
        op.replace_output(old_var, new_var, with_assert=with_assert)

        if old_var in graph.outputs:
            i = graph.outputs.index(old_var)
            graph.outputs.remove(old_var)
            graph.outputs.insert(i, new_var)
Beispiel #2
0
def test_replace_output():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.replace_output(v1, v2)

    assert op.outputs["v1"] == v2
    assert v1.output_from is None
    assert v2.output_from == op
Beispiel #3
0
def _replace_output(op: Operator, var_name: str,
                    target_orders: Union[Order, List[Order]]):
    v = op.outputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    v_new = Variable(v.shape, v.order).change_order(target_orders[0])
    op.replace_output(v, v_new, with_assert=False)
    Transpose(None)(v_new)[0].replace(v, with_assert=False)
    return True
Beispiel #4
0
def _replace_output(graph: Graph, op: Operator, var_name: str,
                    target_orders: Union[Order, List[Order]]):
    v = op.outputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]

    if v.order in target_orders:
        return _optimize_redundant_transposed_output(graph, op, var_name,
                                                     target_orders)

    v_new = Variable(v.shape, v.order).change_order(target_orders[0])
    op.replace_output(v, v_new, with_assert=False)
    v_new.transpose(v.order).replace(v, with_assert=False)
    return True
Beispiel #5
0
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        -{op}- v

    after)

        -{op}- v' -{conversion}- v
    """
    v = op.outputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    v_new = Variable(v.shape, v.order)
    ChannelMode.set(v_new, target)

    op.replace_output(v, v_new)
    if target == ChannelModeEnum.RGBA:
        ConvertRGBAtoR(None)(v_new)[0].replace(v)
    else:
        ConvertRtoRGBA(None)(v_new)[0].replace(v)
    return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        -{op}- v

    after)

        -{op}- v' -{conversion}- v
    """
    v = op.outputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    v_new = Variable(v.shape, v.order)
    ChannelMode.set(v_new, target)

    op.replace_output(v, v_new)
    if target == ChannelModeEnum.RGBA:
        convert_rgba_to_r(v_new).change_order(v.order).replace(v)
    else:
        convert_r_to_rgba(v_new).change_order(v.order).replace(v)
    return True