def add_fake_background_loc(graph: Graph, input_node: Node): """ DetectionOutput layer expects that box coordinates contains coordinates of boxes for the "background" class also, but in the TensorFlow\* Object Detection API the tensor contains information about real object classes only. The function copies a slice of the output data of the node 'input_node' and then concats it to the beginning of the data. The data in this slice is not used by the Detection Output layer so the actual values are not important. This approach allows the model to be reshape-able and does not introduce many layers. "background" class box coordinates. :param graph: graph to operate on. :param input_node: node producing the boxes coordinates. :return convolution node that adds slice of data for the "background" class. """ crop_op = Crop(graph, dict(axis=np.array([1]), offset=np.array([0]), dim=np.array([1]), nchw_layout=True)) crop_node = crop_op.create_node([input_node], dict(name='crop_locs')) concat_op = Concat(graph, dict(axis=1, in_ports_count=2, nchw_layout=True)) return concat_op.create_node([crop_node, input_node], dict(name=input_node.id + '/locs_with_fake_background'))
def test_concat_op(self): graph = build_graph(self.nodes_attributes, [('node_1', 'concat_node'), ('concat_node', 'node_3')]) concat_node = Concat(graph, self.nodes_attributes['concat_node']).add_node() self.assertEqual(concat_node.type, 'Concat') self.assertEqual(concat_node.op, 'Concat') self.assertEqual(concat_node.infer, concat_infer)
def replace_op(self, graph: nx.MultiDiGraph, node: Node): input_node = node.in_nodes()[0] memory_pair_id = unique_id('id') # Memory(in) input_memory = Memory( graph, { 'name': 'prev_splice_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': np.array(([input_node.shape[1] * len(node.context)]), dtype=np.int64) }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': np.array([1], dtype=np.int64), 'offset': np.array([input_node.shape[1]], dtype=np.int64), 'dim': np.array([input_node.shape[1] * (len(node.context) - 1)], dtype=np.int64) }).create_node([input_memory]) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'axis': 1 }).create_node([crop, input_node]) # Concat -> Memory(out) Memory( graph, { 'name': 'out_splice_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': np.array([input_node.shape[1] * len(node.context)], dtype=np.int64) }).create_node([concat_node]) return [concat_node.id]
def replace_op(self, graph: Graph, node: Node): if node.has_and_set('inputs_preprocessed'): log.debug('Node "{}" has already been preprocessed'.format( node.soft_get('name'))) return [] # reshape tensor with batch indices to 2d unsqueeze_node = create_op_node_with_second_input( graph, Unsqueeze, int64_array([1]), {'name': node.name + '/Unsqueeze'}, node.in_node(2)) convert_node = Cast( graph, { 'name': unsqueeze_node.name + '/ToFloat', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() convert_node.in_port(0).connect(unsqueeze_node.out_port(0)) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2 }) concat_node = concat_op.create_node([convert_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_node = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 5]), dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'), swapped_box_coordinates_node) graph.create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def extend_inputs(node: Node, num_insertions: int): graph = node.graph node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions) blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name), 'value': int64_array(blank_values_arr)}).create_node() if node.in_port(i).get_source().node.soft_get('type') == 'Concat': # concat already exists concat = node.in_port(i).get_source().node last_in_port = max(concat.in_ports().keys()) assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \ 'should be connected'. \ format(concat.soft_get('name', node.id)) concat.add_input_port(last_in_port + 1) concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0)) else: # have to create concat concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': 2}).create_node() node.in_port(i).get_connection().set_destination(concat.in_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.out_port(0).get_connection().set_destination(node.in_port(i))
def replace_with_split_concat(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order split = create_op_with_const_inputs(graph, Split, {1: int64_array(axis)}, {'name': name + '/Split', 'num_splits': order.size}) concat = Concat(graph, {'name': name + '/Concat', 'axis': axis, 'in_ports_count': order.size}).create_node() for out_port_idx, in_port_idx in enumerate(order): split.out_port(out_port_idx).connect(concat.in_port(in_port_idx)) node.out_port(0).get_connection().set_source(concat.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) graph.remove_node(node.id)
def replace_op(self, graph: Graph, node: Node): out_node = Concat( graph, { 'axis': node.axis, 'in_ports_count': len(node.in_ports()), 'name': node.name + '/Concat_', }).create_node() for ind in node.in_ports(): expand_dims_node = ExpandDims( graph, { 'expand_axis': int64_array([node.axis]), 'name': node.name + '/ExpandDims_' }).create_node() node.in_port(ind).get_connection().set_destination( expand_dims_node.in_port(0)) expand_dims_node.out_port(0).connect(out_node.in_port(ind)) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def concat_output_states(graph: Graph, match: dict, new_states: list): """ Concatenates output states from multilayer layer. """ rnn_layer = match['rnn_layer'] original_states = [rnn_layer.out_node(i) if i in rnn_layer.out_nodes() else None for i in [1, 2]] concat_ops = [ Concat(rnn_layer.graph, { 'name': rnn_layer.name + '/FinalLayerSplitConcat/HiddenState', 'axis': -1 }), Concat(rnn_layer.graph, { 'name': rnn_layer.name + '/FinalLayerSplitConcat/CellState', 'axis': -1 }) ] for i in range(len(original_states)): # [0] or [0, 1] if original_states[i] is None: continue concat_ops[i].attrs.update({'in_ports_count': len(new_states[i])}) concat_ops[i].create_node_with_data(inputs=new_states[i], data_nodes=[original_states[i]])
def replace_tdnn(self, graph: Graph, tdnn_node: Node): tdnn_name = tdnn_node.soft_get('name', tdnn_node.id) concat_node = Concat(graph, {'axis': 1}).create_node() rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'), (concat_node, tdnn_name)]) for offset_ind, t in enumerate(tdnn_node['time_offsets']): concat_node.add_input_port(offset_ind) if t != 0: memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t)) memoryoffset_node = MemoryOffset( graph, { 'name': memory_name, 't': t, 'pair_name': memory_name + '_out', 'has_default': False, 'splitted': False }).create_node() tdnn_node.in_port(0).get_source().connect( memoryoffset_node.in_port(0)) memoryoffset_node.out_port(0).connect( concat_node.in_port(offset_ind)) else: # 0 time delay is not allowed in IE, it's meaningless # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset tdnn_node.in_port(0).get_source().connect( concat_node.in_port(offset_ind)) weights = tdnn_node['weights'] fc_inputs = {1: weights} bias_term = False if tdnn_node.has_valid('biases'): assert len(tdnn_node['biases']) == weights.shape[0] fc_inputs.update({2: tdnn_node['biases']}) bias_term = True fc_node = create_op_with_const_inputs( graph, FullyConnected, fc_inputs, { 'name': tdnn_name + '/FC', 'out-size': weights.shape[0], 'transpose_weights': True, 'bias_term': bias_term }) concat_node.out_port(0).connect(fc_node.in_port(0)) tdnn_node.in_port(0).disconnect() tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
def fuse_reduces(first_reduce, second_reduce): first_reduce_name = first_reduce.soft_get('name', first_reduce.id) second_reduce_name = second_reduce.soft_get('name', second_reduce.id) reduce_type = first_reduce.type assert first_reduce.type == second_reduce.type if len(first_reduce.out_port(0).get_destinations()) != 1: # data dependency return if first_reduce.keep_dims != second_reduce.keep_dims: return first_axes = first_reduce.in_port(1).data.get_value() second_axes = second_reduce.in_port(1).data.get_value() if first_axes is None or second_axes is None: # dynamic axes merging is not supported return if not first_reduce.keep_dims: if not np.all(first_axes > second_axes): # indexing of upper reduce input dimensions changed return graph = second_reduce.graph new_axes = Concat( graph, { 'name': second_reduce_name + '/Axes', 'axis': int64_array(0), 'in_ports_count': 2, 'override_output_shape': True }).create_node() new_axes.in_port(0).connect(first_reduce.in_port(1).get_source()) new_axes.in_port(1).connect(second_reduce.in_port(1).get_source()) first_reduce.in_port( 0).get_source().node['need_shape_inference'] = True first_reduce.in_port( 0).get_source().node['override_output_shape'] = True second_reduce.in_port(1).get_connection().set_source( new_axes.out_port(0)) first_reduce.out_port(0).get_connection().set_source( first_reduce.in_port(0).get_connection().get_source()) first_reduce.in_port(1).disconnect() graph.remove_node(first_reduce.id) log.debug( '{0} nodes {1} and {2} were fused to a single {2} node with updated axes input' ''.format(reduce_type, first_reduce_name, second_reduce_name))
def replace_op(self, graph: Graph, node: Node): if node.has_and_set('inputs_preprocessed'): log.debug('Node "{}" has already been preprocessed'.format( node.soft_get('name'))) return [] # reshape tensor with batch indices to 2d unsqueeze_op = Unsqueeze( graph, {'unsqueeze_dims': np.array([1], dtype=np.int64)}) unsqueeze_node = unsqueeze_op.create_node([node.in_node(2)]) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2 }) concat_node = concat_op.create_node([unsqueeze_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_op = Reshape(graph, dict(dim=np.array([-1, 5]))) reshape_2d_node = reshape_2d_op.create_node( [swapped_box_coordinates_node], dict(name=swapped_box_coordinates_node.id + '/reshape_2d_', nchw_layout=True)) graph.create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def concat(self, bilstm, forward_outputs, reverse_outputs, final_outputs): """ Concatenates two set of outputs from BiLSTM """ concat_ops = [ Concat(bilstm.graph, { 'name': bilstm.name + '/FinalConcat/Data', 'axis': 1 }), Concat(bilstm.graph, { 'name': bilstm.name + '/FinalConcat/HiddenState', 'axis': 0 }), Concat(bilstm.graph, { 'name': bilstm.name + '/FinalConcat/CellState', 'axis': 0 }) ] bilstm.graph.remove_node(bilstm.id) for i in final_outputs: concat_ops[i].create_node_with_data( [forward_outputs[i], reverse_outputs[i]], data_nodes=[final_outputs[i]])
def replace_pattern(self, graph: Graph, match: dict): node = match['node'] node_name = node.soft_get('name', node.id) connected_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] if len(connected_ports) == 2: axis = node.in_port(1).data.get_value() else: axis = node.axis assert axis is not None, 'The "axis" should be defined for node "{}"'.format( node_name) assert node.has_and_set( 'output_type'), 'The data type is not set for node "{}"'.format( node_name) topk_mode = 'max' if node.op == 'ArgMax' else 'min' topk_node = TopK( graph, { 'axis': axis, 'mode': topk_mode, 'sort': 'index', 'remove_values_output': node.has_and_set('remove_values_output'), 'index_element_type': node.output_type }).create_node() node.in_port(0).get_connection().set_destination(topk_node.in_port(0)) if node.has_and_set( 'out_max_val' ): # in this mode the ArgMax produces tuples (max_ind, max_value) concat_node = Concat(graph, { 'axis': 1, 'name': node.name + '/Concat' }).create_node() concat_node.add_input_port(0, skip_if_exist=True) concat_node.add_input_port(1, skip_if_exist=True) topk_node.out_port(0).connect(concat_node.in_port(1)) # indices topk_node.out_port(1).connect(concat_node.in_port(0)) # values if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( concat_node.out_port(0)) else: if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( topk_node.out_port(1)) topk_node.in_port(1).connect( Const(graph, { 'name': node.soft_get('name') + '/TopK', 'value': node.top_k }).create_node().out_port(0)) graph.remove_nodes_from([node.id, node.out_node(0).id])
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name shape_of_input = Shape(graph, { 'name': 'shape/' + input_name }).create_node() shape_of_input.in_port(0).connect(input_out_port) dim_for_get_batch = Const( graph, { 'name': 'dim/crop_batch/' + shape_of_input.name, 'value': int64_array([1]), 'shape': int64_array([1]) }).create_node() get_batch = Crop( graph, { 'name': 'crop_batch/' + shape_of_input.name, 'axis': int64_array([0]), 'offset': int64_array([0]) }).create_node() get_batch.in_port(0).connect(shape_of_input.out_port(0)) get_batch.in_port(1).connect(dim_for_get_batch.out_port(0)) mem_shape_2nd_dim = Const( graph, { 'name': 'gifo_r_weights_shape/' + input_name, 'value': int64_array([second_dim]), 'shape': int64_array([1]) }).create_node() mem_shape = Concat( graph, { 'name': 'gather_memory_shape/' + input_name, 'axis': 0, 'in_ports_count': 2 }).create_node() mem_shape.in_port(0).connect(get_batch.out_port(0)) mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0)) fill_value = Const( graph, { 'name': 'fill_value/' + input_name, 'value': np.array([0.0], precision), 'shape': int64_array([1]) }).create_node() init_value_prev_lstm_output = Broadcast(graph, { 'name': 'init_value/' + input_name, }).create_node() init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0)) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): """ This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" :param graph: graph to operate on. :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. :param shape: input shape of the Slice :param axes: axes that "starts" and "ends" apply to :param node_name: Slice node name :return: Concat node that forms "begin"/"end" values for the StridedSlice """ # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is # supported by the StridedSlice layer clamp = create_op_with_const_inputs( graph, Clamp, port_value_dict={1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max}, op_attrs=dict(name=node_name + '/Clamp')) clamp.in_port(0).connect(slice_border_port) # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created # here to prevent type errors in Concat node cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() cast.in_port(0).connect(clamp.out_port(0)) concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() for value_idx, port_idx in enumerate(axes): concat.add_input_port(port_idx) # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct # Concat input port value = create_op_with_const_inputs( graph, Gather, port_value_dict={1: int64_array([value_idx]), 2: int64_array(0)}, op_attrs={'name': node_name + '/Gather'}) cast.out_port(0).connect(value.in_port(0)) value.out_port(0).connect(concat.in_port(port_idx)) for port_idx in range(len(shape)): if not concat.is_in_port_connected(port_idx): concat.add_input_port(port_idx) # This border value would be ignored in StridedSlice because of the begin_mask\end_mask const = Const(graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() const.out_port(0).connect(concat.in_port(port_idx)) return concat
def replace_op(self, graph: nx.MultiDiGraph, node: Node): expand_dims_nodes = list() expand_axis_node = Const(graph, dict(value=node.axis)).create_node([]) for ind, edge_attrs in node.in_edges().items(): expand_dims_nodes.append( ExpandDims(graph, dict(name=node.name + '/ExpandDims_')).create_node([ (node.in_node(ind), edge_attrs['out']), expand_axis_node ])) out_node = Concat(graph, dict(name=node.name + '/Concat_', axis=node.axis)).create_node(expand_dims_nodes) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def new_shape_node_from_shape_nodes(input_shape_nodes: list): """ The function returns a node producing 1D tensor with concatenated shapes produced by nodes from "input_shape_nodes" :param input_shape_nodes: list of nodes producing 1D tensors :return: the node producing concatenated values of nodes from the "input_shape_nodes" """ assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty' new_shape_node = Concat(input_shape_nodes[0].graph, {'axis': 0, 'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'} ).create_node() for ind, input_node in enumerate(input_shape_nodes): new_shape_node.add_input_port(ind) new_shape_node.in_port(ind).connect(input_node.out_port(0)) return new_shape_node
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def extract(node): attrs = {'N': node.pb.attr["N"].i, 'simple_concat': True} Concat.update_node_stat(node, attrs) return __class__.enabled
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') is_packed = False if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): is_packed = True embedding_bag = EmbeddingBagPackedSum(graph, { 'name': node_name }).create_node() else: embedding_bag = EmbeddingBagOffsetsSum(graph, { 'name': node_name }).create_node() node.in_port(2).get_connection().set_destination( embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination( embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination( embedding_bag.in_port(1)) node.out_port(0).get_connection().set_source( embedding_bag.out_port(0)) if len(node.in_ports() ) == 4 and not node.in_port(3).disconnected(): if is_packed: node.in_port(3).get_connection().set_destination( embedding_bag.in_port(2)) else: # connect per_sample_weights node.in_port(3).get_connection().set_destination( embedding_bag.in_port(4)) weights_shape_node = Shape( graph, { 'name': node_name + '/WeightsShape' }).create_node() weights_rank_node = Rank(graph, { 'name': node_name + '/WeightsRank' }).create_node() last_dim_node = get_canonical_axis_index_node( weights_rank_node, -1) weights_last_dim = get_shape_values_by_indices_node( weights_shape_node, last_dim_node) weights_first_dim = node_to_get_shape_value_of_indices( weights_shape_node, [0]) zero_col_node = create_op_with_const_inputs( graph, Broadcast, {0: int64_array([0])}, {'name': node_name + '/Broadcast'}) zero_col_node.in_port(1).connect( weights_last_dim.out_port(0)) default_embeddings_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, {'name': node_name + '/Unsqueeze'}) default_embeddings_node.in_port(0).connect( zero_col_node.out_port(0)) # expand embedding table with zeros weights_concat = Concat( graph, { 'axis': 0, 'in_ports_count': 2, 'name': node_name + '/Concat' }).create_node() embedding_bag.in_port(0).get_connection().set_destination( weights_concat.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_shape_node.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_rank_node.in_port(0)) weights_concat.in_port(1).connect( default_embeddings_node.out_port(0)) weights_concat.out_port(0).connect( embedding_bag.in_port(0)) # point default index to expanded part of embedding table weights_first_dim.out_port(0).connect( embedding_bag.in_port(3))
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() split_node = Split(graph, { 'name': 'Split_lstm_input_', 'num_splits': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False} i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False} f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False} o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': 'prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, {'name': 'Concat_c_m'}).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input( input_port, in_shape[1] * node_t) memory_out = ReadValue(graph, { 'name': pair_name, 'variable_id': node_name + pair_name }).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop( graph, { 'name': 'Memory_crop', 'dim': np.array([in_shape[1] * (node_t - 1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop( graph, { 'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) if not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = Memory( graph, { 'name': 'prev_splice_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Memory( graph, { 'name': 'out_splice_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') input_memory_const_dim = Memory( graph, { 'name': 'const_dim_in_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Memory( graph, { 'name': 'const_dim_out_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = Const( graph, { 'name': 'init_value_' + pair_name, 'value': np.zeros( int64_array([in_shape[0], in_shape[1] * node_t])), 'shape': int64_array([in_shape[0], in_shape[1] * node_t]) }).create_node() memory_out = ReadValue(graph, { 'name': pair_name, 'variable_id': node_name + pair_name }).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop( graph, { 'name': 'Memory_crop', 'dim': np.array([in_shape[1] * (node_t - 1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop( graph, { 'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def extract(cls, node): attrs = { 'axis': node.module.dim, } Concat.update_node_stat(node, attrs) return cls.enabled
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) priors_node = match.single_input_node(2)[0] placeholder = [ Node(graph, node_id) for node_id in graph.nodes() if Node(graph, node_id).op == 'Parameter' ][0] im_height = placeholder.shape[1] im_width = placeholder.shape[2] # scale prior boxes to the [0, 1] interval priors_scale_const_node = Const( graph, { 'value': np.array( [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]) }).create_node([]) priors_scale_node = Mul(graph, { 'name': 'scale_priors' }).create_node([priors_node, priors_scale_const_node]) # calculate prior boxes widths and heights split_node = SplitV(graph, { 'axis': 2, 'size_splits': [1, 1, 1, 1], 'out_ports_count': 4 }).create_node([priors_scale_node]) priors_width_node = Sub( graph, dict(name=split_node.name + '/sub_2-0_')).create_node([ (split_node, 2), (split_node, 0) ]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')).create_node([ (split_node, 3), (split_node, 1) ]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values concat_width_height_node = Concat(graph, { 'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 4 }).create_node([ priors_width_node, priors_height_node, priors_width_node, priors_height_node ]) applied_width_height_regressions_node = Mul(graph, { 'name': 'final_regressions' }).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput( graph, match.custom_replacement_desc.custom_attributes) detection_output_op.attrs['old_infer'] = detection_output_op.attrs[ 'infer'] detection_output_op.attrs['infer'] = __class__.do_infer detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors_scale_node], dict(name=detection_output_op.attrs['type'], clip=1, normalized=1, variance_encoded_in_target=0)) return {'detection_output_node': detection_output_node}
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, { 'value': np.array([0, 0, 0], dtype=np.int32) }).create_node() end = Const(graph, { 'value': np.array([1, 0, 0], dtype=np.int32) }).create_node() stride = Const(graph, { 'value': np.array([1, 1, 1], dtype=np.int32) }).create_node() priors_node = StridedSlice( graph, { 'name': priors_name + '/0_batch_slice', 'begin_mask': np.array([1, 1, 1], dtype=np.int32), 'end_mask': np.array([1, 0, 0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32), 'ellipsis_mask': np.array([0], dtype=np.int32) }).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, { 'name': 'scales_broadcast' }).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect( node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes[ 'variance'] except: raise Error( 'There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(2), 2: int64_array([1, 1, 1, 1]) }, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub( graph, dict(name=split_node.name + '/sub_2-0_')).create_node([ (split_node, 2), (split_node, 0) ]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')).create_node([ (split_node, 3), (split_node, 1) ]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, { 'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2 }).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, { 'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2 }).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, { 'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2 }).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, { 'name': 'final_regressions' }).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput( graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error( 'During {} `iou_threshold` was not retrieved from RetinaNet graph' .format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) return {'detection_output_node': detection_output_node}