Exemple #1
0
    def replace_output(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True):
        op.replace_output(old_var, new_var, with_assert=with_assert)

        if old_var in graph.outputs:
            i = graph.outputs.index(old_var)
            graph.outputs.remove(old_var)
            graph.outputs.insert(i, new_var)
def _remove_unary_operator(graph: Graph, op: Operator):
    x = list(op.inputs.values())[0]
    y = list(op.outputs.values())[0]
    op.remove_all()

    if x.order == y.order and x.shape == y.shape:
        x.change_order(y.order)
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, x)

        else:
            y.replace(x)

    else:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, x)

        for op2 in list(y.input_to):
            name = op2.get_input_name(y)
            op2.remove_input(y)
            op2.append_input(name, x)
Exemple #3
0
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    """
    before)

    x1 -+
        +-{op}- y -
    x2 -+

    after)

                v -

    Args:
        graph: the graph
        op: the operator which will be removed
        v: variable with which output variable is replaced
    """
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)
Exemple #4
0
def test_get_input_name():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)

    assert op.get_input_name(v1) == "v1"
    assert op.get_input_name(v2) == "v2"
def _optimize_ScalarAdd_ScalarMul(op1: ScalarAdd, op2: Operator):
    if not isinstance(op2, ScalarMul):
        return False

    x0 = op1.inputs["x0"]
    y2 = op2.outputs["y"]
    op2.remove_all()
    op1.remove_all()
    y = (x0 * op2.value) + (op1.value * op2.value)
    y.replace(y2)
    return True
Exemple #6
0
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.inputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans, = Transpose(None)(orig_var)
    trans.change_order(target_orders[0])
    op.remove_input(orig_var)
    op.append_input(var_name, trans)
    return True
Exemple #7
0
def _replace_input(op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    op.replace_input(v, v.transpose(target_orders[0]), with_assert=False)
    return True
Exemple #8
0
def test_append_input():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)

    assert op.inputs["v1"] == v1
    assert op.inputs["v2"] == v2
    assert v1.input_to == {op}
    assert v2.input_to == {op}
Exemple #9
0
def test_append_output():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.append_output("v2", v2)

    assert op.outputs["v1"] == v1
    assert op.outputs["v2"] == v2
    assert v1.output_from == op
    assert v2.output_from == op
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul,
                                       c1: ConstantVariable, v1: Variable,
                                       op2: Operator):
    if not isinstance(op2, ScalarMul):
        return False

    y2 = op2.outputs["y"]
    op1.remove_all()
    op2.remove_all()
    y = v1 * (c1 * op2.value)
    y.replace(y2)
    return True
Exemple #11
0
def _replace_output(op: Operator, var_name: str,
                    target_orders: Union[Order, List[Order]]):
    v = op.outputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    v_new = Variable(v.shape, v.order).change_order(target_orders[0])
    op.replace_output(v, v_new, with_assert=False)
    Transpose(None)(v_new)[0].replace(v, with_assert=False)
    return True
Exemple #12
0
def _replace_input(op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    v_new, = Transpose(None)(v)
    op.replace_input(v, v_new, with_assert=False)
    v_new.change_order(target_orders[0])
    return True
Exemple #13
0
def _replace_input(graph: Graph, op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]

    if v.order in target_orders:
        return _optimize_redundant_transposed_input(graph, op, var_name,
                                                    target_orders)

    op.replace_input(v, v.transpose(target_orders[0]), with_assert=False)
    return True
Exemple #14
0
def _replace_output(graph: Graph, op: Operator, var_name: str,
                    target_orders: Union[Order, List[Order]]):
    v = op.outputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]

    if v.order in target_orders:
        return _optimize_redundant_transposed_output(graph, op, var_name,
                                                     target_orders)

    v_new = Variable(v.shape, v.order).change_order(target_orders[0])
    op.replace_output(v, v_new, with_assert=False)
    v_new.transpose(v.order).replace(v, with_assert=False)
    return True
