Beispiel #1
0
    def replace_input(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True):
        op.replace_input(old_var, new_var, with_assert=with_assert)

        if old_var in graph.inputs:
            i = graph.inputs.index(old_var)
            graph.inputs.remove(old_var)
            graph.inputs.insert(i, new_var)
Beispiel #2
0
def _replace_input(op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    op.replace_input(v, v.transpose(target_orders[0]), with_assert=False)
    return True
Beispiel #3
0
def test_replace_input():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.replace_input(v1, v2)

    assert op.inputs["v1"] == v2
    assert v1.input_to == set()
    assert v2.input_to == {op}
Beispiel #4
0
def _replace_input(op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    v_new, = Transpose(None)(v)
    op.replace_input(v, v_new, with_assert=False)
    v_new.change_order(target_orders[0])
    return True
Beispiel #5
0
def _replace_input(graph: Graph, op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]

    if v.order in target_orders:
        return _optimize_redundant_transposed_input(graph, op, var_name,
                                                    target_orders)

    op.replace_input(v, v.transpose(target_orders[0]), with_assert=False)
    return True
Beispiel #6
0
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        v -{op}-

    after)

        v -{conversion}- v' -{op-
    """
    v = op.inputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    if target == ChannelModeEnum.RGBA:
        v_new, = ConvertRtoRGBA(None)(v)
    else:
        v_new, = ConvertRGBAtoR(None)(v)
    op.replace_input(v, v_new)
    return True
Beispiel #7
0
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        v -{op}-

    after)

        v -{conversion}- v' -{op}-
    """
    v = op.inputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    if target == ChannelModeEnum.RGBA:
        v_new = convert_r_to_rgba(v)
    else:
        v_new = convert_rgba_to_r(v)
    TextureShape.set(v_new, height=TextureShape.get(v)[0], width=TextureShape.get(v)[1])
    op.replace_input(v, v_new)
    return True