def find_and_replace_pattern(self, graph: Graph): for offset_node in graph.get_op_nodes(op='MemoryOffset', splitted=False): paired_node = MemoryOffset( graph, { 'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id, 't': offset_node.t, 'has_default': offset_node.has_default }).create_node() offset_node['splitted'] = True offset_node.out_port(0).get_connection().set_source( paired_node.out_port(0)) res_node = Result(graph, { 'name': offset_node.id + "_output" }).create_node() offset_node.out_port(0).connect(res_node.in_port(0)) # If 'element_size' is previously copied from Parameter of from node with defined dim if offset_node.has_valid('element_size'): paired_node['element_size'] = offset_node['element_size'] # Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case else: paired_node['element_size'] = offset_node.in_port( 0).data.get_shape()
def add_output_in_body(node, port_num, cur_graph, cur_max_layer_id, tracks, track_index, add_unsqueeze=True): port = node.out_port(port_num) if add_unsqueeze: unsq_name = port.node.soft_get('name', port.node.id) + "/Unsqueeze" unsq_node = create_op_node_with_second_input(cur_graph, Unsqueeze, int64_array([0]), {'name': unsq_name}) port.connect(unsq_node.in_port(0)) unsq_node['internal_layer_id'] = cur_max_layer_id + 1 cur_max_layer_id += 1 tracks.insert(track_index, {'node': unsq_node, 'graph': cur_graph}) port = unsq_node.out_port(0) out_name = port.node.soft_get('name', port.node.id) + ":" + str(port_num) res_node = Result(cur_graph, {'name': out_name}).create_node() port.connect(res_node.in_port(0)) res_node['internal_layer_id'] = cur_max_layer_id + 1 cur_max_layer_id += 1 tracks.insert(track_index, {'node': res_node, 'graph': cur_graph}) return res_node
def replace_sub_graph(self, graph: Graph, match: dict): # node that is used to identify this pattern application instance for switching between supported # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE. anchor_node = match[__class__.anchor()] assert anchor_node.has_valid('name'), \ 'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.' match['input_op'] = match['concat'].in_node(0) match['input_hidden_state'] = match['concat'].in_node(1) match['input_cell_state'] = match['mul_0'].in_node(0) \ if match['mul_0'].in_node(0).id != match['sigmoid_0'].id else match['mul_0'].in_node(1) pattern_edges = self.pattern()['edges'] pattern_edges.extend([('input_op', 'concat'), ('input_cell_state', 'mul_0'), ('input_hidden_state', 'concat')]) inputs = graph.get_inputs_with_ports( match, pattern_edges, __class__.inputs + __class__.extra_inputs) lstm_op = LSTMCell( graph, dict( name=match['concat'].name + '/LSTMCell', activations=None, )) lstm_node = lstm_op.create_node(inputs) lstm_node['old_infer'] = lstm_node.infer lstm_node.infer = __class__.infer # this node consumes one of the resulting LSTMCell outputs, # it should be removed before reconnecting the nodes, # otherwise it will be reconnected to the new cell output graph.remove_node(match['tanh_1'].id) for i, output in enumerate(__class__.outputs): match[output].replace_node(lstm_node, i) # Because of LSTMCell specification, this layer MUST have 2 outputs. # => we need to create fake consumers for LSTMCell # when this node haven't some outputs. for i in [0, 1]: if i not in lstm_node.out_nodes(): fake_output_node = Result( graph, dict(name=lstm_node.name + "/Output_{}".format(i))) fake_output_node.create_node(inputs=[lstm_node], edge_attrs={ 'out': i, 'in': 0 }) lstm_node['tf'] = True lstm_node['extra_inputs'] = { name: match[name].id for name in __class__.extra_inputs } lstm_node['inputs'] = { name: match[name].id for name in __class__.inputs }
def split_offset(offset_node: Node): paired_node = MemoryOffset(offset_node.graph, {'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id, 'element_size': offset_node['element_size'], 't': offset_node.t, 'has_default': offset_node.has_default}).create_node() offset_node['splitted'] = True offset_node.out_port(0).get_connection().set_source(paired_node.out_port(0)) res_node = Result(offset_node.graph, {'name': offset_node.id + '_output'}).create_node() offset_node.out_port(0).connect(res_node.in_port(0))
def normalize_outputs(node: Node): if node.has_valid('out_ports_count') and len( node.out_edges()) < node.out_ports_count: from openvino.tools.mo.ops.result import Result # Import is here to avoid circular import error for p in range(node.out_ports_count): if p not in node.out_ports(): node.add_output_port(p) if node.out_port(p).disconnected(): res_node = Result( node.graph, { 'name': node.name + '/Fake_output_{}/'.format(p), 'keep_output_port': True }).create_node() node.out_port(p).connect(res_node.in_port(0))
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): # IE DetectionOutput layer consumes flattened confidences and locations tensors. # That is why we add reshapes before them. locs_node = match.single_input_node(0) conf_node = match.single_input_node(1) prior_boxes_node = match.single_input_node(2) locs_out_nodes = locs_node[0].out_nodes() assert len(locs_out_nodes) == 1 locs_out_node = locs_out_nodes[list(locs_out_nodes.keys())[0]] assert locs_out_node.op == "Result", locs_out_node.op graph.remove_node(locs_out_node.id) conf_out_nodes = conf_node[0].out_nodes() assert len(conf_out_nodes) == 1 conf_out_node = conf_out_nodes[list(conf_out_nodes.keys())[0]] assert conf_out_node.op == "Result", conf_out_node.op graph.remove_node(conf_out_node.id) # reshape operation to flatten confidence tensor const = Const(graph, {'value': int64_array([0, -1])}).create_node() reshape_loc_node = Reshape(graph, {}).create_node( [locs_node, const], dict(name='DetectionOutput_Reshape_loc_')) # reshape operation to flatten confidence tensor reshape_conf_node = Reshape(graph, {}).create_node( [conf_node, const], dict(name='DetectionOutput_Reshape_conf_')) # remove the Result node after the priors node assert prior_boxes_node[0].out_node().op == "Result" graph.remove_node(prior_boxes_node[0].out_node().id) # reshape operation for prior boxes tensor const = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() reshape_priors_node = Reshape(graph, {}).create_node( [prior_boxes_node, const], dict(name='DetectionOutput_Reshape_priors_')) # create Detection Output node with three inputs: locations, confidences and prior boxes detection_output_op = DetectionOutput( graph, match.custom_replacement_desc.custom_attributes) detection_output_node = detection_output_op.create_node( [reshape_loc_node, reshape_conf_node, reshape_priors_node], dict(name=detection_output_op.attrs['type'] + '_')) PermuteAttrs.set_permutation(reshape_priors_node, detection_output_node, None) # create Output node to mark DetectionOutput as a graph output operation output_op = Result(graph) output_op.create_node([detection_output_node], dict(name='sink_')) return {}
def transform_graph(self, graph: Graph, replacement_descriptions): graph.remove_nodes_from(graph.get_nodes_with_attributes(op='Result')) for i, input_node_name in enumerate( replacement_descriptions['entry_points']): if input_node_name not in graph.nodes(): raise Error( 'TensorFlow YOLO V3 conversion mechanism was enabled. ' 'Entry points "{}" were provided in the configuration file. ' 'Entry points are nodes that feed YOLO Region layers. ' 'Node with name {} doesn\'t exist in the graph. ' 'Refer to documentation about converting YOLO models for more information.' .format( ', '.join(replacement_descriptions['entry_points']), input_node_name)) last_node = Node(graph, input_node_name).in_node(0) op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1, do_softmax=0, nchw_layout=True) op_params.update(replacement_descriptions) if 'masks' in op_params: op_params['mask'] = op_params['masks'][i] del op_params['masks'] region_layer_node = RegionYoloOp(graph, op_params).create_node( [last_node]) # TODO: do we need change axis for further permutation region_layer_node.dim_attrs.remove('axis') Result(graph, { 'name': region_layer_node.id + '/Result' }).create_node([region_layer_node])
def test_leaky_relu_mul_multiple_consumers(self): # multiple consumers of Mul operation graph = build_graph_with_edge_attrs(nodes, edges, {}) additional_result = Result(graph, {'name': 'result_2'}).create_node() Node(graph, 'mul').out_port(0).connect(additional_result.in_port(0)) ref_nodes = { **regular_op_with_shaped_data('input', shape, { 'type': 'Parameter', 'op': 'Parameter' }), **regular_op_with_shaped_data('mul', shape, { 'type': 'Multiply', 'name': 'mul' }), **regular_op_with_shaped_data('max', shape, { 'type': 'Maximum', 'name': 'final_max' }), **valued_const_with_data('const', float_array([0.5])), **regular_op_with_shaped_data('leaky_relu', shape, { 'type': 'LeakyReLU', 'name': 'max_final', 'negative_slope': None }), **result('result'), **result('result_2') } ref_edges = [ *connect('input:0', '0:mul'), *connect('const', '1:mul'), *connect('max:0', 'result'), *connect('mul:0', 'result_2'), *connect_data('input', 'leaky_relu'), *connect('leaky_relu', 'result') ] graph_ref = build_graph_with_edge_attrs(ref_nodes, ref_edges) LeakyReLUFusion().find_and_replace_pattern(graph) graph.clean_up() (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, 'result_2') self.assertTrue(flag, resp)
def find_and_replace_pattern(self, graph: Graph): for ctc_greedy_decoder_tf in graph.get_op_nodes( op='CTCGreedyDecoderSeqLen', output_sparse_format=True): ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get( 'name', ctc_greedy_decoder_tf.id) # TF CTCGreedyDecoder have 4 output tensors. If any of them connected to not Result operation then # transformation in not applicable for port_num in ctc_greedy_decoder_tf.out_ports(): if not ctc_greedy_decoder_tf.out_port(port_num).disconnected()\ and ctc_greedy_decoder_tf.out_port(port_num).get_destination().node.soft_get('op') != 'Result': return # If the first and second output are not connected to Result operations - # create Result operation and connect it to appropriate output if ctc_greedy_decoder_tf.out_port(0).disconnected(): first_result = Result( graph, { 'name': ctc_greedy_decoder_tf_name + '/decoded_classes' }).create_node() ctc_greedy_decoder_tf.out_port(0).connect( first_result.in_port(0)) if ctc_greedy_decoder_tf.out_port(1).disconnected(): second_result = Result(graph, { 'name': ctc_greedy_decoder_tf_name + '/seq_lengths_output' }).create_node() ctc_greedy_decoder_tf.out_port(1).connect( second_result.in_port(0)) # For normalizing input channel needs to transpose input data from [T, N, C] to [N, T, C] # which supported CTCGreedyDecoderSeqLen op. log.warning( 'Found TF CTCGreedyDecoder operation at the end of network. ' 'PLEASE NOTE, appropriate network output operation CTCGreedyDecoderSeqLen {} ' 'will have dense format, not sparse format!'.format( ctc_greedy_decoder_tf_name)) ctc_data_permute = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0, 2])}, {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'}) assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \ 'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format( ctc_greedy_decoder_tf_name) ctc_greedy_decoder_tf.in_port(0).get_source().connect( ctc_data_permute.in_port(0)) ctc_greedy_decoder_tf.in_port(0).disconnect() ctc_data_permute.out_port(0).connect( ctc_greedy_decoder_tf.in_port(0)) del ctc_greedy_decoder_tf['output_sparse_format'] for port_num in [2, 3 ]: # MO CTCGreedyDecoderSeqLen may have 2 outputs if port_num in ctc_greedy_decoder_tf.out_ports(): if not ctc_greedy_decoder_tf.out_port( port_num).disconnected(): ctc_greedy_decoder_tf.out_port(port_num).disconnect()
def replace_pattern(self, graph: Graph, match: dict): node = match['pooling'] node.type = 'MaxPool' del node['pool_method'] if 'exclude_pad' in node: del node['exclude_pad'] # adding missed outputs for MaxPool node if node.out_port(0).disconnected(): output = Result( node.graph, { 'name': node.name + '/Result_port_0/', 'keep_output_port': node.has_and_set('remove_values_output') }).create_node() node.out_port(0).get_connection().set_destination( output.in_port(0)) if node.out_port(1).disconnected(): output = Result( node.graph, { 'name': node.name + '/Result_port_1/', 'keep_output_port': node.has_and_set('remove_values_output') }).create_node() node.out_port(1).get_connection().set_destination( output.in_port(0))
def normalize_outputs(node: Node): """ This function adds missed outputs for TopK node. """ if node.out_port(0).disconnected(): output = Result(node.graph, {'name': node.name + '/Result_port_0/', 'keep_output_port': node.has_and_set('remove_values_output')}).create_node() node.out_port(0).get_connection().set_destination(output.in_port(0)) if node.out_port(1).disconnected(): output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node() node.out_port(1).get_connection().set_destination(output.in_port(0))
def transform_graph(self, graph: Graph, replacement_descriptions): op_outputs = [ n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'Result' ] for op_output in op_outputs: last_node = Node(graph, op_output).in_node(0) op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1) op_params.update(replacement_descriptions) region_layer = RegionYoloOp(graph, op_params) region_layer_node = region_layer.create_node([last_node]) # here we remove 'axis' from 'dim_attrs' to avoid permutation from axis = 1 to axis = 2 region_layer_node.dim_attrs.remove('axis') Result(graph).create_node([region_layer_node]) graph.remove_node(op_output)
def normalize_outputs(node: Node): if node.out_port(0).disconnected(): output = Result( node.graph, { 'name': node.name + '/Result_port_0/', 'keep_output_port': node.has_and_set('remove_values_output') }).create_node() node.out_port(0).get_connection().set_destination( output.in_port(0)) # we check port existing to support MaxPool_1 with only 1 output port and MaxPool_8 with 2 output ports if node.has_port('out', 1) and node.out_port(1).disconnected(): output = Result( node.graph, { 'name': node.name + '/Result_port_1/', 'keep_output_port': node.has_and_set('remove_values_output') }).create_node() node.out_port(1).get_connection().set_destination( output.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): box_nms = match['box_nms'] top_k = box_nms.topk nms_threshold = box_nms.overlap_thresh ssd_concats = {} concat_names = ['ssd_concat1', 'ssd_concat0', 'ssd_concat2'] for i, concat_match in enumerate(self.concats_pattern): for matches in find_pattern_matches(graph, concat_match['nodes'], concat_match['edges'], None, None): for match in matches: if graph.has_node(match): n = Node(graph, match) if n.op == 'Concat': ssd_concats.update({concat_names[i]: n}) break assert concat_names[0] in ssd_concats assert concat_names[1] in ssd_concats assert concat_names[2] in ssd_concats graph.remove_nodes_from(graph.get_nodes_with_attributes(op='Result')) detection_output_node = DetectionOutput( graph, dict(name=graph.unique_id() + '/DetectionOutput_', top_k=top_k, keep_top_k=top_k, nms_threshold=nms_threshold, background_label_id=0, clip=0, decrease_label_id=1, code_type="caffe.PriorBoxParameter.CENTER_SIZE", confidence_threshold=0.01, share_location=1, variance_encoded_in_target=0, normalized=1)).create_node() reshape_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name=graph.unique_id() + '/DetectionOutput_')) ssd_softmax_node = ssd_concats['ssd_concat0'].out_node().out_node() ssd_softmax_node.out_port(0).disconnect() ssd_softmax_node.out_port(0).connect(reshape_node.in_port(0)) reshape_node.out_port(0).connect(detection_output_node.in_port(1)) ssd_concats['ssd_concat2'].axis = 2 self.reshape_priorboxes(ssd_concats['ssd_concat2']) ssd_concats['ssd_concat1'].out_port( 0).get_connection().set_destination( detection_output_node.in_port(0)) ssd_concats['ssd_concat2'].out_port( 0).get_connection().set_destination( detection_output_node.in_port(2)) Result(graph, { 'name': detection_output_node.id + '/Result' }).create_node([detection_output_node])
def replace_op(self, graph: Graph, node: Node): input_out_port = node.in_port(0).get_source() memory_pair_input = unique_id('id') memory_pair_output = unique_id('id') # Input -> FullyConnected fc_layer_after_input_attrs = { 'name': 'input_fullyconnected', 'out-size': node.gifo_x_weights_shape[0], 'transpose_weights': True, 'bias_term': True, } fc_layer_after_input = FullyConnected( graph, fc_layer_after_input_attrs).create_node() fc_layer_after_input.in_port(0).connect(input_out_port) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases) init_value_prev_lstm_output = create_const_with_batch_from_input( input_out_port, node.gifo_r_weights_shape[1]) prev_lstm_output = ReadValue(graph, { 'name': 'prev_memory_output', 'variable_id': memory_pair_input }).create_node() prev_lstm_output.in_port(0).connect( init_value_prev_lstm_output.out_port(0)) # *Memory(output) -> FullyConnected fc_layer_from_prev_state_attrs = { 'name': 'prev_memory_output_fullyconnected', 'out-size': node.gifo_r_weights_shape[0], 'transpose_weights': True, 'bias_term': False, } fc_layer_from_prev_state = FullyConnected( graph, fc_layer_from_prev_state_attrs).create_node() fc_layer_from_prev_state.in_port(0).connect( prev_lstm_output.out_port(0)) input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights) # Memory -> FullyConnected \ # *Eltwise(sum) # Input -> FullyConnected / join_input_prev_state_sum = Add(graph, { 'name': 'join_input_eltwise' }).create_node() join_input_prev_state_sum.in_port(0).connect( fc_layer_from_prev_state.out_port(0)) join_input_prev_state_sum.in_port(1).connect( fc_layer_after_input.out_port(0)) # *Eltwise(sum) -> Split # it is split into 4 nodes: Act, Eltw*3 # the following order is mandatory # ___Tanh # / # Split ---(2)Eltwise(sum) # |\ # | \__(3)Eltwise(sum) # |____(4)Eltwise(sum) split_joined_input_axis = Const(graph, { 'value': np.int64(1) }).create_node() split_joined_input = Split(graph, { 'name': 'join_input_split', 'num_splits': 4, 'out_ports_count': 4 }).create_node() split_joined_input.in_port(0).connect( join_input_prev_state_sum.out_port(0)) split_joined_input.in_port(1).connect( split_joined_input_axis.out_port(0)) init_value_prev_lstm_state = create_const_with_batch_from_input( split_joined_input.out_port(0), node.input_gate_weights.shape[0]) prev_lstm_state = ReadValue(graph, { 'name': 'prev_memory_state', 'variable_id': memory_pair_output }).create_node() prev_lstm_state.in_port(0).connect( init_value_prev_lstm_state.out_port(0)) # *Memory(state) -> *ScaleShift(input) state_input_scaleshift_attrs = { 'name': 'input_scaleshift', 'bias_term': False } state_input_scaleshift = ScaleShiftOp( graph, state_input_scaleshift_attrs).create_node() state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights) # *Memory(state) -> *ScaleShift(forget) state_forget_scaleshift_attrs = { 'name': 'forget_scaleshift', 'bias_term': False } state_forget_scaleshift = ScaleShiftOp( graph, state_forget_scaleshift_attrs).create_node() state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights) # Split \ # (2)Eltwise(sum) # Memory(state) -> *ScaleShift(input) / join_prev_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_input_eltwise' }).create_node() join_prev_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(1)) join_prev_lstm_input_joined_input_sum.in_port(1).connect( state_input_scaleshift.out_port(0)) # Split \ # (3)Eltwise(sum) # Memory(state) -> *ScaleShift(forget) / join_prev_lstm_input_joined_forget_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_forget_sum', }).create_node() join_prev_lstm_input_joined_forget_sum.in_port(0).connect( split_joined_input.out_port(2)) join_prev_lstm_input_joined_forget_sum.in_port(1).connect( state_forget_scaleshift.out_port(0)) # Split -> Tanh remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node() remember_tahn.in_port(0).connect(split_joined_input.out_port(0)) # Split -> (2)Eltwise(sum) -> *Sigmoid remember_sigmoid = Sigmoid(graph, { 'name': 'remember_sigmoid' }).create_node() remember_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_input_sum.out_port(0)) # Split -> (3)Eltwise(sum) -> **Sigmoid forget_sigmoid = Sigmoid(graph, { 'name': 'forget_sigmoid' }).create_node() forget_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_forget_sum.out_port(0)) # *Memory(state) \ # (6)Eltwise(mul) # Split -> (3)Eltwise(sum) -> **Sigmoid / join_forget_prev_state_mul = Mul(graph, { 'name': 'join_forget_prev_state_mul' }).create_node() join_forget_prev_state_mul.in_port(0).connect( forget_sigmoid.out_port(0)) join_forget_prev_state_mul.in_port(1).connect( prev_lstm_state.out_port(0)) # Split -> Tahn \ # (5)Eltwise(mul) # Split -> (2)Eltwise(sum) -> *Sigmoid / join_remember_candidates_mul = Mul( graph, { 'name': 'join_remember_candidates_mul' }).create_node() join_remember_candidates_mul.in_port(0).connect( remember_tahn.out_port(0)) join_remember_candidates_mul.in_port(1).connect( remember_sigmoid.out_port(0)) # (5)Eltwise(mul) \ # (7)Eltwise(sum) # (6)Eltwise(mul) / join_forget_remember_sum = Add(graph, { 'name': 'join_forget_remember_sum' }).create_node() join_forget_remember_sum.in_port(0).connect( join_forget_prev_state_mul.out_port(0)) join_forget_remember_sum.in_port(1).connect( join_remember_candidates_mul.out_port(0)) # (7)Eltwise(sum) -> Clamp join_forget_clamp = create_op_with_const_inputs( graph, Clamp, { 1: np.array(-node.clip_value, dtype=np.float32), 2: np.array(node.clip_value, dtype=np.float32) }, {'name': 'join_forget_clamp'}, join_forget_remember_sum) # # Clamp -> (2)Memory(state) next_lstm_state = Assign(graph, { 'name': 'next_lstm_state', 'variable_id': memory_pair_output }).create_node() next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0)) res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node() res_node.in_port(0).connect(next_lstm_state.out_port(0)) # Clamp -> (2)Tahn state_filtered_tahn = Tanh(graph, { 'name': 'state_filtered_tahn' }).create_node() state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0)) # Clamp -> (2)ScaleShift clamp_scaleshift_attrs = { 'name': 'clamp_scaleshift', 'bias_term': False } clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node() clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0)) input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights) # Split \ # (4)Eltwise(sum) # Clamp -> (2)ScaleShift / join_next_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_next_lstm_input_joined_input_sum', }).create_node() join_next_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(3)) join_next_lstm_input_joined_input_sum.in_port(1).connect( clamp_scaleshift.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid output_sigmoid = Sigmoid(graph, { 'name': 'output_sigmoid' }).create_node() output_sigmoid.in_port(0).connect( join_next_lstm_input_joined_input_sum.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid \ # (5)Eltwise(mul) # Clamp -> (2)Tahn / joined_output_mul = Mul(graph, { 'name': 'joined_output_mul' }).create_node() joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0)) joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0)) # (5)Eltwise(mul) -> (3)FullyConnected fc_output_attrs = { 'name': 'FullyConnected', 'out-size': node.projection_weights_shape[0], 'transpose_weights': True, 'bias_term': False } fc_output = FullyConnected(graph, fc_output_attrs).create_node() fc_output.in_port(0).connect(joined_output_mul.out_port(0)) input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights) # / (2)Memory(output) # (3)FullyConnected # \ Output (any next node) (edge created automatically after replacement) next_lstm_output = Assign(graph, { 'name': 'next_lstm_output', 'variable_id': memory_pair_input }).create_node() next_lstm_output.in_port(0).connect(fc_output.out_port(0)) res_node_lstm_output = Result(graph, { 'name': 'next_lstm_output_out' }).create_node() res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0)) return [fc_output.id]
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input( in_node_port, context_len, precision=np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def add_output_for_path(graphs_nodes_path): # add output to nodes according to path step_node = graphs_nodes_path[-1]['node'] cur_graph = graphs_nodes_path[-1]['graph'] ports_to_add_nodes = [] for o_p in step_node.out_ports(): ports_to_add_nodes.append(o_p) # update internal_layer_id for new Results for i in range(len(graphs_nodes_path)-1, 0, -1): cur_max_layer_id = max_internal_layer_id(cur_graph) + 1 cur_loop_node = graphs_nodes_path[i-1]['node'] new_out_ports = [] if cur_loop_node.op is not 'If': # add Unsqueeze and Result for TensorIterator and Loop and update output_port_map for p_num in ports_to_add_nodes: res_node, graphs_nodes_path, cur_max_layer_id = add_output_in_body(step_node, p_num, cur_graph, cur_max_layer_id, graphs_nodes_path, i) # IR reader fix output port map for Loop, but have not change for TensorIterator new_port_id = len(cur_loop_node.out_ports()) if cur_loop_node.op == 'TensorIterator': new_port_id = new_port_id + len(cur_loop_node.in_ports()) cur_loop_node.output_port_map.append({'axis': 0, 'stride': 1, 'part_size': 1, 'start': 0, 'end': -1, 'external_port_id': new_port_id, 'internal_layer_id': res_node['internal_layer_id']}) port_id = new_port_id if cur_loop_node.op == 'TensorIterator': port_id = port_id - len(cur_loop_node.in_ports()) new_out_ports.append(port_id) cur_loop_node.add_output_port(port_id) else: # add Result nodes for If and update output_id for p_num in ports_to_add_nodes: res_node, graphs_nodes_path, cur_max_layer_id = add_output_in_body(step_node, p_num, cur_graph, cur_max_layer_id, graphs_nodes_path, i, add_unsqueeze=False) if cur_loop_node.then_graph == cur_graph: new_port_id = len(cur_loop_node.out_ports()) res_node['output_id'] = new_port_id cur_loop_node.add_output_port(new_port_id) new_out_ports.append(new_port_id) else: res_node['output_id'] = list(cur_loop_node.out_ports().keys())[-1] ports_to_add_nodes = new_out_ports step_node = cur_loop_node cur_graph = graphs_nodes_path[i-1]['graph'] i = 0 for p_num in ports_to_add_nodes: port = step_node.out_port(p_num) out_name = step_node.soft_get('name', step_node.id) + "." + str(p_num) res_node = Result(cur_graph, {'name': out_name}).create_node() port.connect(res_node.in_port(0)) # add name of Result to fw_tensor_debug_info to avoid renaming if step_node.out_nodes()[p_num].has_and_set('fw_tensor_debug_info'): step_node.out_nodes()[p_num]['fw_tensor_debug_info'].append(out_name) else: step_node.out_nodes()[p_num]['fw_tensor_debug_info'] = [[out_name, out_name]] if step_node.op == 'TensorIterator': step_node.out_edges()[len(step_node.out_edges())-1]['external_port_id'] = p_num + \ len(step_node.in_ports()) graphs_nodes_path.insert(0, {'node': res_node, 'graph': cur_graph}) i += 1 return graphs_nodes_path
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = ReadValue(graph, { 'name': 'prev_splice_memory', 'variable_id': memory_pair_id }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Assign(graph, { 'name': 'out_splice_memory', 'variable_id': memory_pair_id }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') init_value_input_memory_const_dim = Const( graph, { 'name': 'init_value_const_dim_in_memory', 'value': np.zeros(int64_array([in_shape[0], memory_size_constdim]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size_constdim]) }).create_node() input_memory_const_dim = ReadValue(graph, { 'name': 'const_dim_in_memory', 'variable_id': memory_pair_id }).create_node() init_value_input_memory_const_dim.out_port(0).connect( input_memory_const_dim.in_port(0)) crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Assign(graph, { 'name': 'const_dim_out_memory', 'variable_id': memory_pair_id }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size]), dtype=np.float32), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)