Exemplo n.º 1
0
    def add_reshapes_for_tf_subgraph_calls(graph: Graph):
        """
        Input and output tensors of the TFCustomSubgraphCall must be 4D because IE layer accepts and produces only 4D
        tensors. This function adds reshape operations where it is necessary.
        :param graph: graph to operate on.
        :return: None.
        """
        for src_node_name, dst_node_name, edge_attrs in list(
                graph.edges(data=True)):
            src_node = Node(graph, src_node_name)
            dst_node = Node(graph, dst_node_name)
            if dst_node.kind == 'op' and dst_node.has_valid('type') and dst_node.type == 'TFCustomSubgraphCall' and \
                    src_node.has_valid('shape') and len(src_node.shape) != 4:
                log.info(
                    "There is an data tensor of shape '{}' which goes into '{}' node"
                    .format(src_node.shape, dst_node.type))
                CustomSubgraphCall.add_reshape_before_op_node(
                    graph, src_node_name, dst_node_name, edge_attrs)

        for node in graph.get_op_nodes(op='TFCustomSubgraphCall'):
            for index, data_node in node.out_nodes().items():
                real_dims_count = len(data_node.shape)
                if real_dims_count != 4:
                    log.info(
                        "There is an data tensor of shape '{}' with real dims count '{}' which goes out of '{}' "
                        "node".format(data_node.shape, real_dims_count,
                                      node.name))
                    CustomSubgraphCall.add_reshape_after_data_node(
                        graph, data_node.id)

                    # need to update shape of the op so IE generates XML with 4D tensors
                    out_shape = CustomSubgraphCall.make_shape_4d(
                        data_node['shape'])

                    data_node['shape'] = out_shape
Exemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for node_id, attrs in graph.nodes(data=True):
            if '_in_ports' not in attrs:
                attrs['_in_ports'] = set()
            if '_out_ports' not in attrs:
                attrs['_out_ports'] = set()

        for u, v, k, d in graph.edges(data=True, keys=True):
            from_node_attrs = graph.node[u]
            to_node_attrs = graph.node[v]
            is_control_flow = 'control_flow_edge' in d and d[
                'control_flow_edge'] is True

            in_port_id = d[
                'in'] if not is_control_flow else 'control_flow_' + str(
                    d['in'])
            out_port_id = d[
                'out'] if not is_control_flow else 'control_flow_' + str(
                    d['out'])

            to_node_attrs['_in_ports'].update(
                {in_port_id: {
                    'control_flow': is_control_flow
                }})
            from_node_attrs['_out_ports'].update(
                {out_port_id: {
                    'control_flow': is_control_flow
                }})

        graph.stage = 'front'
Exemplo n.º 3
0
def restore_correct_ports(graph: Graph):
    """
    Function renumbers from IE to MO port numbering and add ports to all nodes in graph.
    :param graph:
    :return:
    """
    for node_id, attrs in graph.nodes(data=True):
        if '_in_ports' not in attrs:
            attrs['_in_ports'] = set()
        if '_out_ports' not in attrs:
            attrs['_out_ports'] = set()

    for u, v, k, d in graph.edges(data=True, keys=True):
        from_node_attrs = graph.node[u]
        to_node_attrs = graph.node[v]
        is_control_flow = 'control_flow_edge' in d and d[
            'control_flow_edge'] is True

        if 'in' in d:
            in_port_id = d[
                'in'] if not is_control_flow else 'control_flow_' + str(
                    d['in'])
            to_node_attrs['_in_ports'].update(
                {in_port_id: {
                    'control_flow': is_control_flow
                }})
        if 'out' in d:
            node = Node(graph, u)
            num_of_in_nodes = len(node.in_nodes())
            decremented_number = d['out'] - num_of_in_nodes
            # Initially Const operation in IR has output port with number 1. But later the behaviour was changed
            # so the output port become 0. This change was made to be consistent with the IR serializer in the IE which
            # generates Const with output port 0. For the backward compatibility reason we need to decrement the Const
            # output port number but for current version this number shouldn't be changed during reading the IR.
            if node.type == 'Const' and d['out'] == 0:
                decremented_number = d['out']
            out_port_id = decremented_number if not is_control_flow else 'control_flow_' + str(
                decremented_number)
            from_node_attrs['_out_ports'].update(
                {out_port_id: {
                    'control_flow': is_control_flow
                }})
            d['out'] = decremented_number
