def concat_outputs(bi_rnn, forward_outputs, reverse_outputs, final_outputs): """ Concatenates two set of outputs from bidirectiondl RNNSequence nodes """ concat_ops = [ Concat(bi_rnn.graph, { 'name': bi_rnn.name + '/FinalConcat/Data', 'axis': 1, 'in_ports_count': 2, }), Concat(bi_rnn.graph, { 'name': bi_rnn.name + '/FinalConcat/HiddenState', 'axis': 0, 'in_ports_count': 2, }), Concat(bi_rnn.graph, { 'name': bi_rnn.name + '/FinalConcat/CellState', 'axis': 0, 'in_ports_count': 2, }) ] bi_rnn.graph.remove_node(bi_rnn.id) for i in final_outputs: concat_ops[i].create_node_with_data( [forward_outputs[i], reverse_outputs[i]], data_nodes=[final_outputs[i]] )
def extract(cls, node): pb = node.pb mapping_rule = { 'axis': pb.concat_param.axis, } Concat.update_node_stat(node, mapping_rule) return cls.enabled
def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def replace_with_split_concat(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order split = create_op_with_const_inputs(graph, Split, {1: int64_array(axis)}, { 'name': name + '/Split', 'num_splits': order.size }) concat = Concat(graph, { 'name': name + '/Concat', 'axis': axis, 'in_ports_count': order.size }).create_node() for out_port_idx, in_port_idx in enumerate(order): split.out_port(out_port_idx).connect(concat.in_port(in_port_idx)) node.out_port(0).get_connection().set_source(concat.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) graph.remove_node(node.id)
def concat_output_states(graph: Graph, match: dict, new_states: list): """ Concatenates output states from multilayer layer. """ rnn_layer = match['rnn_layer'] original_states = [ rnn_layer.out_node(i) if i in rnn_layer.out_nodes() else None for i in [1, 2] ] concat_ops = [ Concat( rnn_layer.graph, { 'name': rnn_layer.name + '/FinalLayerSplitConcat/HiddenState', 'axis': -1 }), Concat( rnn_layer.graph, { 'name': rnn_layer.name + '/FinalLayerSplitConcat/CellState', 'axis': -1 }) ] for i in range(len(original_states)): # [0] or [0, 1] if original_states[i] is None: continue concat_ops[i].attrs.update({'in_ports_count': len(new_states[i])}) concat_ops[i].create_node_with_data( inputs=new_states[i], data_nodes=[original_states[i]])
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) data = { 'axis': attrs.int("dim", 1), } # update the attributes of the node Concat.update_node_stat(node, data) return cls.enabled
def replace_op(self, graph: Graph, node: Node): out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports())}).create_node() pack_name = node.soft_get('name', node.id) for ind in node.in_ports(): unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array([node.axis])}, {'name': node.soft_get('name', node.id) + '/Unsqueeze'}) node.in_port(ind).get_connection().set_destination(unsqueeze_node.in_port(0)) unsqueeze_node.out_port(0).connect(out_node.in_port(ind)) rename_nodes([(node, pack_name + '/TBR'), (out_node, pack_name)]) return [out_node.id]
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name, 'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t]), dtype=np.float32), 'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node() memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': mo_array([in_shape[1]*(node_t-1)]), 'offset': mo_array([in_shape[1]]), 'axis': mo_array([1])}).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': mo_array([in_shape[1]]), 'offset': mo_array([0]), 'axis': mo_array([1])}).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): """ This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" :param graph: graph to operate on. :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. :param shape: input shape of the Slice :param axes: axes that "starts" and "ends" apply to :param node_name: Slice node name :return: Concat node that forms "begin"/"end" values for the StridedSlice """ # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is # supported by the StridedSlice layer clamp = create_op_with_const_inputs(graph, Clamp, port_value_dict={ 1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max }, op_attrs=dict(name=node_name + '/Clamp')) clamp.in_port(0).connect(slice_border_port) # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created # here to prevent type errors in Concat node cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() cast.in_port(0).connect(clamp.out_port(0)) concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() for value_idx, port_idx in enumerate(axes): concat.add_input_port(port_idx) # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct # Concat input port value = create_op_with_const_inputs( graph, Gather, port_value_dict={ 1: int64_array([value_idx]), 2: int64_array(0) }, op_attrs={'name': node_name + '/Gather'}) cast.out_port(0).connect(value.in_port(0)) value.out_port(0).connect(concat.in_port(port_idx)) for port_idx in range(len(shape)): if not concat.is_in_port_connected(port_idx): concat.add_input_port(port_idx) # This border value would be ignored in StridedSlice because of the begin_mask\end_mask const = Const( graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() const.out_port(0).connect(concat.in_port(port_idx)) return concat
def test_concat_op(self): graph = build_graph(self.nodes_attributes, [('node_1', 'concat_node'), ('concat_node', 'node_3')]) concat_node = Concat(graph, self.nodes_attributes['concat_node']).add_node() self.assertEqual(concat_node.type, 'Concat') self.assertEqual(concat_node.op, 'Concat') self.assertEqual(concat_node.infer, concat_infer)
def add_fake_background_loc(graph: Graph, input_node: Node): r""" DetectionOutput layer expects that box coordinates contains coordinates of boxes for the "background" class also, but in the TensorFlow\* Object Detection API the tensor contains information about real object classes only. The function copies a slice of the output data of the node 'input_node' and then concats it to the beginning of the data. The data in this slice is not used by the Detection Output layer so the actual values are not important. This approach allows the model to be reshape-able and does not introduce many layers. "background" class box coordinates. :param graph: graph to operate on. :param input_node: node producing the boxes coordinates. :return convolution node that adds slice of data for the "background" class. """ crop_op = Crop(graph, dict(axis=mo_array([1]), offset=mo_array([0]), dim=mo_array([1]), nchw_layout=True)) crop_node = crop_op.create_node([input_node], dict(name='crop_locs')) concat_op = Concat(graph, dict(axis=1, in_ports_count=2, nchw_layout=True)) return concat_op.create_node([crop_node, input_node], dict(name=input_node.id + '/locs_with_fake_background'))
def replace_op(self, graph: Graph, node: Node): if node.has_and_set('inputs_preprocessed'): log.debug('Node "{}" has already been preprocessed'.format( node.soft_get('name'))) return [] # reshape tensor with batch indices to 2d unsqueeze_node = create_op_node_with_second_input( graph, Unsqueeze, int64_array([1]), {'name': node.name + '/Unsqueeze'}, node.in_node(2)) convert_node = Cast( graph, { 'name': unsqueeze_node.name + '/ToFloat', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() convert_node.in_port(0).connect(unsqueeze_node.out_port(0)) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2 }) concat_node = concat_op.create_node([convert_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_node = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 5]), dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'), swapped_box_coordinates_node) graph.create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def replace_pattern(self, graph: Graph, match: dict): concat_node = match['concat'] sources_of_ports = [concat_node.in_port(i).get_connection().get_source() for i in concat_node.in_ports()] # If 'concat' is ConcatV2 layer from TF, then this layer initially had input 'axis' as the last input. # But then this input was deleted and the attribute 'axis' was added. Hence, the last port source can # be None in such case. sources_of_ports = [s for s in sources_of_ports if s is not None] input_nodes = [s.node for s in sources_of_ports] if not all(n.has_valid('type') for n in input_nodes): return saved_ports = [] disconnected_ports = [] for port_num, node in enumerate(input_nodes): if node.soft_get('type') == 'Const' and len(node.shape) > 1 and any(i == 0 for i in node.shape): disconnected_ports.append(port_num) else: saved_ports.append(port_num) if not saved_ports or not disconnected_ports: return if len(saved_ports) == 1: before_concat = concat_node.in_port(saved_ports[0]).get_connection().get_source() concat_node.out_port(0).get_connection().set_source(before_concat) return new_concat_attrs = concat_node.attrs().copy() new_concat_attrs['name'] = concat_node.name + '/Concat_' new_concat_attrs['in_ports_count'] = len(saved_ports) new_concat_node = Concat(graph, attrs=new_concat_attrs).create_node() for new_port_num, old_port_num in enumerate(saved_ports): concat_node.in_port(old_port_num).get_connection().set_destination(new_concat_node.in_port(new_port_num)) for p in disconnected_ports: concat_node.in_port(p).disconnect() concat_node.out_port(0).get_connection().set_source(new_concat_node.out_port(0))
def replace_tdnn(self, graph: Graph, tdnn_node: Node): tdnn_name = tdnn_node.soft_get('name', tdnn_node.id) concat_node = Concat(graph, {'axis': 1}).create_node() rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'), (concat_node, tdnn_name)]) for offset_ind, t in enumerate(tdnn_node['time_offsets']): concat_node.add_input_port(offset_ind) if t != 0: memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t)) memoryoffset_node = MemoryOffset( graph, { 'name': memory_name, 't': t, 'pair_name': memory_name + '_out', 'has_default': False, 'splitted': False }).create_node() tdnn_node.in_port(0).get_source().connect( memoryoffset_node.in_port(0)) memoryoffset_node.out_port(0).connect( concat_node.in_port(offset_ind)) else: # 0 time delay is not allowed in IE, it's meaningless # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset tdnn_node.in_port(0).get_source().connect( concat_node.in_port(offset_ind)) weights = tdnn_node['weights'] fc_inputs = {1: weights} bias_term = False if tdnn_node.has_valid('biases'): assert len(tdnn_node['biases']) == weights.shape[0] fc_inputs.update({2: tdnn_node['biases']}) bias_term = True fc_node = create_op_with_const_inputs( graph, FullyConnected, fc_inputs, { 'name': tdnn_name + '/FC', 'out-size': weights.shape[0], 'transpose_weights': True, 'bias_term': bias_term }) concat_node.out_port(0).connect(fc_node.in_port(0)) tdnn_node.in_port(0).disconnect() tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['node'] node_name = node.soft_get('name', node.id) connected_ports = [port for port in node.in_ports().values() if not port.disconnected()] if len(connected_ports) == 2: axis = node.in_port(1).data.get_value() else: axis = node.axis assert axis is not None, 'The "axis" should be defined for node "{}"'.format(node_name) assert node.has_and_set('output_type'), 'The data type is not set for node "{}"'.format(node_name) topk_mode = 'max' if node.op == 'ArgMax' else 'min' topk_node = TopK(graph, {'axis': axis, 'mode': topk_mode, 'sort': 'index', 'remove_values_output': node.has_and_set('remove_values_output'), 'index_element_type': node.output_type}).create_node() node.in_port(0).get_connection().set_destination(topk_node.in_port(0)) if node.has_and_set('out_max_val'): # in this mode the ArgMax produces tuples (max_ind, max_value) concat_node = Concat(graph, {'axis': 1, 'name': node.name + '/Concat'}).create_node() concat_node.add_input_port(0, skip_if_exist=True) concat_node.add_input_port(1, skip_if_exist=True) topk_node.out_port(0).connect(concat_node.in_port(1)) # indices topk_node.out_port(1).connect(concat_node.in_port(0)) # values if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source(concat_node.out_port(0)) else: if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source(topk_node.out_port(1)) topk_node.in_port(1).connect(Const(graph, {'name': node.soft_get('name') + '/TopK', 'value': node.top_k}).create_node().out_port(0)) graph.remove_nodes_from([node.id, node.out_node(0).id])
def fuse_reduces(first_reduce, second_reduce): first_reduce_name = first_reduce.soft_get('name', first_reduce.id) second_reduce_name = second_reduce.soft_get('name', second_reduce.id) reduce_type = first_reduce.type assert first_reduce.type == second_reduce.type if len(first_reduce.out_port(0).get_destinations()) != 1: # data dependency return if first_reduce.keep_dims != second_reduce.keep_dims: return first_axes = first_reduce.in_port(1).data.get_value() second_axes = second_reduce.in_port(1).data.get_value() if first_axes is None or second_axes is None: # dynamic axes merging is not supported return if not first_reduce.keep_dims: if not np.all(first_axes > second_axes): # indexing of upper reduce input dimensions changed return graph = second_reduce.graph new_axes = Concat( graph, { 'name': second_reduce_name + '/Axes', 'axis': int64_array(0), 'in_ports_count': 2, 'override_output_shape': True }).create_node() new_axes.in_port(0).connect(first_reduce.in_port(1).get_source()) new_axes.in_port(1).connect(second_reduce.in_port(1).get_source()) first_reduce.in_port( 0).get_source().node['need_shape_inference'] = True first_reduce.in_port( 0).get_source().node['override_output_shape'] = True second_reduce.in_port(1).get_connection().set_source( new_axes.out_port(0)) first_reduce.out_port(0).get_connection().set_source( first_reduce.in_port(0).get_connection().get_source()) first_reduce.in_port(1).disconnect() graph.remove_node(first_reduce.id) log.debug( '{0} nodes {1} and {2} were fused to a single {2} node with updated axes input' ''.format(reduce_type, first_reduce_name, second_reduce_name))
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def extend_inputs(node: Node, num_insertions: int): graph = node.graph node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros( num_insertions) if input_name != 'strides' else np.ones( num_insertions) blank_values_node = Const( graph, { 'name': node_name + '/extend_{}_const'.format(input_name), 'value': int64_array(blank_values_arr) }).create_node() if node.in_port(i).get_source().node.soft_get('type') == 'Concat': # concat already exists concat = node.in_port(i).get_source().node # because output data node shape will be changed # while shapes will be reinferred no need to check consistency concat['override_output_shape'] = True last_in_port = max(concat.in_ports().keys()) assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {} ' \ 'should be connected'. \ format(concat.soft_get('name', node.id)) concat.add_input_port(last_in_port + 1) concat.in_port(last_in_port + 1).connect( blank_values_node.out_port(0)) else: # have to create concat concat = Concat( graph, { 'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': 2 }).create_node() node.in_port(i).get_connection().set_destination( concat.in_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.out_port(0).get_connection().set_destination( node.in_port(i))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def new_shape_node_from_shape_nodes(input_shape_nodes: list): """ The function returns a node producing 1D tensor with concatenated shapes produced by nodes from "input_shape_nodes" :param input_shape_nodes: list of nodes producing 1D tensors :return: the node producing concatenated values of nodes from the "input_shape_nodes" """ assert len(input_shape_nodes ) > 0, 'The list of input shape nodes should be non-empty' new_shape_node = Concat( input_shape_nodes[0].graph, { 'axis': 0, 'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat' }).create_node() for ind, input_node in enumerate(input_shape_nodes): new_shape_node.add_input_port(ind) new_shape_node.in_port(ind).connect(input_node.out_port(0)) return new_shape_node
def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int): node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros( num_insertions) if input_name != 'strides' else np.ones( num_insertions) blank_values_node = Const( graph, { 'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name), 'value': int64_array(blank_values_arr) }).create_node() concat_in_ports_count = 3 if ellipsis_start != 0 else 2 concat = Concat( graph, { 'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': concat_in_ports_count }).create_node() if ellipsis_start != 0: split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([ellipsis_start, -1]) }, { 'name': node_name + '/split_for_{}_ellipsis'.format(input_name), 'out_ports_count': 2 }) node.in_port(i).get_connection().set_destination( split.in_port(0)) concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.in_port(2).connect(split.out_port(1)) else: concat.in_port(0).connect(blank_values_node.out_port(0)) node.in_port(i).get_connection().set_destination( concat.in_port(1)) concat.out_port(0).get_connection().set_destination( node.in_port(i))
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node() end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node() stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node() priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 0, 0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes['variance'] except: raise Error('There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs( graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_') ).create_node([(split_node, 2), (split_node, 0)]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_') ).create_node([(split_node, 3), (split_node, 1)]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, {'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, {'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2} ).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer # correspond to original tensors and should be removed from output->Result edges out_nodes = [] for out in range(match.outputs_count()): out_nodes.append(match.output_node(out)[0]) clear_tensor_names_info(out_nodes) return {'detection_output_node': detection_output_node}
def replace_timeheightconv(self, graph: Graph, node: Node): req_time_offsets = node.soft_get('time_offsets') offsets = node.soft_get("offsets", [[]]) all_time_offsets = list(set(offsets[:, 0])) all_time_offsets.sort() in_name = node.soft_get('name', node.id) rename_node(node, in_name + '/to_delete') # create memoryoffsets for context gathering # we need concat if time offsets more than 1 concat = Concat(graph, attrs={ 'name': in_name + '/Concat', 'in_ports_count': len(all_time_offsets) }).create_node() i = 0 for t in all_time_offsets: # if time offset included in required_time_offsets we don't need default value has_default = t not in req_time_offsets memoff = MemoryOffset(graph, attrs={ 'name': in_name + '/MemoryOffset_' + str(i), 't': t, 'has_default': has_default, 'splitted': False, 'pair_name': in_name + '/MemoryOffset_pair_' + str(i) }).create_node() concat.in_port(i).connect(memoff.out_port(0)) memoff.in_port(0).connect(node.in_port(0).get_source()) i = i + 1 stride = node.soft_get("height_subsample", 1) kernel = int64_array([0, 0]) kernel[0] = len(set(offsets[:, 0])) kernel[1] = len(set(offsets[:, 1])) pad_h = int64_array([0, 0]) pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0 pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0])) dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / ( kernel[0] - 1) if kernel[0] > 1 else 1 dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / ( kernel[1] - 1) if kernel[0] > 1 else 1 conv_attrs = { 'name': in_name, 'output': node['out_channels'], 'height_in': node.height_in, 'bias_term': None, 'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]), 'pad_spatial_shape': int64_array([[0, 0], pad_h]), 'dilation': int64_array([1, 1, dilation_t, dilation_h]), 'kernel': int64_array( [node.out_channels, node.in_channels, kernel[0], kernel[1]]), 'stride': int64_array([1, 1, 1, stride]), 'kernel_spatial': kernel, 'input_feature_channel': 1, 'output_feature_channel': 0, 'channel_dims': int64_array([1]), 'spatial_dims': int64_array([2, 3]), 'batch_dims': int64_array([0]), 'kernel_spatial_idx': int64_array([2, 3]), 'group': 1, 'reshape_kernel': True, 'bias_addable': True, } conv = Convolution(graph, attrs=conv_attrs).create_node() conv.in_port(0).connect(concat.out_port(0)) conv.in_port(1).connect(node.in_port(1).get_source()) # change layout for weights from OHWI to OIHW # in future should be replaced by common Permute mechanics weights = conv.in_port(1).get_source().node.value weights = weights.reshape( int64_array([node.out_channels, -1, node.in_channels])) weights = weights.transpose(int64_array([0, 2, 1])) weights = weights.flatten() conv.in_port(1).get_source().node.value = weights conv.in_port(2).connect(node.in_port(2).get_source()) node.out_port(0).get_connection().set_source(conv.out_port(0)) graph.remove_node(node.id)
def transform_graph(self, graph: Graph, replacement_descriptions: dict): parameter_node = graph.get_op_nodes(op='Parameter')[0] parameter_node['data_type'] = data_type_str_to_np( parameter_node.graph.graph['cmd_params'].data_type) # remove existing Result operations to remove unsupported sub-graph graph.remove_nodes_from( [node.id for node in graph.get_op_nodes(op='Result')] + ['detections']) # determine if the op which is a input/final result of mean value and scale applying to the input tensor # then connect it to the input of the first convolution of the model, so we remove the image pre-processing # which includes padding and resizing from the model preprocessing_input_node_id = replacement_descriptions[ 'preprocessing_input_node'] assert preprocessing_input_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \ 'should be a last node before image normalization and is specified' \ ' in the json file.'.format(preprocessing_input_node_id) preprocessing_input_node = Node(graph, preprocessing_input_node_id) consumer_node = preprocessing_input_node.out_port( 0).get_connection().get_destination().node consumer_node.in_port(0).get_connection().set_source( parameter_node.out_port(0)) preprocessing_output_node_id = replacement_descriptions[ 'preprocessing_output_node'] assert preprocessing_output_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \ 'node should provide scaled image output and is specified' \ ' in the json file.'.format(preprocessing_output_node_id) preprocessing_output_node = Node(graph, preprocessing_output_node_id) preprocessing_output_node.out_port(0).disconnect() convolution_nodes = [ n for n in graph.pseudo_topological_sort() if n.soft_get('type') == 'Convolution' ] convolution_nodes[0].in_port(0).get_connection().set_source( preprocessing_output_node.out_port(0)) # create prior boxes (anchors) generator aspect_ratios = replacement_descriptions['aspect_ratios'] assert len(aspect_ratios) % 2 == 0 aspect_ratios = list(zip(aspect_ratios[::2], aspect_ratios[1::2])) priors_generator = self.AnchorGenerator( min_level=int(replacement_descriptions['min_level']), aspect_ratios=aspect_ratios, num_scales=int(replacement_descriptions['num_scales']), anchor_scale=replacement_descriptions['anchor_scale']) prior_boxes = [] for i in range(100): inp_name = 'box_net/box-predict{}/BiasAdd'.format('_%d' % i if i else '') if inp_name not in graph: break widths, heights = priors_generator.get(i) prior_box_op = PriorBoxClusteredOp( graph, { 'width': mo_array(widths), 'height': mo_array(heights), 'clip': 0, 'flip': 0, 'variance': replacement_descriptions['variance'], 'offset': 0.5 }) prior_boxes.append( prior_box_op.create_node( [Node(graph, inp_name), parameter_node])) # concatenate prior box operations concat_prior_boxes = Concat(graph, {'axis': -1}).create_node() for idx, node in enumerate(prior_boxes): concat_prior_boxes.add_input_port(idx) concat_prior_boxes.in_port(idx).connect(node.out_port(0)) conf = Sigmoid(graph, dict(name='concat/sigmoid')).create_node( [Node(graph, 'concat')]) reshape_size_node = Const(graph, { 'value': int64_array([0, -1]) }).create_node([]) logits = Reshape(graph, dict(name=conf.name + '/Flatten')).create_node( [conf, reshape_size_node]) deltas = Reshape(graph, dict(name='concat_1/Flatten')).create_node( [Node(graph, 'concat_1'), reshape_size_node]) # revert convolution boxes prediction weights from yxYX to xyXY (convolutions share weights and bias) weights = Node(graph, 'box_net/box-predict/pointwise_kernel') weights.value = weights.value.reshape(-1, 4)[:, [1, 0, 3, 2]].reshape( weights.shape) bias = Node(graph, 'box_net/box-predict/bias') bias.value = bias.value.reshape(-1, 4)[:, [1, 0, 3, 2]].reshape(bias.shape) detection_output_node = DetectionOutput( graph, dict( name='detections', share_location=1, background_label_id=int( replacement_descriptions['num_classes']) + 1, nms_threshold=replacement_descriptions['nms_threshold'], confidence_threshold=replacement_descriptions[ 'confidence_threshold'], top_k=100, keep_top_k=100, code_type='caffe.PriorBoxParameter.CENTER_SIZE', )).create_node([deltas, logits, concat_prior_boxes]) output_op = Result(graph, dict(name='output')) output_op.create_node([detection_output_node])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') is_packed = False if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): is_packed = True embedding_bag = EmbeddingBagPackedSum(graph, { 'name': node_name }).create_node() else: embedding_bag = EmbeddingBagOffsetsSum(graph, { 'name': node_name }).create_node() node.in_port(2).get_connection().set_destination( embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination( embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination( embedding_bag.in_port(1)) node.out_port(0).get_connection().set_source( embedding_bag.out_port(0)) if len(node.in_ports() ) == 4 and not node.in_port(3).disconnected(): if is_packed: node.in_port(3).get_connection().set_destination( embedding_bag.in_port(2)) else: # connect per_sample_weights node.in_port(3).get_connection().set_destination( embedding_bag.in_port(4)) weights_shape_node = Shape( graph, { 'name': node_name + '/WeightsShape' }).create_node() weights_rank_node = Rank(graph, { 'name': node_name + '/WeightsRank' }).create_node() last_dim_node = get_canonical_axis_index_node( weights_rank_node, -1) weights_last_dim = get_shape_values_by_indices_node( weights_shape_node, last_dim_node) weights_first_dim = node_to_get_shape_value_of_indices( weights_shape_node, [0]) zero_col_node = create_op_with_const_inputs( graph, Broadcast, {0: int64_array([0])}, {'name': node_name + '/Broadcast'}) zero_col_node.in_port(1).connect( weights_last_dim.out_port(0)) default_embeddings_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, {'name': node_name + '/Unsqueeze'}) default_embeddings_node.in_port(0).connect( zero_col_node.out_port(0)) # expand embedding table with zeros weights_concat = Concat( graph, { 'axis': 0, 'in_ports_count': 2, 'name': node_name + '/Concat' }).create_node() embedding_bag.in_port(0).get_connection().set_destination( weights_concat.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_shape_node.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_rank_node.in_port(0)) weights_concat.in_port(1).connect( default_embeddings_node.out_port(0)) weights_concat.out_port(0).connect( embedding_bag.in_port(0)) # point default index to expanded part of embedding table weights_first_dim.out_port(0).connect( embedding_bag.in_port(3))
def extract(cls, node): mapping_rule = {'axis': onnx_attr(node, 'axis', 'i', default=0)} Concat.update_node_stat(node, mapping_rule) return cls.enabled
def extract(cls, node): mapping_rule = {'axis': 1} Concat.update_node_stat(node, mapping_rule) return cls.enabled
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input( in_node_port, context_len, precision=np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): concat_node = match['concat'] concat_node['axis'] = 1 concat_name = concat_node.soft_get('name', concat_node.id) concat_reshape = create_op_node_with_second_input( graph, Reshape, int64_array([1, 2, -1]), op_attrs=dict(name=concat_name + '/Reshape')) split_node = create_op_node_with_second_input( graph, Split, int64_array(1), op_attrs=dict(name=concat_name + '/Split', num_splits=2), input_node=concat_reshape) split_node_reshape = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 4]), op_attrs=dict(name=split_node.name + '/Reshape')) split_node.out_port(0).connect(split_node_reshape.in_port(0)) value = create_op_node_with_second_input( graph, Split, int64_array(1), op_attrs=dict(name=split_node_reshape.name + '/Split', num_splits=4), input_node=split_node_reshape) xmin, xmax = calculate_prior_box_value(value, value_to_div=value.out_port(2), value_to_add=value.out_port(0)) ymin, ymax = calculate_prior_box_value(value, value_to_div=value.out_port(3), value_to_add=value.out_port(1)) concat_slice_value = Concat( graph, dict(name=value.name + '/Concat', in_ports_count=4, axis=1)).create_node() for ind, node in enumerate([xmin, ymin, xmax, ymax]): concat_slice_value.in_port(ind).connect(node.out_port(0)) reshape_concat_values = create_op_node_with_second_input( graph, Reshape, int64_array([1, 1, -1]), op_attrs=dict(name=concat_slice_value.name + '/Reshape'), input_node=concat_slice_value) concat = Concat( graph, dict(name=reshape_concat_values.name + '/Concat', in_ports_count=2, axis=1)).create_node() concat.in_port(0).connect(reshape_concat_values.out_port(0)) concat.in_port(1).connect(split_node.out_port(1)) match['detection_output'].in_port(2).get_connection().set_source( concat.out_port(0)) concat_node.out_port(0).get_connection().set_destination( concat_reshape.in_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = ReadValue(graph, { 'name': 'prev_splice_memory', 'variable_id': memory_pair_id }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Assign(graph, { 'name': 'out_splice_memory', 'variable_id': memory_pair_id }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') init_value_input_memory_const_dim = Const( graph, { 'name': 'init_value_const_dim_in_memory', 'value': np.zeros(int64_array([in_shape[0], memory_size_constdim]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size_constdim]) }).create_node() input_memory_const_dim = ReadValue(graph, { 'name': 'const_dim_in_memory', 'variable_id': memory_pair_id }).create_node() init_value_input_memory_const_dim.out_port(0).connect( input_memory_const_dim.in_port(0)) crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Assign(graph, { 'name': 'const_dim_out_memory', 'variable_id': memory_pair_id }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)