Пример #1
0
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    """
    before)

    x1 -+
        +-{op}- y -
    x2 -+

    after)

                v -

    Args:
        graph: the graph
        op: the operator which will be removed
        v: variable with which output variable is replaced
    """
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)
Пример #2
0
def _remove_unary_operator(graph: Graph, op: Operator):
    x = list(op.inputs.values())[0]
    y = list(op.outputs.values())[0]
    op.remove_all()

    if x.order == y.order and x.shape == y.shape:
        x.change_order(y.order)
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, x)

        else:
            y.replace(x)

    else:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, x)

        for op2 in list(y.input_to):
            name = op2.get_input_name(y)
            op2.remove_input(y)
            op2.append_input(name, x)
def _optimize_ScalarAdd_ScalarMul(op1: ScalarAdd, op2: Operator):
    if not isinstance(op2, ScalarMul):
        return False

    x0 = op1.inputs["x0"]
    y2 = op2.outputs["y"]
    op2.remove_all()
    op1.remove_all()
    y = (x0 * op2.value) + (op1.value * op2.value)
    y.replace(y2)
    return True
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul,
                                       c1: ConstantVariable, v1: Variable,
                                       op2: Operator):
    if not isinstance(op2, ScalarMul):
        return False

    y2 = op2.outputs["y"]
    op1.remove_all()
    op2.remove_all()
    y = v1 * (c1 * op2.value)
    y.replace(y2)
    return True
Пример #5
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key, x in xs.items():
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis in x.order.axes:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

            else:
                # splitting is not occurred
                x_0 = x_1 = x

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    for key, y in ys.items():
        if y == v:
            y_0, y_1 = v_pair

        else:
            if axis in y.order.axes:
                # TODO (Kiikurage)
                # Attribute attached to "y" is not copied to neither "y_0" or "y_1"
                y_0 = Variable([
                    s1 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_1 = Variable([
                    s2 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y, y_new)

            else:
                raise UnexpectedAndPleaseReportError

        op_0.append_output(key, y_0)
        op_1.append_output(key, y_1)
Пример #6
0
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    y = op.outputs["y"]
    op.remove_all()
    y.change_order(v.order)
    v.replace(y)

    if v in graph.inputs:
        if y in graph.outputs:
            index = graph.outputs.index(y)
            graph.outputs.remove(y)
            graph.outputs.insert(index, v)

        else:
            y.replace(v)

    else:
        v.replace(y)
def _optimize_ScalarAdd_ElementwiseAdd(op1: ScalarAdd, op2: Operator):
    if not isinstance(op2, ElementwiseAdd):
        return False

    x0 = op1.inputs["x0"]
    y1 = op1.outputs["y"]

    if y1 == op2.inputs["x0"]:
        w = op2.inputs["x1"]
    else:
        w = op2.inputs["x0"]
    y2 = op2.outputs["y"]

    op2.remove_all()
    op1.remove_all()
    y = (x0 + w) + op1.value
    y.replace(y2)
    return True
Пример #8
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key in xs.keys():
        x = xs[key]
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis not in x.order.axes or x.shape_dict[axis] == 1:
                # broadcasting
                x_0 = x_1 = x

            else:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    op_0.exec()
    op_1.exec()

    for key in ys.keys():
        y = ys[key]
        if y == v:
            OptimizeRule.replace_variable(
                graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0])
            OptimizeRule.replace_variable(
                graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1])

        else:
            y_0 = op_0.outputs[key]
            y_1 = op_1.outputs[key]
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
Пример #9
0
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable):
    """
    before)

    x1 -+
        +-{op}- y -
    x2 -+

    after)

                v -

    Args:
        graph: the graph
        op: the operator which will be removed
        v: variable with which output variable is replaced
    """
    y = op.outputs["y"]
    op.remove_all()
    OptimizeRule.replace_variable(graph, v, y, with_assert=False)
Пример #10
0
def test_remove_all():
    op = Operator("op")
    v1 = Variable((1, 2, 3, 4), OrderNHWC)
    v2 = Variable((1, 2, 3, 4), OrderNHWC)
    v3 = Variable((1, 2, 3, 4), OrderNHWC)
    v4 = Variable((1, 2, 3, 4), OrderNHWC)

    op.append_input("v1", v1)
    op.append_input("v2", v2)
    op.append_output("v3", v3)
    op.append_output("v4", v4)

    op.remove_all()

    assert len(op.inputs) == 0
    assert len(op.outputs) == 0
    assert v1.input_to == set()
    assert v2.input_to == set()
    assert v3.output_from is None
    assert v4.output_from is None
Пример #11
0
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul,
                                            c1: ConstantVariable, v1: Variable,
                                            op2: Operator):
    if not isinstance(op2, ElementwiseMul):
        return False

    x0 = op2.inputs["x0"]
    x1 = op2.inputs["x1"]
    y2 = op2.outputs["y"]
    if isinstance(x0, ConstantVariable):
        c2 = x0

    elif isinstance(x1, ConstantVariable):
        c2 = x1

    else:
        return False

    op2.remove_all()
    op1.remove_all()
    y = v1 * (c1 * c2)
    y.replace(y2)
    return True
def remove_operator(op: Operator):
    x = op.inputs["x0"]
    y = op.outputs["y"]
    op.remove_all()
    x.replace(y)
Пример #13
0
def _remove_unary_operator(graph: Graph, op: Operator):
    x = list(op.inputs.values())[0]
    y = list(op.outputs.values())[0]
    op.remove_all()

    OptimizeRule.replace_variable(graph, y, x, with_assert=False)