def optimize_pair(self, op1: ElementwiseMul, op2: ScalarMul): c1, v1 = _get_constant_and_variable(op1, "x0", "x1") if c1 is None: return False y2 = op2.outputs["y"] op1.remove_all() op2.remove_all() y = v1 * (c1 * op2.value) # type: Variable y.replace(y2) return True
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ScalarMul): return False y2 = op2.outputs["y"] op1.remove_all() op2.remove_all() y = v1 * (c1 * op2.value) y.replace(y2) return True
def optimize_pair(self, op1: ElementwiseAdd, op2: ElementwiseMul): c1, v1 = _get_constant_and_variable(op1, "x0", "x1") if c1 is None: return False c2, v2 = _get_constant_and_variable(op2, "x0", "x1") if c2 is None: return False y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (v1 * c2) + (c1 * c2) # type: Variable y.replace(y2, with_assert=False) return True
def optimize_pair(self, op1: ElementwiseDiv, op2: ElementwiseMul): c1, v1 = _get_constant_and_variable(op1, "x0", "x1") if c1 is None: return False c2, v2 = _get_constant_and_variable(op2, "x0", "x1") if c2 is None: return False y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = v1 * (c2 / c1) y.replace(y2) return True
def optimize_pair(self, op1: ScalarMul, op2: ElementwiseMul): x0 = op1.inputs["x0"] y1 = op1.outputs["y"] if y1 == op2.inputs["x0"]: w = op2.inputs["x1"] else: w = op2.inputs["x0"] y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (x0 * w) * op1.value # type: Variable y.replace(y2) return True
def optimize_pair(self, graph: Graph, op1: ElementwiseAdd, op2: ElementwiseMul): c1, v1 = _get_constant_and_variable(op1, "x0", "x1") if c1 is None: return False c2, v2 = _get_constant_and_variable(op2, "x0", "x1") if c2 is None: return False y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (v1 * c2) + (c1 * c2) OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order)) return True
def optimize_pair(self, graph: Graph, op1: Concat, op2: ElementwiseMul): x0, x1 = op1.inputs["x0"], op1.inputs["x1"] c, _ = _get_constant_and_variable(op2, "x0", "x1") if c is None: return False if c.order != Order([op1.axis]): return False y2 = op2.outputs["y"] c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order) c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order) op1.remove_all() op2.remove_all() y, = Concat(None, axis=op1.axis)((x0 * c0), (x1 * c1)) OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order)) return True
def optimize_pair(self, op1: Concat, op2: ElementwiseMul): x0, x1 = op1.inputs["x0"], op1.inputs["x1"] c, _ = _get_constant_and_variable(op2, "x0", "x1") if c is None: return False if c.order != Order([op1.axis]): return False y2 = op2.outputs["y"] c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order) c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order) op1.remove_all() op2.remove_all() y, = Concat(None, axis=op1.axis)((x0 * c0), (x1 * c1)) # type: Variable y.replace(y2, with_assert=False) return True
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ElementwiseMul): return False x0 = op2.inputs["x0"] x1 = op2.inputs["x1"] y2 = op2.outputs["y"] if isinstance(x0, ConstantVariable): c2 = x0 elif isinstance(x1, ConstantVariable): c2 = x1 else: return False op2.remove_all() op1.remove_all() y = v1 * (c1 * c2) y.replace(y2) return True