def optimize_pair(self, op1: ElementwiseDiv, op2: ScalarMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        y2 = op2.outputs["y"]
        op1.remove_all()
        op2.remove_all()
        y = v1 * (op2.value / c1)  # type: Variable
        y.replace(y2)
        return True
    def optimize_pair(self, op1: ElementwiseAdd, op2: ElementwiseDiv):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = (v1 / c2) + (c1 / c2)  # type: Variable
        y.replace(y2, with_assert=False)
        return True
    def optimize_pair(self, op1: ElementwiseDiv, op2: ElementwiseMul):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = v1 * (c2 / c1)
        y.replace(y2)
        return True
    def optimize_pair(self, op1: ScalarMul, op2: ElementwiseDiv):
        x0 = op1.inputs["x0"]
        y1 = op1.outputs["y"]

        if y1 == op2.inputs["x0"]:
            w = op2.inputs["x1"]
        else:
            w = op2.inputs["x0"]
        y2 = op2.outputs["y"]

        op2.remove_all()
        op1.remove_all()
        y = (x0 / w) * op1.value  # type: Variable
        y.replace(y2)
        return True
Exemplo n.º 5
0
    def optimize_pair(self, graph: Graph, op1: ElementwiseAdd, op2: ElementwiseDiv):
        c1, v1 = _get_constant_and_variable(op1, "x0", "x1")
        if c1 is None:
            return False

        c2, v2 = _get_constant_and_variable(op2, "x0", "x1")
        if c2 is None:
            return False

        y2 = op2.outputs["y"]
        op2.remove_all()
        op1.remove_all()
        y = (v1 / c2) + (c1 / c2)
        OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order))
        return True