Exemplo n.º 1
0
    def find_and_replace_pattern(self, graph: Graph):
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        EltwiseChecker().find_and_replace_pattern(graph)

        for_graph_and_each_sub_graph_recursively(graph, convert_muladd_to_scaleshift)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        for_graph_and_each_sub_graph_recursively(graph, convert_add_or_mul_to_scaleshift)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
Exemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        EltwiseChecker().find_and_replace_pattern(graph)

        # turn off run of transformation for body of TI due to error on TF_Multicell topology
        convert_muladd_to_scaleshift(graph)
        graph.clean_up()

        # turn off run of transformation for body of TI due to error on TF_Multicell topology
        convert_add_or_mul_to_scaleshift(graph)
        graph.clean_up()
Exemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        # to prevent fusing of non per channel lin ops, we run EltwiseChecker to mark nodes with can_be_fused attribute
        EltwiseChecker().find_and_replace_pattern(graph)
        eltwise_nodes = graph.get_op_nodes(op='Mul', can_be_fused=True) + \
                        graph.get_op_nodes(op='Sub', can_be_fused=True) + \
                        graph.get_op_nodes(op='Add', can_be_fused=True)
        for elt in eltwise_nodes:
            if elt.in_port(0).data.get_value() is not None or elt.in_port(
                    1).data.get_value() is not None:
                elt['fuse_up_to_quantize_ports'] = [3, 4]

        slice = graph.get_op_nodes(op='Slice')
        for sl in slice:
            sl['fuse_up_to_quantize_ports'] = [0]
 def mark_fusable_muls_on_weights(graph):
     for node in graph.get_op_nodes(op='Mul'):
         children = node.out_port(0).get_destinations()
         if len(children) > 1 or children[0].node.soft_get('type') not in [
                 'Convolution', 'Deconvolution', 'MatMul'
         ]:
             continue
         value_in_port = get_value_in_port(node)
         if value_in_port is None:
             continue
         value_shape = value_in_port.data.get_shape()
         non_one_axis = np.argwhere(value_shape != 1)
         if non_one_axis.size != 1:
             continue
         non_one_axis = non_one_axis.item(0)
         node['can_be_fused'] = True
         EltwiseChecker().mark_eltwise_node(node, non_one_axis)
Exemplo n.º 5
0
def tf2nx(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str, is_binary: bool):
    """
    Convert TF GraphDef object to NetworkX representation.
    The resulting graph is still TF-specific and needs normalization passes to be applied.
    The specific TF structure assumes each GraphDef node is converted to a single
    NetworkX node, node id is an original TF node name, and edges go directly from one op   to another op.
    """
    log_step(argv.steps, 'LOAD')
    meta_info = get_meta_info(argv)

    if argv.tensorflow_custom_layer_libraries:
        libraries = argv.tensorflow_custom_layer_libraries.split(',')
        for library in libraries:
            log.info('Loading library "{}" with custom operations'.format(library))
            tf.load_op_library(library)

    graph_def, variables_values = load_tf_graph_def(graph_file_name=model_file_name, is_binary=is_binary,
                                                    checkpoint=argv.input_checkpoint,
                                                    user_output_node_names_list=argv.output,
                                                    model_dir=argv.saved_model_dir,
                                                    meta_graph_file=argv.input_meta_graph,
                                                    saved_model_tags=argv.saved_model_tags)

    try:
        tf.import_graph_def(graph_def, name='')
    except:
        log.warning("TensorFlow post-processing of loaded model was unsuccessful. "
                    "This is an optional step that Model Optimizer performs for any input model but it is not usually "
                    "required for all models."
                    "It likely means that the original model is ill-formed. "
                    "Model Optimizer will continue converting this model.")

    log.debug("Number of nodes in graph_def: {}".format(len(graph_def.node)))  # pylint: disable=no-member

    if argv.tensorboard_logdir:
        tensorboard.dump_for_tensorboard(graph_def, argv.tensorboard_logdir)

    update_extractors_with_extensions(tf_op_extractors)

    try:
        graph = protobuf2nx(graph_def)
        graph.__setattr__('name', output_model_name)
        # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer
        # and convert_nhwc_to_nchw(graph)
        graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC'
        graph.graph['cmd_params'] = argv
        graph.graph['fw'] = 'tf'

        if graph.graph['cmd_params'].generate_experimental_IR_V10:
            version = 10
        else:
            version = 6
        graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

        graph.graph['variables_values'] = variables_values
        del variables_values

        graph = restore_edges(graph, get_tf_edges)
        graph = remove_control_dependency_inputs(graph)
    except Exception as e:
        raise Error(
            'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44),
            model_file_name,
            str(e)
        ) from e

    graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
    extract_node_attrs(graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))

    # --------------------------------- LOAD END ------------------------------------------------------
    log_step(argv.steps, 'FRONT')
    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
    log_step(argv.steps, 'MIDDLE')
    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)
    graph_clean_up_tf(graph)

    for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
    convert_batch_norm(graph)
    graph_clean_up_tf(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        for_graph_and_each_sub_graph_recursively(graph, convert_scale_shift_to_mul_add)
        for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

        # Fusing the sequences of Mul/Add operations
        for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence)
        for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

        # Fusing linear operation to Convolution
        for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
        for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    if not argv.disable_gfusing:
        for_graph_and_each_sub_graph_recursively(graph, grouped_convolutions_fusing)
        graph_clean_up_tf(graph)
        if not argv.disable_fusing:
            fuse_linear_ops(graph)
            graph_clean_up_tf(graph)

    for_graph_and_each_sub_graph_recursively(graph, EltwiseChecker().find_and_replace_pattern)

    # Converting Mul->Add to ScaleShift node
    for_graph_and_each_sub_graph_recursively(graph, convert_muladd_to_scaleshift)

    # Need to eliminate dead nodes before doing update_fully_connected_shapes
    # because update_fully_connected_shapes does partial inference and dead
    # nodes will lead to sporadic failures.
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
    for_graph_and_each_sub_graph_recursively(graph, update_fully_connected_shapes)

    for_graph_and_each_sub_graph_recursively(graph, convert_mul_eltwise_to_leaky_relu)
    graph_clean_up_tf(graph)
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    for_graph_and_each_sub_graph_recursively(graph, convert_add_or_mul_to_scaleshift)  # scale = 1
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        for_graph_and_each_sub_graph_recursively(graph, move_scaleshift_to_preprocess)
        graph_clean_up_tf(graph)

    FuseReshapesSequence().find_and_replace_pattern(graph)
    RemoveRedundantReshapes().find_and_replace_pattern(graph)

    EltwiseInputNormalize().find_and_replace_pattern(graph)

    if argv.enable_concat_optimization:
        ConcatOptimization().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, InsertLayoutPropagationTranspose().find_and_replace_pattern)

    LayoutChangeForConstantShapePaths().find_and_replace_pattern(graph)
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    for_graph_and_each_sub_graph_recursively(graph, apply_nhwc_to_nchw_permutation)
    for_graph_and_each_sub_graph_recursively(graph, merge_nodes_permutations)
    for_graph_and_each_sub_graph_recursively(graph, permute_data_nodes_attrs)
    for_graph_and_each_sub_graph_recursively(graph, permute_op_nodes_attrs)
    for_graph_and_each_sub_graph_recursively(graph, permute_input_data)

    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    graph.graph['layout'] = 'NCHW'

    log_step(argv.steps, 'BACK')
    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
    for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
    for_graph_and_each_sub_graph_recursively(graph, CreateConstNodesReplacement().find_and_replace_pattern)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    log_step(argv.steps, 'EMIT')
    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    meta_info=meta_info)

    return 0