Exemplo n.º 4
0
    def update_custom_replacement_attributes(self, graph: Graph):
        if not self.has('instances'):
            raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
                        refer_to_faq_msg(66))
        if not isinstance(self.instances, dict):
            raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format(
                self.replacement_id) +
                        refer_to_faq_msg(67))

        start_points = self.get_internal_input_nodes(graph)
        end_points = self.get_internal_output_nodes(graph)

        matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
        output_tensors = set()
        input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
        for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
            dst_node = graph.node[dst_node_name]

            # edge outside sub-graph into sub-graph
            if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
                tensor_name = src_node_name + ":" + str(edge_attrs['out'])
                if tensor_name not in input_nodes_mapping:
                    input_nodes_mapping[tensor_name] = list()
                input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in']))

            # edge from inside sub-graph to outside sub-graph
            if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
                output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out']))

        for node_name in graph.nodes():
            node = Node(graph, node_name)
            if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
                log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
                output_tensors.add(('^' + node_name + '$', 0))

        if not self.has('inputs'):
            self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
                                                for inp in sorted(input_nodes_mapping.values())]
            log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))

        if not self.has('outputs'):
            self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
            log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
Exemplo n.º 5
0
    def update_custom_replacement_attributes(self, graph: Graph):
        if not self.has('instances') or len(self.instances) == 0:
            raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
                        refer_to_faq_msg(68))

        pattern = self.instances[0]  # use the first instance pattern to find input/output nodes patterns
        # TODO verify that all instances will produce the same sub-graph
        matched_nodes = nodes_matching_name_pattern(graph, pattern)

        output_tensors = set()
        input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
        for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
            dst_node = graph.node[dst_node_name]

            # edge outside sub-graph into sub-graph
            if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
                tensor_name = src_node_name + ":" + str(edge_attrs['out'])
                if tensor_name not in input_nodes_mapping:
                    input_nodes_mapping[tensor_name] = list()
                input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name),
                                                         edge_attrs['in']))

            # edge from inside sub-graph to outside sub-graph
            if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
                output_tensors.add(
                    (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
                     edge_attrs['out']))

        for node_name in graph.nodes():
            node = Node(graph, node_name)
            if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
                log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
                output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0))

        if not self.has('inputs') or len(self._replacement_desc['inputs']) == 0:
            self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
                                                for inp in sorted(input_nodes_mapping.values())]
            log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))

        if not self.has('outputs') or len(self._replacement_desc['outputs']) == 0:
            self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
            log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
Exemplo n.º 6
0
def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
    """
    Load ParallelComponent of the Kaldi model.
    ParallelComponent contains parallel nested networks.
    VariadicSplit is inserted before nested networks.
    Outputs of nested networks concatenate with layer Concat.

    :param file_descr: descriptor of the model file
    :param graph: graph with the topology.
    :param prev_layer_id: id of the input layers for parallel component layer
    :return: id of the concat layer - last layer of the parallel component layers
    """
    nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
    log.debug(
        'Model contains parallel component with {} nested networks'.format(
            nnet_count))

    split_points = []
    outputs = []
    inputs = []

    for i in range(nnet_count):
        read_token_value(file_descr, b'<NestedNnet>')
        collect_until_token(file_descr, b'<Nnet>')
        g = Graph()
        load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))

        # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
        # 1st axis contains input_size of the nested subnetwork
        # we split input from the main network to subnetworks
        input_node = Node(g, 'Parameter')
        split_points.append(input_node['shape'][1])
        g.remove_node(input_node.id)

        mapping = {
            node: graph.unique_id(node)
            for node in g.nodes(data=False) if node in graph
        }
        g = nx.relabel_nodes(g, mapping)
        for val in mapping.values():
            g.node[val]['name'] = val
        graph.add_nodes_from(g.nodes(data=True))
        graph.add_edges_from(g.edges(data=True))
        sorted_nodes = tuple(nx.topological_sort(g))

        outputs.append(Node(graph, sorted_nodes[-1]))
        inputs.append(Node(graph, sorted_nodes[0]))

    split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
    attrs = {
        'out_ports_count': nnet_count,
        'size_splits': split_points,
        'axis': 1,
        'name': split_id
    }
    variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
    prev_layer_node = Node(graph, prev_layer_id)
    prev_layer_node.add_output_port(0)
    graph.create_edge(
        prev_layer_node, variadic_split_node, 0, 0,
        create_edge_attrs(prev_layer_id, variadic_split_node.id,
                          prev_layer_id))

    concat_id = graph.unique_id(prefix='Concat')
    graph.add_node(concat_id, parameters=None, op='concat', kind='op')
    concat_node = Node(graph, concat_id)

    # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
    # and each subnetwork's output to concat_node
    for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
        output_node.add_output_port(0)
        concat_node.add_input_port(i)
        graph.create_edge(
            output_node, concat_node, 0, i,
            create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0))
        graph.create_edge(
            variadic_split_node, input_node, i, 0,
            create_edge_attrs(variadic_split_node.id, input_node.id,
                              variadic_split_node.id, 0, i))
    return concat_id
Exemplo n.º 7
0
 def find_and_replace_pattern(self, graph: Graph):
     for u, v, k, attrs in list(graph.edges(keys=True, data=True)):
         if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
             graph.remove_edge(u, v, k)
             log.debug('Removing control flow edge from {} to {}'.format(
                 u, v))