def __call__(self, first: Node, second: Node) -> bool: """ This function checks whether Interpolate nodes 'first' and 'second' can be fused. :param first: the first of fused nodes :param second: the second of fused nodes :return: True, if nodes can be fused, and False otherwise """ if not (is_next(first, second) and self._compare_attributes(first, second)): self.accumulated_axes = set() return False fst_axes = set([a for a in Interpolate.get_axes(first)]) snd_axes = set([a for a in Interpolate.get_axes(second)]) self.accumulated_axes = self.accumulated_axes | fst_axes # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused, # because interpolations with respect to various axes do not affect each other. if not (self.accumulated_axes & snd_axes): return True # Otherwise, nodes cannot be fused. self.accumulated_axes = set() return False
def make_interpolate_reshapeable(interpolate, concat): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' output_shape = interpolate.out_port(0).data.get_shape() interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)] concat_axis = get_canonical_axis_index(output_shape, concat.axis) if concat_axis in interp_axes: return concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()] non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] if len(non_interp_concat_srcs) == 0: return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' interp_axes = Interpolate.get_axes(interpolate) concat_axis = self.get_concat_axis(concat) if concat_axis is None or interp_axes is None \ or np.any(interp_axes < 0) or concat_axis < 0 \ or concat_axis in interp_axes: # checks that interpolate axes and concat axis are valid and do not intersect return non_interp_concat_srcs = self.get_non_interpolate_concat_sources( concat) if not len(non_interp_concat_srcs): # there is no Concat input to take input from return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, input_node=shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def make_interpolate_reshapeable(interpolate): assert interpolate.soft_get('type') == 'Interpolate' axes = Interpolate.get_axes(interpolate) input_shape = interpolate.in_port(0).data.get_shape() output_shape = interpolate.out_port(0).data.get_shape() if not np.all(np.remainder(output_shape, input_shape) == 0) and \ not np.all(np.remainder(input_shape, output_shape) == 0): return graph = interpolate.graph name = interpolate.soft_get('name', interpolate.id) shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() shape.in_port(0).connect(interpolate.in_port(0).get_source()) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) multipliers = output_shape[axes] / input_shape[axes] mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): interpolate = match['interpolate'] transpose_1 = match['transpose_1'] transpose_2 = match['transpose_2'] axes = Interpolate.get_axes(interpolate) if axes is None or not np.array_equal(axes, int64_array([1, 2])): return # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout opset = interpolate.get_opset() assert opset in ['opset1', 'opset4'], \ 'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id)) if opset == 'opset1': interpolate.axes = int64_array([2, 3]) else: interpolate.in_port(3).data.set_value(int64_array([2, 3])) transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) graph.remove_nodes_from([transpose_1.id, transpose_2.id])
def replace_sequence(seq: List[Node], graph: Graph): """ This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer, if modes of all nodes of a sequence are the same. :param seq: sequence of Interpolate layers :param graph: graph to which nodes of seq belong :return: Nothing """ if not seq: return if len(seq) == 1: return modes = set([n.mode for n in seq]) if len(modes) != 1: return dims_and_scales_ = [] # Each element of the list dims_and_scales_ is a pair # (axis, output size for this axis) (opset1) # or # (axis, output size for this axis, output scales for this axis) (opset4) if seq[0].get_opset() == 'opset1': for interp in seq: dims_and_scales_.extend( zip( Interpolate.get_axes(interp), interp.in_port( 1).get_connection().get_source().data.get_value())) axis_to_size = sorted(list(dict(dims_and_scales_).items()), key=lambda x: x[0]) axes_of_node = int64_array([z[0] for z in axis_to_size]) sizes = int64_array([z[1] for z in axis_to_size]) scales = np.ones(len(axis_to_size)) else: for interp in seq: dims_and_scales_.extend( zip( Interpolate.get_axes(interp), interp.in_port( 1).get_connection().get_source().data.get_value(), interp.in_port( 2).get_connection().get_source().data.get_value())) axis_to_size = sorted(dims_and_scales_, key=lambda x: x[0]) axes_of_node = int64_array([z[0] for z in axis_to_size]) sizes = int64_array([z[1] for z in axis_to_size]) scales = np.array([z[2] for z in axis_to_size]) fst_interp_node = seq[0] last_interp_node = seq[-1] last_interp_node_name = last_interp_node.soft_get('name', last_interp_node.id) attributes = get_interpolate_attributes(fst_interp_node) opset = fst_interp_node.get_opset() if opset == 'opset1': attributes['axes'] = axes_of_node interp_node = create_op_with_const_inputs(graph, Interpolate, {1: sizes}, attributes) fst_interp_connection = fst_interp_node.in_port(0).get_connection() fst_interp_connection.set_destination(interp_node.in_port(0)) last_interp_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) else: attributes['in_ports_count'] = 4 interp_node = create_op_with_const_inputs(graph, Interpolate, { 1: sizes, 2: scales, 3: axes_of_node }, attributes) fst_interp_connection = fst_interp_node.in_port(0).get_connection() fst_interp_connection.set_destination(interp_node.in_port(0)) last_interp_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) rename_nodes([(last_interp_node, last_interp_node_name + '/delete_'), (interp_node, last_interp_node_name)])