def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() reshape_to_1d.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'The node "{}" does not have mandatory attribute "axis"'.format( node_name) flatten_node = FlattenONNX(graph, { 'name': node_name + '/FlattenONNX_', 'axis': node.axis }).create_node() shape_node = Shape(graph, { 'name': node_name + '/ShapeOf_' }).create_node() logsoftmax_node = LogSoftmax(graph, { 'name': node_name + '/LogSoftmax_', 'axis': 1 }).create_node() reshape_node = Reshape(graph, {}).create_node() rename_nodes([(node, node_name + '/delete'), (reshape_node, node_name)]) shape_node.out_port(0).connect(reshape_node.in_port(1)) logsoftmax_node.out_port(0).connect(reshape_node.in_port(0)) flatten_node.out_port(0).connect(logsoftmax_node.in_port(0)) source = node.in_port(0).get_source() flatten_node.in_port(0).connect(source) shape_node.in_port(0).connect(source) return [reshape_node.id]
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = float32_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace(node: Node, const: Node): graph = node.graph shape = const.shape const_name = const.soft_get('name', const.id) non_one_dims = np.argwhere(shape != 1).flatten() one_dims = np.argwhere(shape == 1).flatten() if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500): # (5;500) range is deduced to affect less models return value = const.value if not np.array_equal(np.arange(0, np.prod(shape), 1).reshape(shape), value): return positive_idx = non_one_dims.item(0) negative_idx = positive_idx - len(shape) node_name = node.soft_get('name', node.id) gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)}, {'name': node_name + '/BroadcastingDim'}) gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)}, {'name': const_name + '/BroadcastingDim'}) shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node() shapeof_node.out_port(0).connect(gather_for_const.in_port(0)) equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'}) gather.out_port(0).connect(equal_node.in_port(0)) select_node = Select(graph, {'name': node_name + '/Select', 'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather]) const.out_port(0).connect(shapeof_node.in_port(0)) range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(0, dtype=value.dtype), 2: mo_array(1, dtype=value.dtype)}, {'name': const_name + '/Range', 'dtype': value.dtype}) select_node.out_port(0).connect(range_node.in_port(1)) node.in_port(1).get_connection().add_destination(gather.in_port(0)) node.in_port(0).get_connection().set_source(range_node.out_port(0)) if one_dims.size: unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, one_dims, {'name': const_name + '/KeepShape'}) range_node.out_port(0).get_connection().insert_node(unsqueeze) rename_nodes([(const, const_name + '/ToBeDeleted'), (unsqueeze, const_name)]) else: rename_nodes([(const, const_name + '/ToBeDeleted'), (range_node, const_name)])
def replace_sub_graph(self, graph: Graph, match: dict): mxreshape = match['op'] if not mxreshape.reverse: return shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node() forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardUnsqueeze')) forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node() forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardSqueeze')) reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node() shape_node.in_port(0).connect(mxreshape.in_port(0).get_source()) mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0)) forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0)) forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0)) forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0)) reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0)) reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)), dict(name=str(mxreshape.id) + '/ReshapeShape')) if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0): reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape', dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node() reshape_shape_node.in_port(0).connect(reshape_node.out_port(0)) backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node() backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardUnsqueeze')) backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node() backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardSqueeze')) backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node() backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0)) backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0)) backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0)) backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0)) mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def replace_pattern(graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() second_input_shape = node.in_port(1).data.get_shape() begin_mask = np.zeros(len(input_shape), dtype=np.int64) end_mask = np.zeros(len(input_shape), dtype=np.int64) for i in node.axes: end_mask[i] = np.int64(1) new_axis_mask = np.zeros(len(input_shape), dtype=np.int64) shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64) ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64) ss = create_op_with_const_inputs(graph, StridedSlice, port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)}, op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': new_axis_mask, 'shrink_axis_mask': shrink_axis_mask, 'ellipsis_mask': ellipsis_mask}) if input_shape.size == second_input_shape.size: end = Shape(graph, dict(name=name + '/End')).create_node() end.in_port(0).connect(node.in_port(1).get_source()) ss.in_port(2).connect(end.out_port(0)) else: shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source()) end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True) if input_shape.size - 1 == node.axes[-1]: ss.in_port(2).connect(end_first_part.out_port(0)) else: shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source()) end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False, include_end=True) end = new_shape_node_from_shape_nodes([end_first_part, end_second_part]) ss.in_port(2).connect(end.out_port(0)) node.in_port(0).get_connection().set_destination(ss.in_port(0)) node.in_port(1).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) assert node.has_valid('output_type'), \ 'Rank node should have `output_type` attribute, but it`s not for node {}'.format(name) shape_of = Shape(graph, { 'name': name + '/shape_of', 'output_type': node.output_type }).create_node() rank_1d = Shape(graph, { 'name': name + '/rank_of', 'output_type': node.output_type }).create_node() rank_0d = create_op_node_with_second_input( graph, Squeeze, int64_array(0), {'name': name + '/0d_rank_of'}, rank_1d) shape_of.out_port(0).connect(rank_1d.in_port(0)) node.out_port(0).get_connection().set_source(rank_0d.out_port(0)) node.in_port(0).get_connection().set_destination(shape_of.in_port(0)) rename_nodes([(node, name + '/ToBeDeleted'), (rank_0d, name)])
def get_shape_and_rank_nodes_by_port(port: Port, return_as_a_scalar: bool = True): """ The function returns nodes producing shape and rank of the data from the desired port in order to use those operations on the middle/back phase :param port: Port object that specifies node output port :param return_as_a_scalar: boolean flag to return 1D or 0D rank :return: shape and rank nodes """ input_node_name = port.node.soft_get('name', port.node.id) graph = port.node.graph shape = Shape(graph, dict(name=input_node_name + '/ShapeOf')).create_node() rank_1_d = Shape(graph, dict(name=input_node_name + '/1dRankOf')).create_node() rank_1_d.in_port(0).connect(shape.out_port(0)) shape.in_port(0).connect(port) if not return_as_a_scalar: return shape, rank_1_d rank = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': input_node_name + '/0dRankOf'}, rank_1_d) return shape, rank
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str): resize_name = resize.soft_get('name', resize.id) log.debug( "Converting of {} to Interpolate-4 is triggered for node {}.".format( resize.op, resize_name)) num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs == 2, \ "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name) attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \ "the attribute align_corners must be False" assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \ attrs_msg.format(resize_name, resize.op) shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node() ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([1]), 2: int64_array([3]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float.out_port(0).connect(div_node.in_port(0)) shape_to_float.out_port(0).connect(div_node.in_port(1)) ss.out_port(0).connect(shape_to_float.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) align_corners = resize.align_corners half_pixel_centers = resize.half_pixel_centers nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor' if align_corners: coordinate_transformation_mode = 'align_corners' if interpolation_mode == 'nearest': nearest_mode = 'round_prefer_ceil' elif half_pixel_centers: coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel' else: coordinate_transformation_mode = 'asymmetric' interpolate4 = create_op_with_const_inputs( graph, Interpolate, {3: int64_array([1, 2])}, { 'name': resize_name + '/interpolate_4', 'mode': interpolation_mode, 'antialias': False, 'coordinate_transformation_mode': coordinate_transformation_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'nearest_mode': nearest_mode, 'cube_coeff': -0.75, 'shape_calculation_mode': 'sizes', 'version': 'opset4', 'in_ports_count': 4, }) resize_input_connection = resize.in_port(0).get_connection() resize_input_connection.set_destination(interpolate4.in_port(0)) resize_input_connection.get_source().connect(shape.in_port(0)) div_node.out_port(0).connect(interpolate4.in_port(2)) sizes_connection = resize.in_port(1).get_connection() sizes_connection.set_destination(interpolate4.in_port(1)) sizes_connection.get_source().connect(size_to_float.in_port(0)) resize.out_port(0).get_connection().set_source(interpolate4.out_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) axis = node.axis input_shape_node = Shape(graph, { 'name': name + '/ShapeOf' }).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: mo_array(node.start), 2: mo_array(node.step) }, {'name': name + '/Range'}) node.in_port(0).get_connection().set_destination( input_shape_node.in_port(0)) if axis is not None: ''' Replace arange_like op to subgraph: Shape - Gather - Range ''' gather_node = create_op_with_const_inputs(graph, Gather, { 1: int64_array([axis]), 2: int64_array(0) }, {'name': name + '/Gather'}) input_shape_node.out_port(0).connect(gather_node.in_port(0)) gather_node.out_port(0).connect(range_node.in_port(1)) node.out_port(0).get_connection().set_source( range_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)]) else: r''' Replace arange_like op to subgraph: | ShapeOf ----------- | | | ReduceProd | | | Range | | | Reshape ----------- | | ''' flattened_shape_node = create_op_with_const_inputs( graph, ReduceProd, {1: int64_array([0])}, { 'name': input_shape_node.name + '/ReduceProd', 'keep_dims': True }) reshape_backward_node = Reshape(graph, { 'name': name + '/Reshape_backward' }).create_node() input_shape_node.out_port(0).connect( flattened_shape_node.in_port(0)) flattened_shape_node.out_port(0).connect(range_node.in_port(1)) range_node.out_port(0).connect(reshape_backward_node.in_port(0)) input_shape_node.out_port(0).connect( reshape_backward_node.in_port(1)) node.out_port(0).get_connection().set_source( reshape_backward_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)]) if node.repeat != 1: r""" First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1. Then repeats each value of the interval using Tile. After that we can get a longer interval so we reduce it with Slice. Sub-graph after Range node will be look like Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice """ if node.repeat < 1: raise Error( "Unexpected value {} of the attribute 'repeat' for the node {}" .format(node.repeat, name)) div_node = create_op_with_const_inputs( graph, Div, {1: int64_array([node.repeat])}, {'name': name + '/Divide'}) add_node = create_op_with_const_inputs( graph, Add, {1: int64_array([1])}, {'name': div_node.name + '/Add'}) cast_node = Cast(graph, { 'name': name + '/ConvertToI64', 'dst_type': np.int64 }).create_node() cast_node.out_port(0).connect(div_node.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) range_node.in_port(1).get_connection().set_destination( cast_node.in_port(0)) add_node.out_port(0).connect(range_node.in_port(1)) tile_forward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1, 1])}, {'name': range_node.name + '/ForwardReshape'}) tile = create_op_with_const_inputs( graph, Tile, {1: int64_array([1, node.repeat])}, {'name': tile_forward_reshape.name + '/Tile'}) tile_backward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1])}, {'name': tile.name + '/BackwardReshape'}) slice_node = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': tile_backward_reshape.name + '/Slice'}) tile_forward_reshape.out_port(0).connect(tile.in_port(0)) tile.out_port(0).connect(tile_backward_reshape.in_port(0)) tile_backward_reshape.out_port(0).connect(slice_node.in_port(0)) slice_node.in_port(2).connect(div_node.in_port(0).get_source()) range_node.out_port(0).get_connection().set_source( slice_node.out_port(0)) range_node.out_port(0).connect(tile_forward_reshape.in_port(0)) if axis is not None: rename_nodes([(range_node, name + '/Range'), (slice_node, name)]) # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so # we have to correct the stop value for the Range node if step != 1 or start != 0 if node.step != 1: # If step attribute is not integer, we will generate an interval with a larger size and then reduce it # using Slice true_elements_count_port = range_node.in_port(1).get_source() mul_value = np.ceil(node.step) if node.step > 0 else np.floor( node.step) stop_value = create_op_with_const_inputs( graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))}, op_attrs={'name': range_node.name + '/Stop'}) range_node.in_port(1).get_connection().insert_node(stop_value) slice_range_values = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': range_node.name + '/Slice'}) slice_range_values.in_port(2).connect(true_elements_count_port) range_node.out_port(0).get_connection().insert_node( slice_range_values) if axis is not None and node.repeat == 1: rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)]) if node.start != 0: correct_stop_value = create_op_with_const_inputs( graph, Add, port_value_dict={1: mo_array(node.start)}, op_attrs={'name': range_node.name + '/Correct_Stop'}) range_node.in_port(1).get_connection().insert_node( correct_stop_value) # Range node supports only scalar inputs squeeze_node = create_op_with_const_inputs( graph, Squeeze, port_value_dict={1: int64_array(0)}, op_attrs={"name": range_node.name + '/Stop/Squeeze'}) range_node.in_port(1).get_connection().insert_node(squeeze_node)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] if 1 not in node.in_ports() or node.in_port(1).disconnected(): if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'): factor = Const(graph, {'value': np.array(node.factor)}).create_node() shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) else: shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) pads_value = node.pads_begin + node.pads_end pads_const = Const(graph, {'value': np.array(pads_value)}).create_node() add = Add(graph, {'name': node.name + '/pad_add'}).create_node() ss.out_port(0).connect(add.in_port(0)) add.in_port(1).connect(pads_const.out_port(0)) if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1: shrink_factor = node.shrink_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'name': node.name + '/pre_shrink_sub_const', 'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) sub.in_port(1).connect(const.out_port(0)) const = Const(graph, {'value': np.array(1 / shrink_factor), 'name': node.name + 'shrink_factor_div_const'}).create_node() div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) div.in_port(1).connect(const.out_port(0)) const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1) }).create_node() add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node() div.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() add.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1: zoom_factor = node.zoom_factor if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \ ' file (openvino/tools/mo/front/InterpolateNormalizer.py at the line {}).' \ ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100) # Reshape methods can be different in some cases # Commented out section represents reshape that used in deeplab-caffe # Uncomment the following lines, if your model was trained with deeplab-caffe # or have the same reshape method # const = Const(graph, {'value': np.array(-1), # 'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node() # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node() # add.out_port(0).connect(sub.in_port(0)) # const.out_port(0).connect(sub.in_port(1)) # # const = Const(graph, {'value': np.array(zoom_factor - 1), # 'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node() # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node() # sub.out_port(0).connect(mul.in_port(0)) # const.out_port(0).connect(mul.in_port(1)) # # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node() # add.out_port(0).connect(sum.in_port(0)) # mul.out_port(0).connect(sum.in_port(1)) # # node.add_input_port(1, skip_if_exist=True) # assert node.in_port(1).disconnected() # sum.out_port(0).connect(node.in_port(1)) # Comment out the following lines if you use the reshape method from previous section const = Const(graph, {'value': np.array(zoom_factor), 'name': node.name + '/zoom_factor_mul_const'}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() add.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) elif node.soft_get('width') != 0 and node.soft_get('height') != 0: const = Const(graph, {'value': np.array([node.height, node.width])}).create_node() node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() const.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1: shrink_factor = node.shrink_factor zoom_factor = node.zoom_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) const.out_port(0).connect(sub.in_port(1)) const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node() div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) const.out_port(0).connect(div.in_port(1)) const = Const(graph, {'value': np.array(-1), 'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node() sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) const.out_port(0).connect(sum.in_port(1)) const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() sum.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) mul.out_port(0).connect(sum.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() sum.out_port(0).connect(node.in_port(1)) else: if node.soft_get('fw') == 'caffe': shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(node.in_port(1))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node() initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, {'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node) initial_spatial_dims_node = Cast( graph, {'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0)) group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize'}).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node]) new_shape_node = Cast(graph, {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt'}).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)}, {'name': group_norm_node.name + '/Range', 'output_type': np.int64}) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None layout = graph.graph['layout'] if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[get_height_dim(layout, input_shape_rank)] width_scale = scales[get_width_dim(layout, input_shape_rank)] if len(scales) == 5: depth_scale = scales[get_depth_dim(layout, input_shape_rank)] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() upsample_name = upsample.soft_get('name', upsample.id) shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = float32_array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = float32_array( [depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) axes_node = Const(graph, { 'name': upsample_name + '/axis', 'value': axes }).create_node() interpolate = Interpolate( graph, { 'mode': upsample.attrs()['mode'], 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'shape_calculation_mode': 'scales', 'version': 'opset4', 'in_ports_count': 4 }).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(interpolate.in_port(1)) axes_node.out_port(0).connect(interpolate.in_port(3)) scales_node = Const(graph, { 'name': upsample_name + '/scales', 'value': factor_value }).create_node() scales_node.out_port(0).connect(interpolate.in_port(2)) upsample.in_port(0).get_connection().set_destination( interpolate.in_port(0)) upsample.out_port(0).get_connection().set_source( interpolate.out_port(0)) rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)]) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-10 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) resize_name = resize.soft_get('name', resize.id) rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, {'name': resize_name + '/axes'}) sizes_ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/sizes_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/scales_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) rank_node.out_port(0).connect(range_node.in_port(1)) interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest', 'coordinate_transformation_mode': 'asymmetric', 'cube_coeff': -0.75, 'nearest_mode': 'simple', 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': 'scales', 'in_ports_count': 4 }).create_node() range_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g., # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40. # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0], # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because # input_shape[2] * scales[2] + 1.0e-5 = 39.99991. # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)). add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) mul_node = Mul(graph, { 'name': resize_name + '/Mul' }).create_node([cast_shape_to_float, add_node]) floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node([mul_node]) cast_mul_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node([floor_node]) cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(1).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_resize_input.get_source().connect(rank_node.in_port(0)) connection_of_scales.get_source().connect(add_node.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['pool'] node_name = node.soft_get('name', node.id) if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values shape = Cast(graph, { 'name': node_name + '/to_float', 'dst_type': dst_dtype }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 1: float32_array([node.pool_stride]), 2: float32_array([1]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(3).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \ "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id), resize.op) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize.soft_get("name", resize.id)) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'sizes' if resize.is_in_port_connected( 3) else 'scales' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values if not resize.is_in_port_connected(3): cast_shape_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node_name = node.soft_get('name', node.id) node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format( node.name) bn_mean = Const( graph, { 'name': node_name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True }).create_node() bn_std = Const( graph, { 'name': node_name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True }).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape( graph, { 'name': node.in_port(0).get_source().node.soft_get('name') }).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) input_rank = len(node.in_port(0).data.get_shape()) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 1: int64_array(input_rank - 1), 2: int64_array(1) }, { 'name': node_name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'override_output_shape': True }).create_node() node.in_port(0).get_connection().insert_node(mvn) mvn.in_port(1).connect(rng.out_port(0)) reshape_4d = create_op_node_with_second_input( graph, Reshape, int64_array([1, -1, 0, 0]), { 'override_output_shape': True, 'name': node_name + '/fused_batch_and_channels' }) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, { 'name': node_name + '/restore_shape', 'override_output_shape': True }).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def replace_sub_graph(self, graph: Graph, match: dict): if not check_applicability(match): return reshape = match['reshape'] div_name = match['division'].name input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node() shape_of_reshape = reshape.in_port( 1).get_connection().get_source().node.value c1, c2 = shape_of_reshape[1], shape_of_reshape[2] c = c1 * c2 new_reshape = create_op_node_with_second_input( graph, Reshape, int64_array([0, 0, 0, c1, c2]), dict(name=div_name + '/first_reshape/MVN_T_')) permute_order = int64_array([0, 1, 2, 4, 3]) first_permute = create_op_node_with_second_input( graph, Transpose, permute_order, dict(name=div_name + '/first_permute/MVN_T_'), new_reshape) add = match['add'] variance = match['variance'] eps_port_num = 0 if add.in_port( 0).get_connection().get_source().node.id != variance.id else 1 eps = add.in_port(eps_port_num).get_connection().get_source().node mvn_node = create_op_with_const_inputs( graph, MVN, {1: int64_array([1, 2, 3])}, dict(name=div_name + '/MVN/MVN_T_', eps=eps.value, normalize_variance=1, eps_mode='inside_sqrt')) first_permute.out_port(0).connect(mvn_node.in_port(0)) second_permute = create_op_node_with_second_input( graph, Transpose, permute_order, dict(name=div_name + '/second_permute/MVN_T_'), mvn_node) new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node() second_permute.out_port(0).connect(new_reshape2.in_port(0)) gamma_val = np.reshape( match['gamma_identity'].in_port( 0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_mul = create_op_node_with_second_input( graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'), new_reshape2) beta_val = np.reshape( match['beta_identity'].in_port( 0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_add2 = create_op_node_with_second_input( graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'), new_mul) transpose_connection = match['transpose'].in_port(0).get_connection() before_transpose = transpose_connection.get_source().node transpose_connection.set_destination(new_reshape.in_port(0)) input_shape.out_port(0).connect(new_reshape2.in_port(1)) before_transpose.out_port(0).connect(input_shape.in_port(0)) match['transpose2'].out_port(0).get_connection().set_source( new_add2.out_port(0))
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node() end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node() stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node() priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 0, 0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes['variance'] except: raise Error('There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs( graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_') ).create_node([(split_node, 2), (split_node, 0)]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_') ).create_node([(split_node, 3), (split_node, 1)]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, {'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, {'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2} ).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer # correspond to original tensors and should be removed from output->Result edges out_nodes = [] for out in range(match.outputs_count()): out_nodes.append(match.output_node(out)[0]) clear_tensor_names_info(out_nodes) return {'detection_output_node': detection_output_node}
def replace_pattern(self, graph: Graph, match: [str, Node]): node = match['crop'] assert node.has_valid('axis') node_axis = self.list_to_ndarray(node.axis) in_shape = node.in_port(0).data.get_shape() shape_rank = in_shape.size axis_mask = int64_array( [1 if i in node_axis else 0 for i in range(shape_rank)]) begin_mask = axis_mask.copy() end_mask = axis_mask.copy() ss = StridedSlice( graph, { 'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.zeros(len(end_mask)), 'shrink_axis_mask': np.zeros(len(end_mask)), 'ellipsis_mask': np.zeros(len(end_mask)) }).create_node() if len(node.in_nodes()) == 2 and node.has_valid('offset'): # Crop Type 1 begin = Const( graph, { 'value': self.mask_normalizer(shape_rank, node_axis, node.offset), 'name': ss.name + '/begin' }).create_node() shape = Shape(graph, { 'name': ss.name + '/shape_of_crop' }).create_node() end = Add(graph, {'name': ss.name + '/end'}).create_node() node.in_port(1).get_connection().get_source().connect( shape.in_port(0)) node.in_port(1).disconnect() shape.out_port(0).connect(end.in_port(0)) begin.out_port(0).connect(end.in_port(1)) elif node.has_valid('dim') and node.has_valid('offset'): # Crop Type 2 node_dim = self.list_to_ndarray(node.dim) node_offset = self.list_to_ndarray(node.offset) assert node_dim.size == node_offset.size == node_axis.size begin = Const( graph, { 'value': self.mask_normalizer(shape_rank, node_axis, node_offset), 'name': ss.name + '/begin' }).create_node() end_values = mo_array( [node_offset[i] + node_dim[i] for i in range(len(node_dim))]) end = Const( graph, { 'value': self.mask_normalizer(shape_rank, node_axis, end_values), 'name': ss.name + '/end' }).create_node() elif node.has_valid('crop_begin') and node.has_valid('crop_end'): # Crop Type 3 node_crop_begin = self.list_to_ndarray(node.crop_begin) node_crop_end = self.list_to_ndarray(node.crop_end) assert len(node_crop_begin) == len(node_crop_end) == len(node_axis) begin = Const( graph, { 'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin), 'name': ss.name + '/begin' }).create_node() shape = Shape(graph, {'name': ss.name + '/shape'}).create_node() end = Add(graph, {'name': ss.name + '/end'}).create_node() const = Const( graph, { 'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end), 'name': ss.name + '/const' }).create_node() node.in_port(0).get_connection().get_source().connect( shape.in_port(0)) shape.out_port(0).connect(end.in_port(0)) const.out_port(0).connect(end.in_port(1)) else: raise Exception("Unknown type of Crop") source = node.in_port(0).get_connection().get_source() stride = Const( graph, { 'value': np.ones(shape_rank, dtype=np.int64), 'name': ss.name + '/stride' }).create_node() source.connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) node.in_port(0).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) ss['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
def create_const_with_batch_from_input(producer_port: Port, second_dim, value=0, precision=np.float32): """ Create const with batch taken from input_out_port and second dimension equals second_dim :param producer_port: take batch from this port :param second_dim: second dimension for created constant :param value: value to initialize constant :param precision: precision for constant :return created constant node """ graph = producer_port.node.graph input_name = producer_port.node.soft_get('name', producer_port.node.id) shape_of_input = None for dest in producer_port.get_destinations(): if dest.node.soft_get('op') == "ShapeOf": shape_of_input = dest.node break if shape_of_input is None: shape_of_input = Shape(graph, { 'name': input_name + '/Shape' }).create_node() shape_of_input.in_port(0).connect(producer_port) get_batch = None for dest in shape_of_input.out_port(0).get_destinations(): if dest.node.soft_get('op') == "Crop" and \ dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([1]): get_batch = dest.node break if get_batch is None: get_batch = create_op_node_with_second_input( graph, Crop, int64_array([1]), { 'name': shape_of_input.name + '/Crop', 'axis': int64_array([0]), 'offset': int64_array([0]) }, shape_of_input) mem_shape = None for dest in get_batch.out_port(0).get_destinations(): if dest.node.soft_get('op') == "Concat" and \ dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([second_dim]): mem_shape = dest.node break if mem_shape is None: mem_shape = create_op_node_with_second_input( graph, Concat, int64_array([second_dim]), { 'name': get_batch.name + '/Concat', 'axis': 0, 'in_ports_count': 2 }, get_batch) init_value_prev_lstm_output = None for dest in mem_shape.out_port(0).get_destinations(): if dest.node.soft_get('op') == "Broadcast" and \ dest.node.in_port(1).get_source().node.soft_get('value', []) == mo_array([value], dtype=precision): init_value_prev_lstm_output = dest.node break if init_value_prev_lstm_output is None: init_value_prev_lstm_output = create_op_with_const_inputs( graph, Broadcast, {0: mo_array([value], dtype=precision)}, {'name': mem_shape.name + '/Broadcast'}) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes( op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination( transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect( split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, { 'name': node.name + '/rank_0' }).create_node() in1_shape = Shape(graph, { 'name': node.name + '/rank_1' }).create_node() diff_size = Sub(graph, { 'name': node.name + '/sub_0' }).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, { 'value': int64_array([1]) }).create_node() const_pad_val = Const(graph, { 'value': int64_array(1) }).create_node() block_shape = Pad(graph, { 'name': node.name + '/aligned_block_shape', 'mode': 'constant' }).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, { 'name': node.name + new_begin_name, 'mode': 'constant' }).create_node() end = Pad(graph, { 'name': node.name + new_end_name, 'mode': 'constant' }).create_node() in0_rank_1d = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_shape.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_shape.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): name_of_input_to_node = input_to_node.name node.in_port(idx + 1).get_connection().set_destination( input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1)) convert = Cast(graph, { 'name': name_of_input_to_node + '/i64', 'dst_type': np.int64 }).create_node() input_to_node.in_port(0).get_connection().insert_node(convert)
def replace_pattern(self, graph: Graph, match: dict): node = match['pb'] name = node.soft_get('name', node.id) graph.graph['cmd_params'].static_shape = False assert len(node.in_ports()) == 2 begin = Const(graph, { 'value': mo_array([2], dtype=np.int32), 'name': name + '/ss_begin' }).create_node() end = Const(graph, { 'value': mo_array([4], dtype=np.int32), 'name': name + '/ss_end' }).create_node() stride = Const(graph, { 'value': mo_array([1], dtype=np.int32), 'name': name + '/ss_stride' }).create_node() shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node() ss_0 = StridedSlice( graph, { 'name': name + '/ss_0_port', 'begin_mask': mo_array([1], dtype=np.int32), 'end_mask': mo_array([0], dtype=np.int32), 'new_axis_mask': mo_array([0], dtype=np.int32), 'shrink_axis_mask': mo_array([0], dtype=np.int32), 'ellipsis_mask': mo_array([0], dtype=np.int32) }).create_node() shape_0.out_port(0).connect(ss_0.in_port(0)) begin.out_port(0).connect(ss_0.in_port(1)) end.out_port(0).connect(ss_0.in_port(2)) stride.out_port(0).connect(ss_0.in_port(3)) source = node.in_port(0).get_connection().get_source() node.in_port(0).disconnect() source.connect(shape_0.in_port(0)) ss_0.out_port(0).connect(node.in_port(0)) shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node() ss_1 = StridedSlice( graph, { 'name': name + '/ss_1_port', 'begin_mask': mo_array([1], dtype=np.int32), 'end_mask': mo_array([0], dtype=np.int32), 'new_axis_mask': mo_array([0], dtype=np.int32), 'shrink_axis_mask': mo_array([0], dtype=np.int32), 'ellipsis_mask': mo_array([0], dtype=np.int32) }).create_node() shape_1.out_port(0).connect(ss_1.in_port(0)) begin.out_port(0).connect(ss_1.in_port(1)) end.out_port(0).connect(ss_1.in_port(2)) stride.out_port(0).connect(ss_1.in_port(3)) source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape_1.in_port(0)) ss_1.out_port(0).connect(node.in_port(1)) ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'} ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'} node['need_shape_inference'] = True node['override_output_shape'] = True node['V10_infer'] = True unsqueeze = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'}) naked_priorbox_name = name + '/naked_not_unsqueezed' rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)]) node.out_port(0).get_connection().set_source(unsqueeze.out_port(0)) node.out_port(0).connect(unsqueeze.in_port(0))