def optimize_pair(self, op1: ElementwiseMul, op2: ScalarMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        y2 = op2.outputs["y"]
        op1.remove_all()
        op2.remove_all()
        y = v1 * (c1 * op2.value)  # type: Variable
        y.replace(y2)
        return True
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul,
                                       c1: ConstantVariable, v1: Variable,
                                       op2: Operator):
    if not isinstance(op2, ScalarMul):
        return False

    y2 = op2.outputs["y"]
    op1.remove_all()
    op2.remove_all()
    y = v1 * (c1 * op2.value)
    y.replace(y2)
    return True
    def optimize_pair(self, op1: ElementwiseAdd, op2: ElementwiseMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = (v1 * c2) + (c1 * c2)  # type: Variable
        y.replace(y2, with_assert=False)
        return True
    def optimize_pair(self, op1: ElementwiseDiv, op2: ElementwiseMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = v1 * (c2 / c1)
        y.replace(y2)
        return True
    def optimize_pair(self, op1: ScalarMul, op2: ElementwiseMul):
        x0 = op1.inputs["x0"]
        y1 = op1.outputs["y"]

        if y1 == op2.inputs["x0"]:
            w = op2.inputs["x1"]
        else:
            w = op2.inputs["x0"]
        y2 = op2.outputs["y"]

        op2.remove_all()
        op1.remove_all()
        y = (x0 * w) * op1.value  # type: Variable
        y.replace(y2)
        return True
예제 #6
0
    def optimize_pair(self, graph: Graph, op1: ElementwiseAdd, op2: ElementwiseMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = (v1 * c2) + (c1 * c2)
        OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order))
        return True
    def optimize_pair(self, graph: Graph, op1: Concat, op2: ElementwiseMul):
        x0, x1 = op1.inputs["x0"], op1.inputs["x1"]
        c, _ = _get_constant_and_variable(op2, "x0", "x1")
        if c is None:
            return False
        if c.order != Order([op1.axis]):
            return False

        y2 = op2.outputs["y"]
        c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order)
        c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order)

        op1.remove_all()
        op2.remove_all()

        y, = Concat(None, axis=op1.axis)((x0 * c0), (x1 * c1))
        OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order))
        return True
    def optimize_pair(self, op1: Concat, op2: ElementwiseMul):
        x0, x1 = op1.inputs["x0"], op1.inputs["x1"]
        c, _ = _get_constant_and_variable(op2, "x0", "x1")
        if c is None:
            return False
        if c.order != Order([op1.axis]):
            return False

        y2 = op2.outputs["y"]
        c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order)
        c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order)

        op1.remove_all()
        op2.remove_all()

        y, = Concat(None, axis=op1.axis)((x0 * c0),
                                         (x1 * c1))  # type: Variable
        y.replace(y2, with_assert=False)
        return True
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul,
                                            c1: ConstantVariable, v1: Variable,
                                            op2: Operator):
    if not isinstance(op2, ElementwiseMul):
        return False

    x0 = op2.inputs["x0"]
    x1 = op2.inputs["x1"]
    y2 = op2.outputs["y"]
    if isinstance(x0, ConstantVariable):
        c2 = x0

    elif isinstance(x1, ConstantVariable):
        c2 = x1

    else:
        return False

    op2.remove_all()
    op1.remove_all()
    y = v1 * (c1 * c2)
    y.replace(y2)
    return True