def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key, x in xs.items(): if x == v: x_0, x_1 = v_pair else: if axis in x.order.axes: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) else: # splitting is not occurred x_0 = x_1 = x op_0.append_input(key, x_0) op_1.append_input(key, x_1) for key, y in ys.items(): if y == v: y_0, y_1 = v_pair else: if axis in y.order.axes: # TODO (Kiikurage) # Attribute attached to "y" is not copied to neither "y_0" or "y_1" y_0 = Variable([ s1 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_1 = Variable([ s2 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y, y_new) else: raise UnexpectedAndPleaseReportError op_0.append_output(key, y_0) op_1.append_output(key, y_1)
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key in xs.keys(): x = xs[key] if x == v: x_0, x_1 = v_pair else: if axis not in x.order.axes or x.shape_dict[axis] == 1: # broadcasting x_0 = x_1 = x else: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) op_0.append_input(key, x_0) op_1.append_input(key, x_1) op_0.exec() op_1.exec() for key in ys.keys(): y = ys[key] if y == v: OptimizeRule.replace_variable( graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0]) OptimizeRule.replace_variable( graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1]) else: y_0 = op_0.outputs[key] y_1 = op_1.outputs[key] y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)