Example #1
0
def test_r2rgba():
    """test_r2rgba

    before)

    v0 -{ConvertRGBAtoR}- v1 -{ConvertRtoRGBA}- v2

    after)

    v0 -{Transpose}- v2

    """
    v0 = Variable((2, 3, 4, 5), OrderNCHW)
    v1, = ConvertRGBAtoR(None)(v0)
    v2, = ConvertRtoRGBA(None)(v1)
    v2.change_order(OrderNHWC)

    v0_original_order = v0.order
    v2_original_order = v2.order

    graph = Graph([v0], [v2])
    SimplifyNonsenseChannelModeConversion().optimize(graph)

    assert len(graph.inputs) == 1 and graph.inputs[0] == v0
    assert len(graph.outputs) == 1 and graph.outputs[0] == v2

    assert v0.order == v0_original_order

    assert len(traverse.listup_operators(graph)) == 1

    assert isinstance(v2.output_from,
                      Transpose) and v2.output_from.inputs["x0"] == v0
    assert v2.order == v2_original_order
Example #2
0
def test_r2rgba_2():
    """test_r2rgba_2

    before)

    v0[R] -{ConvertRtoRGBA}- v1[RGBA] -{ConvertRtoRGBA}- v2[RGBA] -{ConvertRtoRGBA} -v3[RGBA]

    after)

    v0[R] -{ConvertRtoRGBA}- v1[RGBA] -{ConvertRGBAtoR}- v3[R] -{Transpose}- v4[R] -{ConvertRtoRGBA}- v2[RGBA] -

    - v2[RGBA] -{ConvertRGBAtoR}- v5[R] -{Transpose}- v6[R] -{ConvertRtoRGBA}- v3[RGBA]
    """
    v0 = Variable((2, 3, 4, 5), OrderNCHW)
    v1, = ConvertRtoRGBA(None)(v0)
    v2, = ConvertRtoRGBA(None)(v1)
    v3, = ConvertRtoRGBA(None)(v2)

    graph = Graph([v0], [v3])
    SimplifyRedundantChannelModeConversion().optimize(graph)

    assert len(graph.inputs) == 1 and graph.inputs[0] == v0
    assert len(graph.outputs) == 1 and graph.outputs[0] == v3

    new_ops = traverse.listup_operators(graph)
    assert len(new_ops) == 7
    assert isinstance(new_ops[0], ConvertRtoRGBA)
    assert isinstance(new_ops[1], ConvertRGBAtoR)
    assert isinstance(new_ops[2], Transpose)
    assert isinstance(new_ops[3], ConvertRtoRGBA)
    assert isinstance(new_ops[4], ConvertRGBAtoR)
    assert isinstance(new_ops[5], Transpose)
    assert isinstance(new_ops[6], ConvertRtoRGBA)
Example #3
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        global _rgba_support_operators
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if op.__class__ not in _rgba_support_operators:
                # This operator doesn't support RGBA mode
                continue

            if op.get_attribute(ChannelMode)[0].mode == ChannelModeEnum.RGBA:
                # This operator is configured as RGBA mode already
                continue

            y = list(op.outputs.values())[0]
            if any(x.shape != y.shape for x in op.inputs.values()):
                # FIXME: ブロードキャストがあるとRGBAは無理
                continue

            op.get_attribute(ChannelMode)[0].mode = ChannelModeEnum.RGBA

            for name, x in op.inputs.items():
                op.remove_input(x)
                x_converted, = ConvertRtoRGBA(None)(x)
                op.append_input(name, x_converted)

            for name, y in list(op.outputs.items()):
                y_dummy = Variable(y.shape, y.order)
                y_converted, = ConvertRGBAtoR(None)(y_dummy)
                for op2 in y.input_to:  # type: Operator
                    op2.replace_input(y, y_converted)
                y_dummy.replace(y)
                y.get_attribute(ChannelMode)[0].mode = ChannelModeEnum.RGBA

            flag_changed = True

        return graph, flag_changed
def template(x_shape=(2, 3, 4, 5),
             x_order: Order = OrderNHWC,
             y_order: Order = OrderNHWC,
             description: str = ""):
    vx = np.arange(np.product(x_shape)).reshape(x_shape)
    vy = vx.copy()

    x = Variable(vx.shape, order=x_order)
    y, = ConvertRtoRGBA(None)(x)

    y.change_order(y_order)

    generate_kernel_test_case(
        description=f"ConvertRtoRGBA {description}",
        graph=Graph([x], [y]),
        backend=["webgl"],
        inputs={x: vx},
        expected={
            y: np.transpose(vy, [x_order.axes_dict[a] for a in y.order.axes])
        },
    )
Example #5
0
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        v -{op}-

    after)

        v -{conversion}- v' -{op-
    """
    v = op.inputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    if target == ChannelModeEnum.RGBA:
        v_new, = ConvertRtoRGBA(None)(v)
    else:
        v_new, = ConvertRGBAtoR(None)(v)
    op.replace_input(v, v_new)
    return True
Example #6
0
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        -{op}- v

    after)

        -{op}- v' -{conversion}- v
    """
    v = op.outputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    v_new = Variable(v.shape, v.order)
    ChannelMode.set(v_new, target)

    op.replace_output(v, v_new)
    if target == ChannelModeEnum.RGBA:
        ConvertRGBAtoR(None)(v_new)[0].replace(v)
    else:
        ConvertRtoRGBA(None)(v_new)[0].replace(v)
    return True