Example #1
0
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
                      body_parameter_names: list, body_results: list):
    """
    Updates the loop body graph with a sub-graph (for body or condition functions)
    :param body_graph: a loop body graph to be updated
    :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
    :param body_parameter_names: a (unchanged) list of parameters in the loop body graph
    :param body_results: a list of Result nodes that is extended with a list from a sub-graph
    """
    # create a map from a node name in original model to a name in a loop body graph assuming
    # that names in the original model are unique
    # initially, the map contains names for parameters that are common for the body and condition graphs
    map_original_name = {}
    for idx, pb_node in enumerate(subgraph_proto['input_arg']):
        map_original_name[pb_node.name] = body_parameter_names[idx]

    # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
    for pb_node in subgraph_proto['node_def']:
        # create an NX node
        id = body_graph.unique_id(pb_node.name)
        map_original_name[pb_node.name] = id
        body_graph.add_node(id, pb=pb_node, kind='op')
        if hasattr(body_graph, 'op_names_statistic') and hasattr(
                pb_node, 'op'):
            body_graph.op_names_statistic[pb_node.op] += 1

        # add incoming edges based on data_nodes_map
        for dst_port, inp in enumerate(pb_node.input):
            orig_src_id = inp.split(":")[0]

            # TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
            #  skip control flow dependency
            if orig_src_id[0] == '^':
                continue

            src_id = map_original_name[orig_src_id]
            src_port = 0 if len(inp.split(":")) == 1 else int(
                inp.split(":")[-1])
            assert (body_graph.has_node(src_id))

            body_graph.add_edges_from(
                [create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])

    # create Result nodes in the loop body graph
    for output in subgraph_proto['output_arg']:
        output_name = subgraph_proto['ret'][output.name]
        orig_src_id = output_name.split(":")[0]
        src_id = map_original_name[orig_src_id]
        src_port = 0 if len(output_name.split(":")) == 1 \
            else int(output_name.split(":")[-1])
        assert body_graph.has_node(
            src_id
        ), 'The body graph does not contain output with name "{}"'.format(
            src_id)
        body_results.append(
            Node(body_graph, add_opoutput(body_graph, src_id, src_port,
                                          False)))

    return True
Example #2
0
    def load(self, graph: Graph):
        argv = graph.graph['cmd_params']
        if argv.tensorflow_custom_layer_libraries:
            libraries = argv.tensorflow_custom_layer_libraries.split(',')
            for library in libraries:
                log.info('Loading library "{}" with custom operations'.format(
                    library))
                tf_v1.load_op_library(library)

        graph_def, variables_values = load_tf_graph_def(
            graph_file_name=argv.input_model,
            is_binary=not argv.input_model_is_text,
            checkpoint=argv.input_checkpoint,
            user_output_node_names_list=argv.output,
            model_dir=argv.saved_model_dir,
            meta_graph_file=argv.input_meta_graph,
            saved_model_tags=argv.saved_model_tags)

        try:
            tf_v1.import_graph_def(graph_def, name='')
        except:
            log.warning(
                "TensorFlow post-processing of loaded model was unsuccessful. "
                "This is an optional step that Model Optimizer performs for any input model but it is not usually "
                "required for all models."
                "It likely means that the original model is ill-formed. "
                "Model Optimizer will continue converting this model.")

        log.debug("Number of nodes in graph_def: {}".format(len(
            graph_def.node)))  # pylint: disable=no-member

        if argv.tensorboard_logdir:
            tensorboard_util.dump_for_tensorboard(graph_def,
                                                  argv.tensorboard_logdir)

        update_extractors_with_extensions(tf_op_extractors)

        try:
            protobuf2nx(graph, graph_def)
        except Exception as e:
            raise Error(
                'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \
                'File is corrupt or has unsupported format. Details: {}. ' +
                refer_to_faq_msg(44),
                argv.model_name,
                str(e)
            ) from e

        graph.__setattr__('name', argv.model_name)
        # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer
        # and convert_nhwc_to_nchw(graph)
        graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC'
        graph.graph['fw'] = 'tf'

        graph.graph['variables_values'] = variables_values
        del variables_values

        used_tensors = restore_edges(graph, get_tf_edges)

        # Tensor names information corresponding to a node is stored on outgoing edges.
        # As output nodes do not have outgoing edges, fake outputs are required. In the following code
        # for each output Identity node is added, and tensor name for the output is kept
        # on (output, fake output) edge. After Result nodes adding transformation fake outputs
        # are deleted from graph.
        add_outputs_identity(
            graph, graph.nodes - used_tensors,
            lambda g, output, fake_node_name: g.add_edges_from(
                [create_tf_edge(output, fake_node_name, 0)]))

        remove_control_dependency_inputs(graph)

        graph.check_empty_graph(
            'protobuf2nx. It may happen due to problems with loaded model')
        extract_node_attrs(
            graph, lambda node: tf_op_extractor(
                node, check_for_duplicates(tf_op_extractors)))