def __call__(self, first: Node, second: Node) -> bool: """ This function checks whether Interpolate nodes 'first' and 'second' can be fused. :param first: the first of fused nodes :param second: the second of fused nodes :return: True, if nodes can be fused, and False otherwise """ if not (is_next(first, second) and self._compare_attributes(first, second)): self.accumulated_axes = set() return False fst_axes = set([a for a in Interpolate.get_axes(first)]) snd_axes = set([a for a in Interpolate.get_axes(second)]) self.accumulated_axes = self.accumulated_axes | fst_axes # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused, # because interpolations with respect to various axes do not affect each other. if not (self.accumulated_axes & snd_axes): return True # Otherwise, nodes cannot be fused. self.accumulated_axes = set() return False
def replace_sub_graph(self, graph: Graph, match: dict): log.debug('Matched NearestNeighborUpsampling pattern: {}'.format([node.id for node in match.values()])) try: input_height = match['pack_1'].in_node(1).value.item() input_width = match['pack_1'].in_node(3).value.item() height_scale = match['mul_const'].shape[-4] width_scale = match['mul_const'].shape[-2] except Exception as ex: log.warning('Failed to determine scaling parameters from the topology. Do not apply pattern.') return axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2]) resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes}) resample_node = resample_op.create_node([match['op']]) const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]), 'name': resample_node.name + '/target_shape'}).create_node() match['reshape_2'].replace_node(resample_node) resample_node.add_input_port(1, skip_if_exist=True) assert resample_node.in_port(1).disconnected() const.out_port(0).connect(resample_node.in_port(1)) graph.remove_nodes_from([node.id for node in match.values() if node.id != match['op'].id])
def replace_sub_graph(self, graph: Graph, match: dict): resize_node = match['resize'] if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \ match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value: log.info( 'Pattern matched around resize op {} has different scale values.' .format(resize_node.name)) return interpolate_node = Interpolate( graph, { 'name': resize_node.name + '/Interpolate', 'mode': resize_node.mode, 'axes': int64_array([2, 3, 4]) }).create_node() scale = match['mul_1'].in_node(1).value scale_value = int64_array([scale, scale, scale]) scale_const = Const(graph, { 'value': scale_value, 'name': resize_node.name + '/Scale' }).create_node() interpolated_shape = Mul(graph, { 'name': resize_node.name + '/OutputShape' }).create_node() match['slice'].out_port(0).connect(interpolated_shape.in_port(0)) scale_const.out_port(0).connect(interpolated_shape.in_port(1)) resize_node.in_port(0).get_connection().set_destination( interpolate_node.in_port(0)) interpolated_shape.out_port(0).connect(interpolate_node.in_port(1)) resize_node.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def test_interpolate4_using_scales_without_axes(self, pads_begin, pads_end, input_shape, output_shape, sizes, scales): graph = build_graph(nodes_attrs=graph_node_attrs_without_axes, edges=graph_edges_without_axes, update_attributes={ 'input_data': {'shape': input_shape}, 'sizes': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)}, 'sizes_data': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)}, 'scales': {'shape': np.array(scales).shape, 'value': np.array(scales)}, 'scales_data': {'shape': np.array(scales).shape, 'value': np.array(scales)}, 'interpolate': {'pads_begin': int64_array(pads_begin), 'pads_end': int64_array(pads_end), 'shape_calculation_mode': 'scales'} }) node = Node(graph, 'interpolate') tested_class = Interpolate(graph=graph, attrs=node.attrs()) tested_class.infer(node) msg = "Interpolate-4 infer failed for case: sizes={}, scales={}, pads_begin={}, pads_end={}," \ " expected_shape={}, actual_shape={}" self.assertTrue(np.array_equal(graph.node['interpolate_data']['shape'], int64_array(output_shape)), msg.format(sizes, scales, pads_begin, pads_end, output_shape, graph.node['interpolate_data']['shape']))
def extract(node): proto_layer = node.pb param = proto_layer.resample_param types = [ "", 'nearest', 'linear', 'cubic', 'area', ] resample_type = types[param.type] update_attrs = { 'antialias': int(param.antialias), 'height': param.height, 'width': param.width, 'type': resample_type, 'factor': param.factor, 'fw': 'caffe', } mapping_rule = merge_attrs(param, update_attrs) mapping_rule['mode'] = mapping_rule['type'] mapping_rule['axes'] = int64_array([2, 3]) mapping_rule.pop('type') Interpolate.update_node_stat(node, mapping_rule) return __class__.enabled
def extract(node): mapping_rule = { 'mode': 'nearest', 'antialias': 0, 'axes': int64_array([1, 2]), } Interpolate.update_node_stat(node, mapping_rule) return __class__.enabled
def extract(cls, node): mapping_rule = { 'align_corners': int(node.pb.attr['align_corners'].b), 'mode': 'linear', 'axes': int64_array([1, 2]), } Interpolate.update_node_stat(node, mapping_rule) return cls.enabled
def replace_op(self, graph: Graph, node: Node): mode = node.module.mode if mode == 'bilinear': mode = 'linear' align_corners = node.module.align_corners if mode == 'linear' and not align_corners: height = node.module.size[0] width = node.module.size[1] attrs = { 'name': node.name, 'version': 'opset4', 'height': height, 'width': width, 'mode': mode, 'axes': [2, 3], 'pads_begin': [0, 0], 'pads_end': [0, 0], 'align_corners': node.module.align_corners, 'shape_calculation_mode': 'sizes', } sizes = Const(graph, { 'value': np.array([height, width]) }).create_node() axes = Const(graph, {'value': np.array([2, 3])}).create_node() scales = Const(graph, { 'value': np.array([1, 1], dtype=np.float32) }).create_node() interp = Interpolate(graph, attrs).create_node( [node.in_node(0), sizes, scales, axes]) else: if node.module.size: attrs = { 'name': node.name, 'version': 'opset1', 'height': node.module.size[0], 'width': node.module.size[1], 'mode': mode, 'axes': [2, 3], 'align_corners': node.module.align_corners, } interp = Interpolate(graph, attrs).create_node([node.in_node(0)]) else: if not node.module.scale_factor: raise Error('No scale_factor found') attrs = { 'name': node.name, 'height_scale': np.float(node.module.scale_factor), 'width_scale': np.float(node.module.scale_factor), 'mode': mode, 'align_corners': node.module.align_corners, } interp = UpsampleOp(graph, attrs).create_node([node.in_node(0)]) return [interp.id]
def extract(node): mapping_rule = { 'pads_begin': 0, 'pads_end': 0, 'align_corners': int(node.pb.attr['align_corners'].b), 'mode': 'linear', 'axes': int64_array([1, 2]), } Interpolate.update_node_stat(node, mapping_rule) return __class__.enabled
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) scale = attrs.int("scale", 1) num_filter = attrs.int("num_filter", 0) mode = attrs.str("sample_type", None) if mode == 'nearest': node_attrs = { 'factor': attrs.int("scale", 1), 'mode': mode, 'antialias': 0, 'axes': int64_array([2, 3]), } Interpolate.update_node_stat(node, node_attrs) elif mode == 'bilinear': """ Bilinear UpSampling uses deconvolution algorithm under the hood. For MXNet Bilinear UpSampling op just wrapper over Deconvolution op. Inputs data: input1 - input data input2 - deconvolution weight """ kernel = 2 * scale - scale % 2 stride = scale pad = math.ceil((scale - 1) / 2) num_group = num_filter node_attrs = { 'op': __class__.op, 'type': 'Deconvolution', 'bias_addable': True, 'bias_term': False, 'pad': int64_array([[0, 0], [0, 0], [pad, pad], [pad, pad]]), 'pad_spatial_shape': int64_array([[pad, pad], [pad, pad]]), 'dilation': None, 'output_spatial_shape': None, 'output_shape': None, 'stride': int64_array([1, 1, stride, stride]), 'group': num_group, 'output': num_filter, 'kernel_spatial': int64_array([kernel, kernel]), 'input_feature_channel': 0, 'output_feature_channel': 1, 'kernel_spatial_idx': None, 'reshape_kernel': True, 'spatial_dims': None, 'channel_dims': int64_array([1]), 'batch_dims': int64_array([0]), 'layout': 'NCHW', 'get_pad': DeconvFrontExtractor.get_pad, } Convolution.update_node_stat(node, node_attrs) return cls.enabled
def __call__(self, first: Node, second: Node) -> bool: """ This function checks whether Interpolate nodes 'first' and 'second' can be fused. :param first: the first of fused nodes :param second: the second of fused nodes :return: True, if nodes can be fused, and False otherwise """ # If some of attributes 'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end' are different, # then nodes cannot be fused, because fused result will be incorrect. op = Interpolate(graph=first.graph, attrs={}) for attr in [ 'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end' ]: if first.soft_get(attr, default=op.attrs[attr]) != second.soft_get( attr, default=op.attrs[attr]): return False fst_axes = set([a for a in first.axes]) snd_axes = set([a for a in second.axes]) self.accumulated_axes = self.accumulated_axes | fst_axes # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused, # because interpolations with respect to various axes do not affect each other. if not (self.accumulated_axes & snd_axes): return True # Otherwise, nodes cannot be fused. self.accumulated_axes = set() return False
def extract(node): proto_layer = node.pb param = proto_layer.interp_param update_attrs = { 'height': param.height, 'width': param.width, 'zoom_factor': param.zoom_factor, 'shrink_factor': param.shrink_factor, } mapping_rule = merge_attrs(param, update_attrs) mapping_rule.update({'fw': 'caffe', 'mode': 'linear', 'axes': int64_array([2, 3]), 'pads_begin': param.pad_beg, 'pads_end': param.pad_end, 'align_corners': 1}) Interpolate.update_node_stat(node, mapping_rule) return __class__.enabled
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' interp_axes = Interpolate.get_axes(interpolate) concat_axis = self.get_concat_axis(concat) if concat_axis is None or interp_axes is None \ or np.any(interp_axes < 0) or concat_axis < 0 \ or concat_axis in interp_axes: # checks that interpolate axes and concat axis are valid and do not intersect return non_interp_concat_srcs = self.get_non_interpolate_concat_sources( concat) if not len(non_interp_concat_srcs): # there is no Concat input to take input from return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, input_node=shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def make_interpolate_reshapeable(interpolate, concat): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' output_shape = interpolate.out_port(0).data.get_shape() interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)] concat_axis = get_canonical_axis_index(output_shape, concat.axis) if concat_axis in interp_axes: return concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()] non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] if len(non_interp_concat_srcs) == 0: return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = int64_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) slice_begin = Const( graph, dict(name=split_node_name + '/slice_begin_', value=int64_array([axis]))).create_node() slice_end = Const( graph, dict(name=split_node_name + '/slice_end_', value=int64_array([axis + 1]))).create_node() strided_slice_node = StridedSlice( graph, { 'name': split_node_name + '/StridedSlice_', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), }).create_node([shape_node, slice_begin, slice_end]) strided_slice_node.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(name=split_node_name + '/Interpolate_', axes=int64_array([axis]), mode='nearest')).create_node() mul_node.out_port(0).connect(interp_node.in_port(1)) match['concat'].out_port(0).get_connection().set_source( interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = np.array([get_split_scale(split)], dtype=np.float32) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def _compare_attributes_of_interpolate1(self, first: Node, second: Node) -> bool: """ This function checks whether attributes of Interpolate-1 nodes first and second are identical (except attribute 'axes'). :param first: the first of compared nodes :param second: the second of compared nodes :return: True, if attributes of nodes are identical and False otherwise """ # If some of attributes 'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end' are different, # then attributes of nodes are not identical. op = Interpolate(graph=first.graph, attrs={}) for attr in ['mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end']: if first.soft_get(attr, default=op.attrs[attr]) != second.soft_get(attr, default=op.attrs[attr]): return False return True
def make_interpolate_reshapeable(interpolate): assert interpolate.soft_get('type') == 'Interpolate' axes = Interpolate.get_axes(interpolate) input_shape = interpolate.in_port(0).data.get_shape() output_shape = interpolate.out_port(0).data.get_shape() if not np.all(np.remainder(output_shape, input_shape) == 0) and \ not np.all(np.remainder(input_shape, output_shape) == 0): return graph = interpolate.graph name = interpolate.soft_get('name', interpolate.id) shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() shape.in_port(0).connect(interpolate.in_port(0).get_source()) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) multipliers = output_shape[axes] / input_shape[axes] mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): interpolate = match['interpolate'] transpose_1 = match['transpose_1'] transpose_2 = match['transpose_2'] axes = Interpolate.get_axes(interpolate) if axes is None or not np.array_equal(axes, int64_array([1, 2])): return # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout opset = interpolate.get_opset() assert opset in ['opset1', 'opset4'], \ 'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id)) if opset == 'opset1': interpolate.axes = int64_array([2, 3]) else: interpolate.in_port(3).data.set_value(int64_array([2, 3])) transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) graph.remove_nodes_from([transpose_1.id, transpose_2.id])
def replace_op(self, graph: Graph, node: Node): mode = node.module.mode if mode.endswith('linear'): # like bilinear or trilinear mode = 'linear' align_corners = node.module.align_corners if mode == 'linear': height = node.module.size[0] if node.module.size is not None else -1 width = node.module.size[1] if node.module.size is not None else -1 dims = node.module.dims axes = np.arange(2, dims) pads = np.zeros(dims, dtype=np.int32) scales = np.repeat(node.module.scale_factor, dims - 2).astype(np.float32) attrs = { 'name': node.name, 'version': 'opset4', 'height': height, 'width': width, 'mode': mode, 'axes': axes, 'pads_begin': pads, 'pads_end': pads, 'coordinate_transformation_mode': 'align_corners' if align_corners else 'half_pixel', 'shape_calculation_mode': 'sizes' if node.module.size is not None else 'scales', } sizes = Const(graph, { 'value': np.array([height, width]) }).create_node() axes = Const(graph, {'value': axes}).create_node() scales = Const(graph, {'value': scales}).create_node() interp = Interpolate(graph, attrs).create_node( [node.in_node(0), sizes, scales, axes]) else: if node.module.size: attrs = { 'name': node.name, 'version': 'opset1', 'height': node.module.size[0], 'width': node.module.size[1], 'mode': mode, 'axes': [2, 3], 'align_corners': node.module.align_corners, } interp = Interpolate(graph, attrs).create_node([node.in_node(0)]) else: if not node.module.scale_factor: raise Error('No scale_factor found') attrs = { 'name': node.name, 'height_scale': np.float(node.module.scale_factor), 'width_scale': np.float(node.module.scale_factor), 'mode': mode, 'align_corners': node.module.align_corners, } interp = UpsampleOp(graph, attrs).create_node([node.in_node(0)]) return [interp.id]
def replace_pattern(self, graph: Graph, match: dict): unsqueeze_node = match['unsqueeze'] unsqueeze_name = unsqueeze_node.name second_input_of_unsqueeze = unsqueeze_node.in_port( 1).get_connection().get_source().node if not second_input_of_unsqueeze.has_valid('value'): return d_idx = int(second_input_of_unsqueeze.value) second_input_of_tile = match['tile'].in_port( 1).get_connection().get_source().node if not second_input_of_tile.has_valid('value'): return input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape() if len(input_shape_of_unsqueeze) not in {4, 5}: return scale = float32_array([second_input_of_tile.value[d_idx]]) axis = d_idx - 1 axis_node = Const(graph, { 'name': unsqueeze_name + '/axis', 'value': int64_array([axis]) }).create_node() shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=unsqueeze_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) slice_begin = Const( graph, dict(name=unsqueeze_name + '/slice_begin', value=int64_array([axis]))).create_node() slice_end = Const( graph, dict(name=unsqueeze_name + '/slice_end', value=int64_array([axis + 1]))).create_node() strided_slice_node = StridedSlice( graph, { 'name': unsqueeze_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), }).create_node() shape_node.out_port(0).connect(strided_slice_node.in_port(0)) slice_begin.out_port(0).connect(strided_slice_node.in_port(1)) slice_end.out_port(0).connect(strided_slice_node.in_port(2)) cast_shape_to_float = Cast(graph, { 'dst_type': np.float32 }).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, { 'name': unsqueeze_name + '/Floor' }).create_node() cast_mul_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) reshape_node = match['reshape'] reshape_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) reshape_name = reshape_node.soft_get('name', reshape_node.id) rename_nodes([(reshape_node, reshape_name + '/delete'), (interp_node, reshape_name)]) unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection() before_unsqueeze = unsqueeze_connection.get_source().node unsqueeze_connection.set_destination(interp_node.in_port(0)) before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): unsqueeze_node = match['unsqueeze'] unsqueeze_name = unsqueeze_node.name second_input_of_unsqueeze = unsqueeze_node.in_port( 1).get_connection().get_source().node if not second_input_of_unsqueeze.has_valid('value'): return d_idx = int(second_input_of_unsqueeze.value) second_input_of_tile = match['tile'].in_port( 1).get_connection().get_source().node if not second_input_of_tile.has_valid('value'): return input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape() if len(input_shape_of_unsqueeze) not in {4, 5}: return scale = int64_array([second_input_of_tile.value[d_idx]]) axis = d_idx - 1 shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape_')).create_node() scales_node = Const( graph, dict(name=unsqueeze_name + '/scales_', value=scale)).create_node() mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul_')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) slice_begin = Const( graph, dict(name=unsqueeze_name + '/slice_begin_', value=int64_array([axis]))).create_node() slice_end = Const( graph, dict(name=unsqueeze_name + '/slice_end_', value=int64_array([axis + 1]))).create_node() strided_slice_node = StridedSlice( graph, { 'name': unsqueeze_name + '/StridedSlice_', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), }).create_node() shape_node.out_port(0).connect(strided_slice_node.in_port(0)) slice_begin.out_port(0).connect(strided_slice_node.in_port(1)) slice_end.out_port(0).connect(strided_slice_node.in_port(2)) strided_slice_node.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(name=unsqueeze_name + '/Interpolate_', axes=int64_array([axis]), mode='nearest')).create_node() mul_node.out_port(0).connect(interp_node.in_port(1)) match['reshape'].out_port(0).get_connection().set_source( interp_node.out_port(0)) unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection() before_unsqueeze = unsqueeze_connection.get_source().node unsqueeze_connection.set_destination(interp_node.in_port(0)) before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert scales.shape == (4, ) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[2] width_scale = scales[3] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() factor = Const(graph, { 'value': np.array([height_scale, width_scale]) }).create_node() shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin = Const(graph, { 'value': int64_array([get_height_dim(layout, input_shape_rank)]) }).create_node() else: begin = Const(graph, { 'value': int64_array([get_depth_dim(layout, input_shape_rank)]) }).create_node() end = Const(graph, { 'value': int64_array([get_width_dim(layout, input_shape_rank) + 1]) }).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() ss = StridedSlice( graph, { 'name': upsample.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }).create_node() mul = Mul(graph, { 'name': upsample.name + '/factor_mul_' }).create_node() source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) resample_op = Interpolate( graph, dict(name='Interpolate/{}'.format(upsample.name), axes=axes, mode=upsample.attrs()['mode'], antialias=0, convert_to_resample=True)).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(resample_op.in_port(1)) upsample.in_port(0).get_connection().set_destination( resample_op.in_port(0)) upsample.out_port(0).get_connection().set_source( resample_op.out_port(0))
def replace_sequence(seq: List[Node], graph: Graph): """ This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer, if modes of all nodes of a sequence are the same. :param seq: sequence of Interpolate layers :param graph: graph to which nodes of seq belong :return: Nothing """ if not seq: return if len(seq) == 1: return modes = set([n.mode for n in seq]) if len(modes) != 1: return dims_and_scales_ = [] # Each element of the list dims_and_scales_ is a pair # (axis, output size for this axis) for interp in seq: dims_and_scales_.extend( zip(interp.axes, interp.in_port(1).get_connection().get_source().node.value)) axis_to_size = sorted(list(dict(dims_and_scales_).items()), key=lambda x: x[0]) axes_of_node = int64_array([z[0] for z in axis_to_size]) sizes = int64_array([z[1] for z in axis_to_size]) fst_interp_node = seq[0] last_interp_node = seq[-1] fst_interp_node_name = fst_interp_node.name fst_interp_node_mode = fst_interp_node.mode fst_interp_node_align_corners = fst_interp_node.soft_get('align_corners', default=0) fst_interp_node_antialias = fst_interp_node.soft_get('antialias', default=0) fst_interp_node_pads_begin = fst_interp_node.soft_get('pads_begin', default=0) fst_interp_node_pads_end = fst_interp_node.soft_get('pads_end', default=0) interp_node = Interpolate( graph, dict(name=fst_interp_node_name + '/Interpolate_', axes=axes_of_node, mode=fst_interp_node_mode, align_corners=fst_interp_node_align_corners, antialias=fst_interp_node_antialias, pads_begin=fst_interp_node_pads_begin, pads_end=fst_interp_node_pads_end)).create_node() scales_node = Const( graph, dict(name=fst_interp_node_name + '/scales_', value=sizes)).create_node() scales_node.out_port(0).connect(interp_node.in_port(1)) fst_interp_connection = fst_interp_node.in_port(0).get_connection() fst_interp_connection.set_destination(interp_node.in_port(0)) last_interp_node.out_port(0).get_connection().set_source( interp_node.out_port(0))
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-10 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) resize_name = resize.soft_get('name', resize.id) rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, {'name': resize_name + '/axes'}) sizes_ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/sizes_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/scales_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) rank_node.out_port(0).connect(range_node.in_port(1)) interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest', 'coordinate_transformation_mode': 'asymmetric', 'cube_coeff': -0.75, 'nearest_mode': 'simple', 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': 'scales', 'in_ports_count': 4 }).create_node() range_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g., # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40. # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0], # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because # input_shape[2] * scales[2] + 1.0e-5 = 39.99991. # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)). add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) mul_node = Mul(graph, { 'name': resize_name + '/Mul' }).create_node([cast_shape_to_float, add_node]) floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node([mul_node]) cast_mul_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node([floor_node]) cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(1).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_resize_input.get_source().connect(rank_node.in_port(0)) connection_of_scales.get_source().connect(add_node.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[2] width_scale = scales[3] if len(scales) == 5: depth_scale = scales[4] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if not math.isclose(height_scale, width_scale, rel_tol=1e-5): log.debug( 'Width and height scales are not equal: {} vs {} for node {}'. format(width_scale, height_scale, upsample_name)) return if depth_scale is not None and not math.isclose( height_scale, depth_scale, rel_tol=1e-5): log.debug( 'Depth and height scales are not equal: {} vs {} for node {}'. format(depth_scale, height_scale, upsample_name)) return if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = np.array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = np.array([depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) resample_op = Interpolate( graph, dict(name=upsample_name + '/Interpolate', axes=axes, mode=upsample.attrs()['mode'], antialias=0, convert_to_resample=True)).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(resample_op.in_port(1)) upsample.in_port(0).get_connection().set_destination( resample_op.in_port(0)) upsample.out_port(0).get_connection().set_source( resample_op.out_port(0)) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs in {3, 4}, \ "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) if num_of_inputs == 3: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_sequence(seq: List[Node], graph: Graph): """ This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer, if modes of all nodes of a sequence are the same. :param seq: sequence of Interpolate layers :param graph: graph to which nodes of seq belong :return: Nothing """ if not seq: return if len(seq) == 1: return modes = set([n.mode for n in seq]) if len(modes) != 1: return dims_and_scales_ = [] # Each element of the list dims_and_scales_ is a pair # (axis, output size for this axis) (opset1) # or # (axis, output size for this axis, output scales for this axis) (opset4) if seq[0].get_opset() == 'opset1': for interp in seq: dims_and_scales_.extend( zip( Interpolate.get_axes(interp), interp.in_port( 1).get_connection().get_source().data.get_value())) axis_to_size = sorted(list(dict(dims_and_scales_).items()), key=lambda x: x[0]) axes_of_node = int64_array([z[0] for z in axis_to_size]) sizes = int64_array([z[1] for z in axis_to_size]) scales = np.ones(len(axis_to_size)) else: for interp in seq: dims_and_scales_.extend( zip( Interpolate.get_axes(interp), interp.in_port( 1).get_connection().get_source().data.get_value(), interp.in_port( 2).get_connection().get_source().data.get_value())) axis_to_size = sorted(dims_and_scales_, key=lambda x: x[0]) axes_of_node = int64_array([z[0] for z in axis_to_size]) sizes = int64_array([z[1] for z in axis_to_size]) scales = np.array([z[2] for z in axis_to_size]) fst_interp_node = seq[0] last_interp_node = seq[-1] last_interp_node_name = last_interp_node.soft_get('name', last_interp_node.id) attributes = get_interpolate_attributes(fst_interp_node) opset = fst_interp_node.get_opset() if opset == 'opset1': attributes['axes'] = axes_of_node interp_node = create_op_with_const_inputs(graph, Interpolate, {1: sizes}, attributes) fst_interp_connection = fst_interp_node.in_port(0).get_connection() fst_interp_connection.set_destination(interp_node.in_port(0)) last_interp_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) else: attributes['in_ports_count'] = 4 interp_node = create_op_with_const_inputs(graph, Interpolate, { 1: sizes, 2: scales, 3: axes_of_node }, attributes) fst_interp_connection = fst_interp_node.in_port(0).get_connection() fst_interp_connection.set_destination(interp_node.in_port(0)) last_interp_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) rename_nodes([(last_interp_node, last_interp_node_name + '/delete_'), (interp_node, last_interp_node_name)])
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None layout = graph.graph['layout'] if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[get_height_dim(layout, input_shape_rank)] width_scale = scales[get_width_dim(layout, input_shape_rank)] if len(scales) == 5: depth_scale = scales[get_depth_dim(layout, input_shape_rank)] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() upsample_name = upsample.soft_get('name', upsample.id) shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = float32_array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = float32_array( [depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) axes_node = Const(graph, { 'name': upsample_name + '/axis', 'value': axes }).create_node() interpolate = Interpolate( graph, { 'mode': upsample.attrs()['mode'], 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'shape_calculation_mode': 'scales', 'version': 'opset4', 'in_ports_count': 4 }).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(interpolate.in_port(1)) axes_node.out_port(0).connect(interpolate.in_port(3)) scales_node = Const(graph, { 'name': upsample_name + '/scales', 'value': factor_value }).create_node() scales_node.out_port(0).connect(interpolate.in_port(2)) upsample.in_port(0).get_connection().set_destination( interpolate.in_port(0)) upsample.out_port(0).get_connection().set_source( interpolate.out_port(0)) rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)]) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)
def replace_sub_graph(self, graph: Graph, match: dict): log.debug('Matched NearestNeighborUpsampling pattern: {}'.format( [node.id for node in match.values()])) try: input_height = match['pack_1'].in_node(1).value.item() input_width = match['pack_1'].in_node(3).value.item() height_scale = match['mul_const'].shape[-4] width_scale = match['mul_const'].shape[-2] except Exception as ex: log.warning( 'Failed to determine scaling parameters from the topology. Do not apply pattern.' ) return reshape2_name = match['reshape_2'].name resample_op = Interpolate( graph, { 'mode': 'nearest', 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'version': 'opset4', 'name': reshape2_name + '/Resample', 'shape_calculation_mode': 'scales', 'in_ports_count': 4 }) resample_node = resample_op.create_node([match['op']]) axes_node = Const( graph, { 'name': resample_node.name + '/axes', 'value': int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2]) }).create_node() sizes_node = Const( graph, { 'value': np.array( [input_height * height_scale, input_width * width_scale]), 'name': resample_node.name + '/target_shape' }).create_node() scales_node = Const( graph, { 'value': np.array([height_scale, width_scale], dtype=np.float32), 'name': resample_node.name + '/scales' }).create_node() match['reshape_2'].replace_node(resample_node) resample_node.add_input_port(1, skip_if_exist=True) assert resample_node.in_port(1).disconnected() sizes_node.out_port(0).connect(resample_node.in_port(1)) scales_node.out_port(0).connect(resample_node.in_port(2)) axes_node.out_port(0).connect(resample_node.in_port(3)) graph.remove_nodes_from( [node.id for node in match.values() if node.id != match['op'].id])