def add_reshapes_for_tf_subgraph_calls(graph: Graph): """ Input and output tensors of the TFCustomSubgraphCall must be 4D because IE layer accepts and produces only 4D tensors. This function adds reshape operations where it is necessary. :param graph: graph to operate on. :return: None. """ for src_node_name, dst_node_name, edge_attrs in list( graph.edges(data=True)): src_node = Node(graph, src_node_name) dst_node = Node(graph, dst_node_name) if dst_node.kind == 'op' and dst_node.has_valid('type') and dst_node.type == 'TFCustomSubgraphCall' and \ src_node.has_valid('shape') and len(src_node.shape) != 4: log.info( "There is an data tensor of shape '{}' which goes into '{}' node" .format(src_node.shape, dst_node.type)) CustomSubgraphCall.add_reshape_before_op_node( graph, src_node_name, dst_node_name, edge_attrs) for node in graph.get_op_nodes(op='TFCustomSubgraphCall'): for index, data_node in node.out_nodes().items(): real_dims_count = len(data_node.shape) if real_dims_count != 4: log.info( "There is an data tensor of shape '{}' with real dims count '{}' which goes out of '{}' " "node".format(data_node.shape, real_dims_count, node.name)) CustomSubgraphCall.add_reshape_after_data_node( graph, data_node.id) # need to update shape of the op so IE generates XML with 4D tensors out_shape = CustomSubgraphCall.make_shape_4d( data_node['shape']) data_node['shape'] = out_shape
def find_and_replace_pattern(self, graph: Graph): for node_id, attrs in graph.nodes(data=True): if '_in_ports' not in attrs: attrs['_in_ports'] = set() if '_out_ports' not in attrs: attrs['_out_ports'] = set() for u, v, k, d in graph.edges(data=True, keys=True): from_node_attrs = graph.node[u] to_node_attrs = graph.node[v] is_control_flow = 'control_flow_edge' in d and d[ 'control_flow_edge'] is True in_port_id = d[ 'in'] if not is_control_flow else 'control_flow_' + str( d['in']) out_port_id = d[ 'out'] if not is_control_flow else 'control_flow_' + str( d['out']) to_node_attrs['_in_ports'].update( {in_port_id: { 'control_flow': is_control_flow }}) from_node_attrs['_out_ports'].update( {out_port_id: { 'control_flow': is_control_flow }}) graph.stage = 'front'
def restore_correct_ports(graph: Graph): """ Function renumbers from IE to MO port numbering and add ports to all nodes in graph. :param graph: :return: """ for node_id, attrs in graph.nodes(data=True): if '_in_ports' not in attrs: attrs['_in_ports'] = set() if '_out_ports' not in attrs: attrs['_out_ports'] = set() for u, v, k, d in graph.edges(data=True, keys=True): from_node_attrs = graph.node[u] to_node_attrs = graph.node[v] is_control_flow = 'control_flow_edge' in d and d[ 'control_flow_edge'] is True if 'in' in d: in_port_id = d[ 'in'] if not is_control_flow else 'control_flow_' + str( d['in']) to_node_attrs['_in_ports'].update( {in_port_id: { 'control_flow': is_control_flow }}) if 'out' in d: node = Node(graph, u) num_of_in_nodes = len(node.in_nodes()) decremented_number = d['out'] - num_of_in_nodes # Initially Const operation in IR has output port with number 1. But later the behaviour was changed # so the output port become 0. This change was made to be consistent with the IR serializer in the IE which # generates Const with output port 0. For the backward compatibility reason we need to decrement the Const # output port number but for current version this number shouldn't be changed during reading the IR. if node.type == 'Const' and d['out'] == 0: decremented_number = d['out'] out_port_id = decremented_number if not is_control_flow else 'control_flow_' + str( decremented_number) from_node_attrs['_out_ports'].update( {out_port_id: { 'control_flow': is_control_flow }}) d['out'] = decremented_number
def update_custom_replacement_attributes(self, graph: Graph): if not self.has('instances'): raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) + refer_to_faq_msg(66)) if not isinstance(self.instances, dict): raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format( self.replacement_id) + refer_to_faq_msg(67)) start_points = self.get_internal_input_nodes(graph) end_points = self.get_internal_output_nodes(graph) matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False) output_tensors = set() input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name) for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True): dst_node = graph.node[dst_node_name] # edge outside sub-graph into sub-graph if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes): tensor_name = src_node_name + ":" + str(edge_attrs['out']) if tensor_name not in input_nodes_mapping: input_nodes_mapping[tensor_name] = list() input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in'])) # edge from inside sub-graph to outside sub-graph if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes): output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out'])) for node_name in graph.nodes(): node = Node(graph, node_name) if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const': log.debug("Node {} doesn't have output edges. Consider it output".format(node_name)) output_tensors.add(('^' + node_name + '$', 0)) if not self.has('inputs'): self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp] for inp in sorted(input_nodes_mapping.values())] log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances)) if not self.has('outputs'): self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)] log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
def update_custom_replacement_attributes(self, graph: Graph): if not self.has('instances') or len(self.instances) == 0: raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) + refer_to_faq_msg(68)) pattern = self.instances[0] # use the first instance pattern to find input/output nodes patterns # TODO verify that all instances will produce the same sub-graph matched_nodes = nodes_matching_name_pattern(graph, pattern) output_tensors = set() input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name) for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True): dst_node = graph.node[dst_node_name] # edge outside sub-graph into sub-graph if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes): tensor_name = src_node_name + ":" + str(edge_attrs['out']) if tensor_name not in input_nodes_mapping: input_nodes_mapping[tensor_name] = list() input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name), edge_attrs['in'])) # edge from inside sub-graph to outside sub-graph if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes): output_tensors.add( (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]), edge_attrs['out'])) for node_name in graph.nodes(): node = Node(graph, node_name) if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const': log.debug("Node {} doesn't have output edges. Consider it output".format(node_name)) output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0)) if not self.has('inputs') or len(self._replacement_desc['inputs']) == 0: self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp] for inp in sorted(input_nodes_mapping.values())] log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances)) if not self.has('outputs') or len(self._replacement_desc['outputs']) == 0: self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)] log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
def load_parallel_component(file_descr, graph: Graph, prev_layer_id): """ Load ParallelComponent of the Kaldi model. ParallelComponent contains parallel nested networks. VariadicSplit is inserted before nested networks. Outputs of nested networks concatenate with layer Concat. :param file_descr: descriptor of the model file :param graph: graph with the topology. :param prev_layer_id: id of the input layers for parallel component layer :return: id of the concat layer - last layer of the parallel component layers """ nnet_count = read_token_value(file_descr, b'<NestedNnetCount>') log.debug( 'Model contains parallel component with {} nested networks'.format( nnet_count)) split_points = [] outputs = [] inputs = [] for i in range(nnet_count): read_token_value(file_descr, b'<NestedNnet>') collect_until_token(file_descr, b'<Nnet>') g = Graph() load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i)) # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis # 1st axis contains input_size of the nested subnetwork # we split input from the main network to subnetworks input_node = Node(g, 'Parameter') split_points.append(input_node['shape'][1]) g.remove_node(input_node.id) mapping = { node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph } g = nx.relabel_nodes(g, mapping) for val in mapping.values(): g.node[val]['name'] = val graph.add_nodes_from(g.nodes(data=True)) graph.add_edges_from(g.edges(data=True)) sorted_nodes = tuple(nx.topological_sort(g)) outputs.append(Node(graph, sorted_nodes[-1])) inputs.append(Node(graph, sorted_nodes[0])) split_id = graph.unique_id(prefix='NestedNets/VariadicSplit') attrs = { 'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id } variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node() prev_layer_node = Node(graph, prev_layer_id) prev_layer_node.add_output_port(0) graph.create_edge( prev_layer_node, variadic_split_node, 0, 0, create_edge_attrs(prev_layer_id, variadic_split_node.id, prev_layer_id)) concat_id = graph.unique_id(prefix='Concat') graph.add_node(concat_id, parameters=None, op='concat', kind='op') concat_node = Node(graph, concat_id) # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent # and each subnetwork's output to concat_node for i, (input_node, output_node) in enumerate(zip(inputs, outputs)): output_node.add_output_port(0) concat_node.add_input_port(i) graph.create_edge( output_node, concat_node, 0, i, create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0)) graph.create_edge( variadic_split_node, input_node, i, 0, create_edge_attrs(variadic_split_node.id, input_node.id, variadic_split_node.id, 0, i)) return concat_id
def find_and_replace_pattern(self, graph: Graph): for u, v, k, attrs in list(graph.edges(keys=True, data=True)): if 'control_flow_edge' in attrs and attrs['control_flow_edge']: graph.remove_edge(u, v, k) log.debug('Removing control flow edge from {} to {}'.format( u, v))