def replace_output(graph: Graph, op: Operator, old_var: Variable, new_var: Variable, with_assert: bool = True): op.replace_output(old_var, new_var, with_assert=with_assert) if old_var in graph.outputs: i = graph.outputs.index(old_var) graph.outputs.remove(old_var) graph.outputs.insert(i, new_var)
def _remove_unary_operator(graph: Graph, op: Operator): x = list(op.inputs.values())[0] y = list(op.outputs.values())[0] op.remove_all() if x.order == y.order and x.shape == y.shape: x.change_order(y.order) if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, x) else: y.replace(x) else: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, x) for op2 in list(y.input_to): name = op2.get_input_name(y) op2.remove_input(y) op2.append_input(name, x)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): """ before) x1 -+ +-{op}- y - x2 -+ after) v - Args: graph: the graph op: the operator which will be removed v: variable with which output variable is replaced """ y = op.outputs["y"] op.remove_all() y.change_order(v.order) v.replace(y) if v in graph.inputs: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, v) else: y.replace(v) else: v.replace(y)
def test_get_input_name(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) assert op.get_input_name(v1) == "v1" assert op.get_input_name(v2) == "v2"
def _optimize_ScalarAdd_ScalarMul(op1: ScalarAdd, op2: Operator): if not isinstance(op2, ScalarMul): return False x0 = op1.inputs["x0"] y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (x0 * op2.value) + (op1.value * op2.value) y.replace(y2) return True
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans, = Transpose(None)(orig_var) trans.change_order(target_orders[0]) op.remove_input(orig_var) op.append_input(var_name, trans) return True
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False op.replace_input(v, v.transpose(target_orders[0]), with_assert=False) return True
def test_append_input(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_input("v1", v1) op.append_input("v2", v2) assert op.inputs["v1"] == v1 assert op.inputs["v2"] == v2 assert v1.input_to == {op} assert v2.input_to == {op}
def test_append_output(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_output("v1", v1) op.append_output("v2", v2) assert op.outputs["v1"] == v1 assert op.outputs["v2"] == v2 assert v1.output_from == op assert v2.output_from == op
def _optimize_ElementwiseMul_ScalarMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ScalarMul): return False y2 = op2.outputs["y"] op1.remove_all() op2.remove_all() y = v1 * (c1 * op2.value) y.replace(y2) return True
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) Transpose(None)(v_new)[0].replace(v, with_assert=False) return True
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new, = Transpose(None)(v) op.replace_input(v, v_new, with_assert=False) v_new.change_order(target_orders[0]) return True
def _replace_input(graph: Graph, op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return _optimize_redundant_transposed_input(graph, op, var_name, target_orders) op.replace_input(v, v.transpose(target_orders[0]), with_assert=False) return True
def _replace_output(graph: Graph, op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return _optimize_redundant_transposed_output(graph, op, var_name, target_orders) v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) v_new.transpose(v.order).replace(v, with_assert=False) return True
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans = Variable(orig_var.shape, orig_var.order) trans.change_order(target_orders[0]) op.remove_output(orig_var) op.append_output(var_name, trans) transpose_op = Transpose(None) dummy_out, = transpose_op(trans) transpose_op.remove_output(dummy_out) transpose_op.append_output("y", orig_var) return True
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key, x in xs.items(): if x == v: x_0, x_1 = v_pair else: if axis in x.order.axes: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) else: # splitting is not occurred x_0 = x_1 = x op_0.append_input(key, x_0) op_1.append_input(key, x_1) for key, y in ys.items(): if y == v: y_0, y_1 = v_pair else: if axis in y.order.axes: # TODO (Kiikurage) # Attribute attached to "y" is not copied to neither "y_0" or "y_1" y_0 = Variable([ s1 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_1 = Variable([ s2 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y, y_new) else: raise UnexpectedAndPleaseReportError op_0.append_output(key, y_0) op_1.append_output(key, y_1)
def _listup_splittable_axis(v: Variable, op: Operator) -> List[Axis]: if isinstance(op, (Concat, SplitAxis)): return list(v.order.axes) elif isinstance(op, Reshape): """ For more detail of this condition check, please see the comment document of `_split_reshape` """ splittable_axes = [] # type: List[Axis] v1 = v v2 = op.outputs["y"] if v == op.inputs["x"] else op.inputs["x"] for a1 in v1.order.axes: d1 = mul(v1.shape[v1.order.axes_dict[a1]:]) d2 = 1 for a2 in reversed(v2.order.axes): d2 *= v2.shape_dict[a2] if d2 == d1: splittable_axes.append(a1) continue elif d2 > d1: continue return splittable_axes elif isinstance(op, Im2Col): op = op # type: Im2Col if v in op.outputs.values(): if v.shape_dict[Axis.C] % (op.ksize[0] * op.ksize[1]) == 0: return [Axis.N, Axis.H, Axis.W, Axis.C] else: return [Axis.N, Axis.H, Axis.W] else: return [] elif isinstance(op, PartialIm2Col): op = op # type: PartialIm2Col if v in op.outputs.values(): return [] else: return [op.axis] elif isinstance(op, Sgemm): if v == op.outputs["C"]: return [] else: return list(v.order.axes) elif isinstance(op, Tensordot): if v == op.outputs["C"]: return [] else: return list(v.order.axes) else: return list(attr.axis for attr in op.get_attribute(Tensorwise))
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): y = op.outputs["y"] op.remove_all() y.change_order(v.order) v.replace(y) if v in graph.inputs: if y in graph.outputs: index = graph.outputs.index(y) graph.outputs.remove(y) graph.outputs.insert(index, v) else: y.replace(v) else: v.replace(y)
def __init__(self, op: Operator): self.delegate = lambda exp: exp # type: Callable[[str], str] self.has_inline = traverse.check_attribute_match(op, PostInlineInplace) if self.has_inline: post_inline_inplace = op.get_attribute(PostInlineInplace)[ 0] # type: PostInlineInplace if post_inline_inplace.injected is not None: self.delegate = post_inline_inplace.injected.injector
def _optimize_ScalarAdd_ElementwiseAdd(op1: ScalarAdd, op2: Operator): if not isinstance(op2, ElementwiseAdd): return False x0 = op1.inputs["x0"] y1 = op1.outputs["y"] if y1 == op2.inputs["x0"]: w = op2.inputs["x1"] else: w = op2.inputs["x0"] y2 = op2.outputs["y"] op2.remove_all() op1.remove_all() y = (x0 + w) + op1.value y.replace(y2) return True
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key in xs.keys(): x = xs[key] if x == v: x_0, x_1 = v_pair else: if axis not in x.order.axes or x.shape_dict[axis] == 1: # broadcasting x_0 = x_1 = x else: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) op_0.append_input(key, x_0) op_1.append_input(key, x_1) op_0.exec() op_1.exec() for key in ys.keys(): y = ys[key] if y == v: OptimizeRule.replace_variable( graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0]) OptimizeRule.replace_variable( graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1]) else: y_0 = op_0.outputs[key] y_1 = op_1.outputs[key] y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
def _remove_binary_elementwise(graph: Graph, op: Operator, v: Variable): """ before) x1 -+ +-{op}- y - x2 -+ after) v - Args: graph: the graph op: the operator which will be removed v: variable with which output variable is replaced """ y = op.outputs["y"] op.remove_all() OptimizeRule.replace_variable(graph, v, y, with_assert=False)
def test_replace_all(): op1 = Operator("op1") op2 = Operator("op2") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op1.append_input("v1", v1) op1.append_output("v2", v2) op1.replace(op2) assert len(op1.inputs) == 0 assert len(op1.outputs) == 0 assert len(op2.inputs) == 1 and op2.inputs["v1"] == v1 assert len(op2.outputs) == 1 and op2.outputs["v2"] == v2 assert v1.input_to == {op2} assert v2.output_from == op2
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum): """ before) v -{op}- after) v -{conversion}- v' -{op- """ v = op.inputs[var_name] if ChannelMode.get(v) == target: return False if target == ChannelModeEnum.RGBA: v_new, = ConvertRtoRGBA(None)(v) else: v_new, = ConvertRGBAtoR(None)(v) op.replace_input(v, v_new) return True
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum): """ before) v -{op}- after) v -{conversion}- v' -{op}- """ v = op.inputs[var_name] if ChannelMode.get(v) == target: return False if target == ChannelModeEnum.RGBA: v_new = convert_r_to_rgba(v) else: v_new = convert_rgba_to_r(v) TextureShape.set(v_new, height=TextureShape.get(v)[0], width=TextureShape.get(v)[1]) op.replace_input(v, v_new) return True
def fn(x: Variable): y = Variable(x.shape, x.order) op = Operator(None) op.append_input("x", x) op.append_output("y", y) return y
def _optimize_ElementwiseMul_ElementwiseMul(op1: ElementwiseMul, c1: ConstantVariable, v1: Variable, op2: Operator): if not isinstance(op2, ElementwiseMul): return False x0 = op2.inputs["x0"] x1 = op2.inputs["x1"] y2 = op2.outputs["y"] if isinstance(x0, ConstantVariable): c2 = x0 elif isinstance(x1, ConstantVariable): c2 = x1 else: return False op2.remove_all() op1.remove_all() y = v1 * (c1 * c2) y.replace(y2) return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: convert_rgba_to_r(v_new).change_order(v.order).replace(v) else: convert_r_to_rgba(v_new).change_order(v.order).replace(v) return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: ConvertRGBAtoR(None)(v_new)[0].replace(v) else: ConvertRtoRGBA(None)(v_new)[0].replace(v) return True
def test_replace_output(): op = Operator("op") v1 = Variable((1, 2, 3, 4), OrderNHWC) v2 = Variable((1, 2, 3, 4), OrderNHWC) op.append_output("v1", v1) op.replace_output(v1, v2) assert op.outputs["v1"] == v2 assert v1.output_from is None assert v2.output_from == op