def make_interpolate_reshape_able(self, interpolate: Node, concat: Node): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' interp_axes = Interpolate.get_axes(interpolate) concat_axis = self.get_concat_axis(concat) if concat_axis is None or interp_axes is None \ or np.any(interp_axes < 0) or concat_axis < 0 \ or concat_axis in interp_axes: # checks that interpolate axes and concat axis are valid and do not intersect return non_interp_concat_srcs = self.get_non_interpolate_concat_sources( concat) if not len(non_interp_concat_srcs): # there is no Concat input to take input from return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, input_node=shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() reshape_to_1d.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def extract(cls, node: Node): Shape.update_node_stat( node, { 'output_type': tf_dtype_extractor(node.pb.attr['out_type'].type, np.int32) }) return cls.enabled
def replace_sub_graph(self, graph: Graph, match: dict): div_sqrt = match['op'] div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id) shape_node = Shape(graph, dict(name=div_sqrt_name + '/Shape')).create_node() data_out_port = div_sqrt.in_port(0).get_source() shape_node.in_port(0).connect(data_out_port) shape_values_node = node_to_get_shape_value_of_indices( shape_node=shape_node, indices=[-1]) pow_node = AttributedPower( graph, dict(name=div_sqrt_name + '/Sqrt', power=mo_array(0.5))).create_node() # Due to specification, Power must have inputs with the same data type. convert_pow_input = Cast( graph, dict(dst_type=np.float32, name=shape_values_node.name + '/ConvertToFP32')).create_node() div_node = Div(graph, dict(name="Div")).create_node() shape_values_node.out_port(0).connect(convert_pow_input.in_port(0)) convert_pow_input.out_port(0).connect(pow_node.in_port(0)) div_sqrt.in_port(0).get_connection().set_destination( div_node.in_port(0)) div_node.in_port(1).connect(pow_node.out_port(0)) div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0)) rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def squeeze_initial_states(graph: Graph, match: dict): """ Squeeze input initial states of recurrent node to 2-D shape. """ hidden_init_port = 5 cell_init_port = 6 rnn_layer = match['rnn_layer'] # Add input ports to rnn_layer rnn_layer.add_sequence_of_ports(type='in', rng=range(7)) rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id) assert hidden_init_port in rnn_layer.in_nodes() hidden_size = rnn_layer.hidden_size shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node() rnn_layer.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim])) new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]), op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim', in_ports_count=2, axis=0), input_node=batch) reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_h.in_port(1)) rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h) if rnn_layer.op == 'LSTM': assert cell_init_port in rnn_layer.in_nodes() reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_c.in_port(1)) rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
def make_interpolate_reshapeable(interpolate, concat): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' output_shape = interpolate.out_port(0).data.get_shape() interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)] concat_axis = get_canonical_axis_index(output_shape, concat.axis) if concat_axis in interp_axes: return concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()] non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] if len(non_interp_concat_srcs) == 0: return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format( name) axis = node.axis reshape_node = Reshape(graph, { 'name': node.id + '/Reshape' }).create_node() if axis == 0: dim = Const( graph, { 'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape' }).create_node() elif axis == 1: dim = Const( graph, { 'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape' }).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() idxs = list(range(axis)) if axis > 0 else list(range(axis, 0)) axis_shape_portion = node_to_get_shape_value_of_indices( shape, idxs) first_dims = create_op_node_with_second_input( graph, ReduceProd, int64_array([0]), { 'name': name + '/first_dims', 'keep_dims': True }) second_dims = Const(graph, { 'value': int64_array([-1]), 'name': name + '/second_dims' }).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) order_of_dims = [first_dims, second_dims ] if axis > 0 else [second_dims, first_dims] dim = new_shape_node_from_shape_nodes(order_of_dims) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0))
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def find_and_replace_pattern(self, graph: Graph): reverse_nodes = graph.get_op_nodes(op='Reverse') for reverse in reverse_nodes: reverse_name = reverse.soft_get('name', reverse.id) assert reverse.in_port(1).disconnected() assert reverse.has_valid('axis') in_shape_rank = len(reverse.in_port(0).data.get_shape()) # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis if in_shape_rank == 1: unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': reverse_name+"/Unsqueeze"}) reverse.in_port(0).get_source().connect(unsq_node.in_port(0)) new_in = unsq_node.out_port(0) batch_axis = 0 seq_axis = 1 else: new_in = reverse.in_port(0).get_source() seq_axis = reverse['axis'] batch_axis = 0 if seq_axis != 0 else 1 # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as # shape[seq_axis] broadcasted to shape[batch_axis] # in ---> ShapeOf ----> Gather(seq_axis) ----> Broadcast-----> # | | # | -------> Gather(batch_axis)----------| shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node() new_in.connect(shape_node.in_port(0)) seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis]) batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis]) broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node() broadcast_node.in_port(0).connect(seq_axis_node.out_port(0)) broadcast_node.in_port(1).connect(batch_node.out_port(0)) # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it rename_node(reverse, reverse_name + '/to_delete') reverse_sequence = ReverseSequence(graph, {'name': reverse_name, 'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node() reverse_sequence.in_port(0).connect(new_in) reverse_sequence.in_port(1).connect(broadcast_node.out_port(0)) # 4. remove added dimension for rank = 1 if in_shape_rank == 1: rename_node(reverse_sequence, reverse_name + '/ReverseSequence') squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), {'name': reverse_name}) squeeze_node.in_port(0).connect(reverse_sequence.out_port(0)) reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0)) else: reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0)) # 5. Delete old Reverse node graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
def replace_pattern(self, graph: Graph, match: dict): if not self.is_applicable(match): return unsqueeze_node = match['unsqueeze'] unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id) second_input_of_unsqueeze = unsqueeze_node.in_port( 1).get_connection().get_source().node d_idx = int(second_input_of_unsqueeze.value) axis = d_idx - 1 shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape')).create_node() axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis]) second_input_of_tile = match['tile'].in_port( 1).get_connection().get_source().node scale = int64_array([second_input_of_tile.value[d_idx]]) float_scale = float32_array([second_input_of_tile.value[d_idx]]) mul_node = create_op_with_const_inputs( graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'}) axis_len_node.out_port(0).connect(mul_node.in_port(0)) interp_node = create_op_with_const_inputs( graph, Interpolate, { 2: float_scale, 3: int64_array([axis]) }, { 'mode': 'nearest', 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'version': 'opset4', 'shape_calculation_mode': 'scales', 'in_ports_count': 4, 'maybe_part_of_sequence': True }) mul_node.out_port(0).connect(interp_node.in_port(1)) reshape_node = match['reshape'] reshape_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) reshape_name = reshape_node.soft_get('name', reshape_node.id) rename_nodes([(reshape_node, reshape_name + '/delete'), (interp_node, reshape_name)]) unsqueeze_connection = unsqueeze_node.in_port(0).get_connection() unsqueeze_connection.set_destination(interp_node.in_port(0)) unsqueeze_connection.get_source().connect(shape_node.in_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = float32_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace(node: Node, const: Node): graph = node.graph shape = const.shape const_name = const.soft_get('name', const.id) non_one_dims = np.argwhere(shape != 1).flatten() one_dims = np.argwhere(shape == 1).flatten() if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500): # (5;500) range is deduced to affect less models return value = const.value if not np.array_equal(np.arange(0, np.prod(shape), 1).reshape(shape), value): return positive_idx = non_one_dims.item(0) negative_idx = positive_idx - len(shape) node_name = node.soft_get('name', node.id) gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)}, {'name': node_name + '/BroadcastingDim'}) gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)}, {'name': const_name + '/BroadcastingDim'}) shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node() shapeof_node.out_port(0).connect(gather_for_const.in_port(0)) equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'}) gather.out_port(0).connect(equal_node.in_port(0)) select_node = Select(graph, {'name': node_name + '/Select', 'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather]) const.out_port(0).connect(shapeof_node.in_port(0)) range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(0, dtype=value.dtype), 2: mo_array(1, dtype=value.dtype)}, {'name': const_name + '/Range', 'dtype': value.dtype}) select_node.out_port(0).connect(range_node.in_port(1)) node.in_port(1).get_connection().add_destination(gather.in_port(0)) node.in_port(0).get_connection().set_source(range_node.out_port(0)) if one_dims.size: unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, one_dims, {'name': const_name + '/KeepShape'}) range_node.out_port(0).get_connection().insert_node(unsqueeze) rename_nodes([(const, const_name + '/ToBeDeleted'), (unsqueeze, const_name)]) else: rename_nodes([(const, const_name + '/ToBeDeleted'), (range_node, const_name)])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') is_packed = False if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): is_packed = True embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node() else: embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node() node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1)) node.out_port(0).get_connection().set_source(embedding_bag.out_port(0)) if len(node.in_ports()) == 4 and not node.in_port(3).disconnected(): if is_packed: node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2)) else: # connect per_sample_weights node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4)) weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node() weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node() last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1) weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node) weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0]) zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])}, {'name': node_name + '/Broadcast'}) zero_col_node.in_port(1).connect(weights_last_dim.out_port(0)) default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)}, {'name': node_name + '/Unsqueeze'}) default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0)) # expand embedding table with zeros weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2, 'name': node_name + '/Concat'}).create_node() embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0)) weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0)) weights_concat.in_port(1).connect(default_embeddings_node.out_port(0)) weights_concat.out_port(0).connect(embedding_bag.in_port(0)) # point default index to expanded part of embedding table weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
def resolve_minus2(self, shape_node, input_index, reshape_index, dims): rank_node = Shape( shape_node.graph, dict(name=shape_node.id + '/RankShapeMXReshapeMinus2')).create_node() rank_node.in_port(0).connect(shape_node.out_port(0)) shape_values_node = get_shape_values_by_range_idxs(shape=shape_node, rank=rank_node, begin=input_index, end=-1, include_begin=True, include_end=True) input_index = None reshape_index = reshape_index + 1 return input_index, reshape_index, dims, shape_values_node
def replace_sub_graph(self, graph: Graph, match: dict): mxreshape = match['op'] if not mxreshape.reverse: return shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node() forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardUnsqueeze')) forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node() forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardSqueeze')) reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node() shape_node.in_port(0).connect(mxreshape.in_port(0).get_source()) mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0)) forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0)) forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0)) forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0)) reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0)) reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)), dict(name=str(mxreshape.id) + '/ReshapeShape')) if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0): reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape', dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node() reshape_shape_node.in_port(0).connect(reshape_node.out_port(0)) backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node() backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardUnsqueeze')) backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node() backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardSqueeze')) backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node() backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0)) backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0)) backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0)) backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0)) mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def make_interpolate_reshapeable(interpolate): assert interpolate.soft_get('type') == 'Interpolate' axes = Interpolate.get_axes(interpolate) input_shape = interpolate.in_port(0).data.get_shape() output_shape = interpolate.out_port(0).data.get_shape() if not np.all(np.remainder(output_shape, input_shape) == 0) and \ not np.all(np.remainder(input_shape, output_shape) == 0): return graph = interpolate.graph name = interpolate.soft_get('name', interpolate.id) shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() shape.in_port(0).connect(interpolate.in_port(0).get_source()) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) multipliers = output_shape[axes] / input_shape[axes] mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): tf_slice_node = match['op'] slice_name = tf_slice_node.soft_get('name', tf_slice_node.id) slice_node = Slice(graph).create_node() rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node() # reconnect input, begin, and size from TFSlice to the subgraph with Slice tf_slice_node.in_port(0).get_connection().set_destination( slice_node.in_port(0)) tf_slice_node.in_port(1).get_connection().set_destination( slice_node.in_port(1)) tf_slice_node.in_port(2).get_connection().set_destination( ends_node.in_port(0)) slice_node.in_port(1).get_connection().add_destination( ends_node.in_port(1)) max_ends = Shape(graph, { 'name': slice_name + '/ShapeOf' }).create_node() slice_node.in_port(0).get_connection().add_destination( max_ends.in_port(0)) # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank where_max_ends_is_needed = create_op_with_const_inputs( graph, Equal, {0: int64_array(-1)}, {'name': slice_name + '/where_max_ends_is_needed'}) ends_node.in_port(0).get_connection().add_destination( where_max_ends_is_needed.in_port(1)) # select requires equal dtypes, need to convert ends to I64 ends_casted_to_i64 = Cast(graph, { 'name': slice_name + '/CastToI64', 'dst_type': np.int64 }).create_node([ends_node]) # if size[i] == 1 then take max_ends values correct_ends = Select(graph, { 'name': slice_name + '/chosen_ends' }).create_node( [where_max_ends_is_needed, max_ends, ends_casted_to_i64]) correct_ends.out_port(0).connect(slice_node.in_port(2)) tf_slice_node.out_port(0).get_connection().set_source( slice_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name) assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node() middle_dim = Const(graph, {'value': int64_array([-1])}).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'The node "{}" does not have mandatory attribute "axis"'.format( node_name) flatten_node = FlattenONNX(graph, { 'name': node_name + '/FlattenONNX_', 'axis': node.axis }).create_node() shape_node = Shape(graph, { 'name': node_name + '/ShapeOf_' }).create_node() logsoftmax_node = LogSoftmax(graph, { 'name': node_name + '/LogSoftmax_', 'axis': 1 }).create_node() reshape_node = Reshape(graph, {}).create_node() rename_nodes([(node, node_name + '/delete'), (reshape_node, node_name)]) shape_node.out_port(0).connect(reshape_node.in_port(1)) logsoftmax_node.out_port(0).connect(reshape_node.in_port(0)) flatten_node.out_port(0).connect(logsoftmax_node.in_port(0)) source = node.in_port(0).get_source() flatten_node.in_port(0).connect(source) shape_node.in_port(0).connect(source) return [reshape_node.id]
def replace_pattern(self, graph: Graph, match: dict): matmul = match['matmul'] reshape = match['reshape'] other_input_port_idx = 0 if match['matmul'].in_port( 0).get_source().node.id == match['other_input'].id else 1 shape_source = match['matmul'].in_port( other_input_port_idx).get_source() initial_reshape_pattern = reshape.in_port(1).data.get_value() if len(initial_reshape_pattern) != 2: return reshape_is_A_input = matmul.in_port( 0).get_source().node.id == reshape.id if reshape_is_A_input: idx = -1 if matmul.transpose_b else -2 else: idx = -2 if matmul.transpose_a else -1 idx = get_canonical_axis_index(initial_reshape_pattern, idx) shape_name = shape_source.node.soft_get('name', shape_source.node.id) shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node() shape.in_port(0).connect(shape_source) C = node_to_get_shape_value_of_indices(shape, [idx]) N = Const(graph, { 'name': shape_name + '/MinusOne', 'value': int64_array([-1]) }).create_node() if len(initial_reshape_pattern) == 2: if reshape_is_A_input: reshape_pattern = [C, N] if matmul.transpose_a else [N, C] else: reshape_pattern = [N, C] if matmul.transpose_b else [C, N] new_reshape_pattern = new_shape_node_from_shape_nodes( reshape_pattern) reshape.in_port(1).get_connection().set_source( new_reshape_pattern.out_port(0)) else: return
def get_shape_and_rank_nodes_by_port(port: Port, return_as_a_scalar: bool = True): """ The function returns nodes producing shape and rank of the data from the desired port in order to use those operations on the middle/back phase :param port: Port object that specifies node output port :param return_as_a_scalar: boolean flag to return 1D or 0D rank :return: shape and rank nodes """ input_node_name = port.node.soft_get('name', port.node.id) graph = port.node.graph shape = Shape(graph, dict(name=input_node_name + '/ShapeOf')).create_node() rank_1_d = Shape(graph, dict(name=input_node_name + '/1dRankOf')).create_node() rank_1_d.in_port(0).connect(shape.out_port(0)) shape.in_port(0).connect(port) if not return_as_a_scalar: return shape, rank_1_d rank = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': input_node_name + '/0dRankOf'}, rank_1_d) return shape, rank
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) assert node.has_valid('output_type'), \ 'Rank node should have `output_type` attribute, but it`s not for node {}'.format(name) shape_of = Shape(graph, { 'name': name + '/shape_of', 'output_type': node.output_type }).create_node() rank_1d = Shape(graph, { 'name': name + '/rank_of', 'output_type': node.output_type }).create_node() rank_0d = create_op_node_with_second_input( graph, Squeeze, int64_array(0), {'name': name + '/0d_rank_of'}, rank_1d) shape_of.out_port(0).connect(rank_1d.in_port(0)) node.out_port(0).get_connection().set_source(rank_0d.out_port(0)) node.in_port(0).get_connection().set_destination(shape_of.in_port(0)) rename_nodes([(node, name + '/ToBeDeleted'), (rank_0d, name)])
def replace_pattern(graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() second_input_shape = node.in_port(1).data.get_shape() begin_mask = np.zeros(len(input_shape), dtype=np.int64) end_mask = np.zeros(len(input_shape), dtype=np.int64) for i in node.axes: end_mask[i] = np.int64(1) new_axis_mask = np.zeros(len(input_shape), dtype=np.int64) shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64) ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64) ss = create_op_with_const_inputs(graph, StridedSlice, port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)}, op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': new_axis_mask, 'shrink_axis_mask': shrink_axis_mask, 'ellipsis_mask': ellipsis_mask}) if input_shape.size == second_input_shape.size: end = Shape(graph, dict(name=name + '/End')).create_node() end.in_port(0).connect(node.in_port(1).get_source()) ss.in_port(2).connect(end.out_port(0)) else: shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source()) end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True) if input_shape.size - 1 == node.axes[-1]: ss.in_port(2).connect(end_first_part.out_port(0)) else: shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source()) end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False, include_end=True) end = new_shape_node_from_shape_nodes([end_first_part, end_second_part]) ss.in_port(2).connect(end.out_port(0)) node.in_port(0).get_connection().set_destination(ss.in_port(0)) node.in_port(1).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes( op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination( transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect( split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, { 'name': node.name + '/rank_0' }).create_node() in1_shape = Shape(graph, { 'name': node.name + '/rank_1' }).create_node() diff_size = Sub(graph, { 'name': node.name + '/sub_0' }).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, { 'value': int64_array([1]) }).create_node() const_pad_val = Const(graph, { 'value': int64_array(1) }).create_node() block_shape = Pad(graph, { 'name': node.name + '/aligned_block_shape', 'mode': 'constant' }).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, { 'name': node.name + new_begin_name, 'mode': 'constant' }).create_node() end = Pad(graph, { 'name': node.name + new_end_name, 'mode': 'constant' }).create_node() in0_rank_1d = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_shape.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_shape.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): name_of_input_to_node = input_to_node.name node.in_port(idx + 1).get_connection().set_destination( input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1)) convert = Cast(graph, { 'name': name_of_input_to_node + '/i64', 'dst_type': np.int64 }).create_node() input_to_node.in_port(0).get_connection().insert_node(convert)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] if 1 not in node.in_ports() or node.in_port(1).disconnected(): if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'): factor = Const(graph, {'value': np.array(node.factor)}).create_node() shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) else: shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) pads_value = node.pads_begin + node.pads_end pads_const = Const(graph, {'value': np.array(pads_value)}).create_node() add = Add(graph, {'name': node.name + '/pad_add'}).create_node() ss.out_port(0).connect(add.in_port(0)) add.in_port(1).connect(pads_const.out_port(0)) if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1: shrink_factor = node.shrink_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'name': node.name + '/pre_shrink_sub_const', 'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) sub.in_port(1).connect(const.out_port(0)) const = Const(graph, {'value': np.array(1 / shrink_factor), 'name': node.name + 'shrink_factor_div_const'}).create_node() div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) div.in_port(1).connect(const.out_port(0)) const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1) }).create_node() add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node() div.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() add.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1: zoom_factor = node.zoom_factor if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \ ' file (openvino/tools/mo/front/InterpolateNormalizer.py at the line {}).' \ ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100) # Reshape methods can be different in some cases # Commented out section represents reshape that used in deeplab-caffe # Uncomment the following lines, if your model was trained with deeplab-caffe # or have the same reshape method # const = Const(graph, {'value': np.array(-1), # 'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node() # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node() # add.out_port(0).connect(sub.in_port(0)) # const.out_port(0).connect(sub.in_port(1)) # # const = Const(graph, {'value': np.array(zoom_factor - 1), # 'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node() # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node() # sub.out_port(0).connect(mul.in_port(0)) # const.out_port(0).connect(mul.in_port(1)) # # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node() # add.out_port(0).connect(sum.in_port(0)) # mul.out_port(0).connect(sum.in_port(1)) # # node.add_input_port(1, skip_if_exist=True) # assert node.in_port(1).disconnected() # sum.out_port(0).connect(node.in_port(1)) # Comment out the following lines if you use the reshape method from previous section const = Const(graph, {'value': np.array(zoom_factor), 'name': node.name + '/zoom_factor_mul_const'}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() add.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) elif node.soft_get('width') != 0 and node.soft_get('height') != 0: const = Const(graph, {'value': np.array([node.height, node.width])}).create_node() node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() const.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1: shrink_factor = node.shrink_factor zoom_factor = node.zoom_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) const.out_port(0).connect(sub.in_port(1)) const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node() div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) const.out_port(0).connect(div.in_port(1)) const = Const(graph, {'value': np.array(-1), 'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node() sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) const.out_port(0).connect(sum.in_port(1)) const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() sum.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) mul.out_port(0).connect(sum.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() sum.out_port(0).connect(node.in_port(1)) else: if node.soft_get('fw') == 'caffe': shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(node.in_port(1))
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node() end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node() stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node() priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 0, 0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes['variance'] except: raise Error('There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs( graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_') ).create_node([(split_node, 2), (split_node, 0)]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_') ).create_node([(split_node, 3), (split_node, 1)]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, {'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, {'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2} ).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer # correspond to original tensors and should be removed from output->Result edges out_nodes = [] for out in range(match.outputs_count()): out_nodes.append(match.output_node(out)[0]) clear_tensor_names_info(out_nodes) return {'detection_output_node': detection_output_node}
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) axis = node.axis input_shape_node = Shape(graph, { 'name': name + '/ShapeOf' }).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: mo_array(node.start), 2: mo_array(node.step) }, {'name': name + '/Range'}) node.in_port(0).get_connection().set_destination( input_shape_node.in_port(0)) if axis is not None: ''' Replace arange_like op to subgraph: Shape - Gather - Range ''' gather_node = create_op_with_const_inputs(graph, Gather, { 1: int64_array([axis]), 2: int64_array(0) }, {'name': name + '/Gather'}) input_shape_node.out_port(0).connect(gather_node.in_port(0)) gather_node.out_port(0).connect(range_node.in_port(1)) node.out_port(0).get_connection().set_source( range_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)]) else: r''' Replace arange_like op to subgraph: | ShapeOf ----------- | | | ReduceProd | | | Range | | | Reshape ----------- | | ''' flattened_shape_node = create_op_with_const_inputs( graph, ReduceProd, {1: int64_array([0])}, { 'name': input_shape_node.name + '/ReduceProd', 'keep_dims': True }) reshape_backward_node = Reshape(graph, { 'name': name + '/Reshape_backward' }).create_node() input_shape_node.out_port(0).connect( flattened_shape_node.in_port(0)) flattened_shape_node.out_port(0).connect(range_node.in_port(1)) range_node.out_port(0).connect(reshape_backward_node.in_port(0)) input_shape_node.out_port(0).connect( reshape_backward_node.in_port(1)) node.out_port(0).get_connection().set_source( reshape_backward_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)]) if node.repeat != 1: r""" First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1. Then repeats each value of the interval using Tile. After that we can get a longer interval so we reduce it with Slice. Sub-graph after Range node will be look like Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice """ if node.repeat < 1: raise Error( "Unexpected value {} of the attribute 'repeat' for the node {}" .format(node.repeat, name)) div_node = create_op_with_const_inputs( graph, Div, {1: int64_array([node.repeat])}, {'name': name + '/Divide'}) add_node = create_op_with_const_inputs( graph, Add, {1: int64_array([1])}, {'name': div_node.name + '/Add'}) cast_node = Cast(graph, { 'name': name + '/ConvertToI64', 'dst_type': np.int64 }).create_node() cast_node.out_port(0).connect(div_node.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) range_node.in_port(1).get_connection().set_destination( cast_node.in_port(0)) add_node.out_port(0).connect(range_node.in_port(1)) tile_forward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1, 1])}, {'name': range_node.name + '/ForwardReshape'}) tile = create_op_with_const_inputs( graph, Tile, {1: int64_array([1, node.repeat])}, {'name': tile_forward_reshape.name + '/Tile'}) tile_backward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1])}, {'name': tile.name + '/BackwardReshape'}) slice_node = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': tile_backward_reshape.name + '/Slice'}) tile_forward_reshape.out_port(0).connect(tile.in_port(0)) tile.out_port(0).connect(tile_backward_reshape.in_port(0)) tile_backward_reshape.out_port(0).connect(slice_node.in_port(0)) slice_node.in_port(2).connect(div_node.in_port(0).get_source()) range_node.out_port(0).get_connection().set_source( slice_node.out_port(0)) range_node.out_port(0).connect(tile_forward_reshape.in_port(0)) if axis is not None: rename_nodes([(range_node, name + '/Range'), (slice_node, name)]) # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so # we have to correct the stop value for the Range node if step != 1 or start != 0 if node.step != 1: # If step attribute is not integer, we will generate an interval with a larger size and then reduce it # using Slice true_elements_count_port = range_node.in_port(1).get_source() mul_value = np.ceil(node.step) if node.step > 0 else np.floor( node.step) stop_value = create_op_with_const_inputs( graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))}, op_attrs={'name': range_node.name + '/Stop'}) range_node.in_port(1).get_connection().insert_node(stop_value) slice_range_values = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': range_node.name + '/Slice'}) slice_range_values.in_port(2).connect(true_elements_count_port) range_node.out_port(0).get_connection().insert_node( slice_range_values) if axis is not None and node.repeat == 1: rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)]) if node.start != 0: correct_stop_value = create_op_with_const_inputs( graph, Add, port_value_dict={1: mo_array(node.start)}, op_attrs={'name': range_node.name + '/Correct_Stop'}) range_node.in_port(1).get_connection().insert_node( correct_stop_value) # Range node supports only scalar inputs squeeze_node = create_op_with_const_inputs( graph, Squeeze, port_value_dict={1: int64_array(0)}, op_attrs={"name": range_node.name + '/Stop/Squeeze'}) range_node.in_port(1).get_connection().insert_node(squeeze_node)
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str): resize_name = resize.soft_get('name', resize.id) log.debug( "Converting of {} to Interpolate-4 is triggered for node {}.".format( resize.op, resize_name)) num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs == 2, \ "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name) attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \ "the attribute align_corners must be False" assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \ attrs_msg.format(resize_name, resize.op) shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node() ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([1]), 2: int64_array([3]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float.out_port(0).connect(div_node.in_port(0)) shape_to_float.out_port(0).connect(div_node.in_port(1)) ss.out_port(0).connect(shape_to_float.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) align_corners = resize.align_corners half_pixel_centers = resize.half_pixel_centers nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor' if align_corners: coordinate_transformation_mode = 'align_corners' if interpolation_mode == 'nearest': nearest_mode = 'round_prefer_ceil' elif half_pixel_centers: coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel' else: coordinate_transformation_mode = 'asymmetric' interpolate4 = create_op_with_const_inputs( graph, Interpolate, {3: int64_array([1, 2])}, { 'name': resize_name + '/interpolate_4', 'mode': interpolation_mode, 'antialias': False, 'coordinate_transformation_mode': coordinate_transformation_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'nearest_mode': nearest_mode, 'cube_coeff': -0.75, 'shape_calculation_mode': 'sizes', 'version': 'opset4', 'in_ports_count': 4, }) resize_input_connection = resize.in_port(0).get_connection() resize_input_connection.set_destination(interpolate4.in_port(0)) resize_input_connection.get_source().connect(shape.in_port(0)) div_node.out_port(0).connect(interpolate4.in_port(2)) sizes_connection = resize.in_port(1).get_connection() sizes_connection.set_destination(interpolate4.in_port(1)) sizes_connection.get_source().connect(size_to_float.in_port(0)) resize.out_port(0).get_connection().set_source(interpolate4.out_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])