Example #1
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        global _rgba_support_operators
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if op.__class__ not in _rgba_support_operators:
                # This operator doesn't support RGBA mode
                continue

            if op.get_attribute(ChannelMode)[0].mode == ChannelModeEnum.RGBA:
                # This operator is configured as RGBA mode already
                continue

            y = list(op.outputs.values())[0]
            if any(x.shape != y.shape for x in op.inputs.values()):
                # FIXME: ブロードキャストがあるとRGBAは無理
                continue

            op.get_attribute(ChannelMode)[0].mode = ChannelModeEnum.RGBA

            for name, x in op.inputs.items():
                op.remove_input(x)
                x_converted, = ConvertRtoRGBA(None)(x)
                op.append_input(name, x_converted)

            for name, y in list(op.outputs.items()):
                y_dummy = Variable(y.shape, y.order)
                y_converted, = ConvertRGBAtoR(None)(y_dummy)
                for op2 in y.input_to:  # type: Operator
                    op2.replace_input(y, y_converted)
                y_dummy.replace(y)
                y.get_attribute(ChannelMode)[0].mode = ChannelModeEnum.RGBA

            flag_changed = True

        return graph, flag_changed
Example #2
0
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    """
    before)

    x1 -+
        +-{op}- y -
    x2 -+

    after)

                v -

    Args:
        graph: the graph
        op: the operator which will be removed
        v: variable with which output variable is replaced
    """
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)
Example #3
0
    def replace_variable(graph: Graph, old_var: Variable, new_var: Variable, with_assert: bool = True):
        old_var.replace(new_var, with_assert=with_assert)

        if old_var in graph.inputs:
            i = graph.inputs.index(old_var)
            graph.inputs.remove(old_var)
            graph.inputs.insert(i, new_var)

        if old_var in graph.outputs:
            i = graph.outputs.index(old_var)
            graph.outputs.remove(old_var)
            graph.outputs.insert(i, new_var)
Example #4
0
    def replace_variable(graph: Graph, old_var: Variable, new_var: Variable):
        old_var.replace(new_var)

        if old_var in graph.inputs:
            i = graph.inputs.index(old_var)
            graph.inputs.remove(old_var)
            graph.inputs.insert(i, new_var)

        if old_var in graph.outputs:
            i = graph.outputs.index(old_var)
            graph.outputs.remove(old_var)
            graph.outputs.insert(i, new_var)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)