Exemple #15
0
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.outputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans = Variable(orig_var.shape, orig_var.order)
    trans.change_order(target_orders[0])
    op.remove_output(orig_var)
    op.append_output(var_name, trans)
    transpose_op = Transpose(None)
    dummy_out, = transpose_op(trans)
    transpose_op.remove_output(dummy_out)
    transpose_op.append_output("y", orig_var)
    return True
Exemple #16
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key, x in xs.items():
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis in x.order.axes:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

            else:
                # splitting is not occurred
                x_0 = x_1 = x

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    for key, y in ys.items():
        if y == v:
            y_0, y_1 = v_pair

        else:
            if axis in y.order.axes:
                # TODO (Kiikurage)
                # Attribute attached to "y" is not copied to neither "y_0" or "y_1"
                y_0 = Variable([
                    s1 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_1 = Variable([
                    s2 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y, y_new)

            else:
                raise UnexpectedAndPleaseReportError

        op_0.append_output(key, y_0)
        op_1.append_output(key, y_1)
Exemple #17
0
def _listup_splittable_axis(v: Variable, op: Operator) -> List[Axis]:
    if isinstance(op, (Concat, SplitAxis)):
        return list(v.order.axes)

    elif isinstance(op, Reshape):
        """
        For more detail of this condition check, please see the comment document of `_split_reshape`
        """
        splittable_axes = []  # type: List[Axis]
        v1 = v
        v2 = op.outputs["y"] if v == op.inputs["x"] else op.inputs["x"]

        for a1 in v1.order.axes:
            d1 = mul(v1.shape[v1.order.axes_dict[a1]:])
            d2 = 1
            for a2 in reversed(v2.order.axes):
                d2 *= v2.shape_dict[a2]

                if d2 == d1:
                    splittable_axes.append(a1)
                    continue

                elif d2 > d1:
                    continue

        return splittable_axes

    elif isinstance(op, Im2Col):
        op = op  # type: Im2Col
        if v in op.outputs.values():
            if v.shape_dict[Axis.C] % (op.ksize[0] * op.ksize[1]) == 0:
                return [Axis.N, Axis.H, Axis.W, Axis.C]
            else:
                return [Axis.N, Axis.H, Axis.W]

        else:
            return []

    elif isinstance(op, PartialIm2Col):
        op = op  # type: PartialIm2Col
        if v in op.outputs.values():
            return []

        else:
            return [op.axis]

    elif isinstance(op, Sgemm):
        if v == op.outputs["C"]:
            return []
        else:
            return list(v.order.axes)

    elif isinstance(op, Tensordot):
        if v == op.outputs["C"]:
            return []
        else:
            return list(v.order.axes)

    else:
        return list(attr.axis for attr in op.get_attribute(Tensorwise))
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)
Exemple #19
0
    def __init__(self, op: Operator):
        self.delegate = lambda exp: exp  # type: Callable[[str], str]
        self.has_inline = traverse.check_attribute_match(op, PostInlineInplace)

        if self.has_inline:
            post_inline_inplace = op.get_attribute(PostInlineInplace)[
                0]  # type: PostInlineInplace
            if post_inline_inplace.injected is not None:
                self.delegate = post_inline_inplace.injected.injector
def _optimize_ScalarAdd_ElementwiseAdd(op1: ScalarAdd, op2: Operator):
    if not isinstance(op2, ElementwiseAdd):
        return False

    x0 = op1.inputs["x0"]
    y1 = op1.outputs["y"]

    if y1 == op2.inputs["x0"]:
        w = op2.inputs["x1"]
    else:
        w = op2.inputs["x0"]
    y2 = op2.outputs["y"]

    op2.remove_all()
    op1.remove_all()
    y = (x0 + w) + op1.value
    y.replace(y2)
    return True
Exemple #21
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key in xs.keys():
        x = xs[key]
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis not in x.order.axes or x.shape_dict[axis] == 1:
                # broadcasting
                x_0 = x_1 = x

            else:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    op_0.exec()
    op_1.exec()

    for key in ys.keys():
        y = ys[key]
        if y == v:
            OptimizeRule.replace_variable(
                graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0])
            OptimizeRule.replace_variable(
                graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1])

        else:
            y_0 = op_0.outputs[key]
            y_1 = op_1.outputs[key]
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    """
    before)

    x1 -+
        +-{op}- y -
    x2 -+

    after)

                v -

    Args:
        graph: the graph
        op: the operator which will be removed
        v: variable with which output variable is replaced
    """
    y = op.outputs["y"]
    op.remove_all()
    OptimizeRule.replace_variable(graph, v, y, with_assert=False)
