def is_spatial_squeeze(layout: str, input_shape: np.ndarray, squeeze_dims: np.ndarray): """ Checks that the squeeze operation removes all spatial dimensions. :param layout: graph layout. :param input_shape: numpy array with input shape. :param squeeze_dims: numpy array with dims to squeeze. :return: result of the check. """ if len(input_shape) < 4 or len(input_shape) > 5: return False spatial_dims = [get_height_dim(layout, len(input_shape)), get_width_dim(layout, len(input_shape))] if len(input_shape) == 5: spatial_dims.append(get_depth_dim(layout, len(input_shape))) for dim in spatial_dims: if input_shape[dim] != 1: log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape, squeeze_dims)) return False if len(squeeze_dims) != len(spatial_dims): log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape, squeeze_dims)) return False log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape, squeeze_dims)) return True
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert scales.shape == (4, ) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[2] width_scale = scales[3] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() factor = Const(graph, { 'value': np.array([height_scale, width_scale]) }).create_node() shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin = Const(graph, { 'value': int64_array([get_height_dim(layout, input_shape_rank)]) }).create_node() else: begin = Const(graph, { 'value': int64_array([get_depth_dim(layout, input_shape_rank)]) }).create_node() end = Const(graph, { 'value': int64_array([get_width_dim(layout, input_shape_rank) + 1]) }).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() ss = StridedSlice( graph, { 'name': upsample.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }).create_node() mul = Mul(graph, { 'name': upsample.name + '/factor_mul_' }).create_node() source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) resample_op = Interpolate( graph, dict(name='Interpolate/{}'.format(upsample.name), axes=axes, mode=upsample.attrs()['mode'], antialias=0, convert_to_resample=True)).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(resample_op.in_port(1)) upsample.in_port(0).get_connection().set_destination( resample_op.in_port(0)) upsample.out_port(0).get_connection().set_source( resample_op.out_port(0))
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs in {3, 4}, \ "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) if num_of_inputs == 3: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def test_get_depth_dim_NDHWC(self): self.assertEqual(get_depth_dim('NHWC', 5), 1)
def test_get_depth_dim_NCDHW(self): self.assertEqual(get_depth_dim('NCHW', 5), 2)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None layout = graph.graph['layout'] if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[get_height_dim(layout, input_shape_rank)] width_scale = scales[get_width_dim(layout, input_shape_rank)] if len(scales) == 5: depth_scale = scales[get_depth_dim(layout, input_shape_rank)] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() upsample_name = upsample.soft_get('name', upsample.id) shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = float32_array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = float32_array( [depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) axes_node = Const(graph, { 'name': upsample_name + '/axis', 'value': axes }).create_node() interpolate = Interpolate( graph, { 'mode': upsample.attrs()['mode'], 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'shape_calculation_mode': 'scales', 'version': 'opset4', 'in_ports_count': 4 }).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(interpolate.in_port(1)) axes_node.out_port(0).connect(interpolate.in_port(3)) scales_node = Const(graph, { 'name': upsample_name + '/scales', 'value': factor_value }).create_node() scales_node.out_port(0).connect(interpolate.in_port(2)) upsample.in_port(0).get_connection().set_destination( interpolate.in_port(0)) upsample.out_port(0).get_connection().set_source( interpolate.out_port(0)) rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)]) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[2] width_scale = scales[3] if len(scales) == 5: depth_scale = scales[4] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if not math.isclose(height_scale, width_scale, rel_tol=1e-5): log.debug( 'Width and height scales are not equal: {} vs {} for node {}'. format(width_scale, height_scale, upsample_name)) return if depth_scale is not None and not math.isclose( height_scale, depth_scale, rel_tol=1e-5): log.debug( 'Depth and height scales are not equal: {} vs {} for node {}'. format(depth_scale, height_scale, upsample_name)) return if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = np.array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = np.array([depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) resample_op = Interpolate( graph, dict(name=upsample_name + '/Interpolate', axes=axes, mode=upsample.attrs()['mode'], antialias=0, convert_to_resample=True)).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(resample_op.in_port(1)) upsample.in_port(0).get_connection().set_destination( resample_op.in_port(0)) upsample.out_port(0).get_connection().set_source( resample_op.out_port(0)) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)