def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): """ before) x1 -+ +-{op}- y - x2 -+ after) v - Args: graph: the graph op: the operator which will be removed v: variable with which output variable is replaced """ y = op.outputs["y"] op.remove_all() y.change_order(v.order) v.replace(y) if v in graph.inputs: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, v) else: y.replace(v) else: v.replace(y)
def _remove_unary_operator(graph: Graph, op: Operator): x = list(op.inputs.values())[0] y = list(op.outputs.values())[0] op.remove_all() if x.order == y.order and x.shape == y.shape: x.change_order(y.order) if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, x) else: y.replace(x) else: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, x) for op2 in list(y.input_to): name = op2.get_input_name(y) op2.remove_input(y) op2.append_input(name, x)
def _optimize_ScalarAdd_ScalarMul(op1: ScalarAdd, op2: Operator): if not isinstance(op2, ScalarMul): return False x0 = op1.inputs["x0"] y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (x0 * op2.value) + (op1.value * op2.value) y.replace(y2) return True
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ScalarMul): return False y2 = op2.outputs["y"] op1.remove_all() op2.remove_all() y = v1 * (c1 * op2.value) y.replace(y2) return True
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key, x in xs.items(): if x == v: x_0, x_1 = v_pair else: if axis in x.order.axes: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) else: # splitting is not occurred x_0 = x_1 = x op_0.append_input(key, x_0) op_1.append_input(key, x_1) for key, y in ys.items(): if y == v: y_0, y_1 = v_pair else: if axis in y.order.axes: # TODO (Kiikurage) # Attribute attached to "y" is not copied to neither "y_0" or "y_1" y_0 = Variable([ s1 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_1 = Variable([ s2 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y, y_new) else: raise UnexpectedAndPleaseReportError op_0.append_output(key, y_0) op_1.append_output(key, y_1)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): y = op.outputs["y"] op.remove_all() y.change_order(v.order) v.replace(y) if v in graph.inputs: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, v) else: y.replace(v) else: v.replace(y)
def _optimize_ScalarAdd_ElementwiseAdd(op1: ScalarAdd, op2: Operator): if not isinstance(op2, ElementwiseAdd): return False x0 = op1.inputs["x0"] y1 = op1.outputs["y"] if y1 == op2.inputs["x0"]: w = op2.inputs["x1"] else: w = op2.inputs["x0"] y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (x0 + w) + op1.value y.replace(y2) return True
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key in xs.keys(): x = xs[key] if x == v: x_0, x_1 = v_pair else: if axis not in x.order.axes or x.shape_dict[axis] == 1: # broadcasting x_0 = x_1 = x else: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) op_0.append_input(key, x_0) op_1.append_input(key, x_1) op_0.exec() op_1.exec() for key in ys.keys(): y = ys[key] if y == v: OptimizeRule.replace_variable( graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0]) OptimizeRule.replace_variable( graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1]) else: y_0 = op_0.outputs[key] y_1 = op_1.outputs[key] y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): """ before) x1 -+ +-{op}- y - x2 -+ after) v - Args: graph: the graph op: the operator which will be removed v: variable with which output variable is replaced """ y = op.outputs["y"] op.remove_all() OptimizeRule.replace_variable(graph, v, y, with_assert=False)
def test_remove_all(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) v3 = Variable((1, 2, 3, 4), OrderNHWC) v4 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) op.append_output("v3", v3) op.append_output("v4", v4) op.remove_all() assert len(op.inputs) == 0 assert len(op.outputs) == 0 assert v1.input_to == set() assert v2.input_to == set() assert v3.output_from is None assert v4.output_from is None
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ElementwiseMul): return False x0 = op2.inputs["x0"] x1 = op2.inputs["x1"] y2 = op2.outputs["y"] if isinstance(x0, ConstantVariable): c2 = x0 elif isinstance(x1, ConstantVariable): c2 = x1 else: return False op2.remove_all() op1.remove_all() y = v1 * (c1 * c2) y.replace(y2) return True
def remove_operator(op: Operator): x = op.inputs["x0"] y = op.outputs["y"] op.remove_all() x.replace(y)
def _remove_unary_operator(graph: Graph, op: Operator): x = list(op.inputs.values())[0] y = list(op.outputs.values())[0] op.remove_all() OptimizeRule.replace_variable(graph, y, x, with_assert=False)