Exemple #23
0
def test_replace_all():
    op1 = Operator("op1")
    op2 = Operator("op2")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op1.append_input("v1", v1)
    op1.append_output("v2", v2)

    op1.replace(op2)

    assert len(op1.inputs) == 0
    assert len(op1.outputs) == 0
    assert len(op2.inputs) == 1 and op2.inputs["v1"] == v1
    assert len(op2.outputs) == 1 and op2.outputs["v2"] == v2
    assert v1.input_to == {op2}
    assert v2.output_from == op2
Exemple #24
0
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        v -{op}-

    after)

        v -{conversion}- v' -{op-
    """
    v = op.inputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    if target == ChannelModeEnum.RGBA:
        v_new, = ConvertRtoRGBA(None)(v)
    else:
        v_new, = ConvertRGBAtoR(None)(v)
    op.replace_input(v, v_new)
    return True
Exemple #25
0
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        v -{op}-

    after)

        v -{conversion}- v' -{op}-
    """
    v = op.inputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    if target == ChannelModeEnum.RGBA:
        v_new = convert_r_to_rgba(v)
    else:
        v_new = convert_rgba_to_r(v)
    TextureShape.set(v_new, height=TextureShape.get(v)[0], width=TextureShape.get(v)[1])
    op.replace_input(v, v_new)
    return True
def fn(x: Variable):
    y = Variable(x.shape, x.order)
    op = Operator(None)

    op.append_input("x", x)
    op.append_output("y", y)

    return y
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul,
                                            c1: ConstantVariable, v1: Variable,
                                            op2: Operator):
    if not isinstance(op2, ElementwiseMul):
        return False

    x0 = op2.inputs["x0"]
    x1 = op2.inputs["x1"]
    y2 = op2.outputs["y"]
    if isinstance(x0, ConstantVariable):
        c2 = x0

    elif isinstance(x1, ConstantVariable):
        c2 = x1

    else:
        return False

    op2.remove_all()
    op1.remove_all()
    y = v1 * (c1 * c2)
    y.replace(y2)
    return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        -{op}- v

    after)

        -{op}- v' -{conversion}- v
    """
    v = op.outputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    v_new = Variable(v.shape, v.order)
    ChannelMode.set(v_new, target)

    op.replace_output(v, v_new)
    if target == ChannelModeEnum.RGBA:
        convert_rgba_to_r(v_new).change_order(v.order).replace(v)
    else:
        convert_r_to_rgba(v_new).change_order(v.order).replace(v)
    return True
Exemple #29
0
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum):
    """
    before)

        -{op}- v

    after)

        -{op}- v' -{conversion}- v
    """
    v = op.outputs[var_name]

    if ChannelMode.get(v) == target:
        return False

    v_new = Variable(v.shape, v.order)
    ChannelMode.set(v_new, target)

    op.replace_output(v, v_new)
    if target == ChannelModeEnum.RGBA:
        ConvertRGBAtoR(None)(v_new)[0].replace(v)
    else:
        ConvertRtoRGBA(None)(v_new)[0].replace(v)
    return True
Exemple #30
0
def test_replace_output():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_output("v1", v1)
    op.replace_output(v1, v2)

    assert op.outputs["v1"] == v2
    assert v1.output_from is None
    assert v2.output_from == op