예제 #1
0
    def load(self, graph: Graph):
        argv = graph.graph['cmd_params']
        try:
            model_nodes, model_params, model_name, iteration_number = load_symbol_def(
                argv.input_model, argv.input_symbol, argv.input,
                argv.nd_prefix_name, argv.pretrained_model_name,
                argv.legacy_mxnet_model)
        except (ValueError, mxnet.base.MXNetError) as e:
            raise FrameworkError(
                'The following error happened while loading mxnet model {}: {}. '
                + refer_to_faq_msg(53), argv.input_model, str(e)) from e

        if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
            save_params_file(model_name, model_params._arg_params,
                             model_params._aux_params, iteration_number)

        update_extractors_with_extensions(mxnet_op_extractors)
        symbol2nx(graph, model_nodes, model_params, argv.input)
        graph.check_empty_graph(
            'symbol2nx. It may happen due to problems with loaded model')

        graph.graph['layout'] = 'NCHW'
        graph.graph['fw'] = 'mxnet'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3

        extract_node_attrs(graph, mxnet_op_extractor)
예제 #2
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str,
           output_dir: str):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(
            input_model, argv.input_symbol, argv.input, argv.nd_prefix_name,
            argv.pretrained_model_name, argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53), input_model, str(e)) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params,
                         model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    graph.check_empty_graph(
        'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3

    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph[
        'ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

    extract_node_attrs(graph, mxnet_op_extractor)

    # --------------------------------- LOAD END ------------------------------------------------------

    class_registration.apply_replacements(graph, [
        class_registration.ClassType.FRONT_REPLACER,
        class_registration.ClassType.MIDDLE_REPLACER,
        class_registration.ClassType.BACK_REPLACER
    ])

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0
예제 #3
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, outputs: list, output_dir: str,
           scale: float,
           placeholder_shapes: [None, list, np.array] = None,
           mean_scale_values: [dict, list] = ()):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(input_model, argv.input_symbol,
                                                                                  argv.input,
                                                                                  argv.nd_prefix_name,
                                                                                  argv.pretrained_model_name,
                                                                                  argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53),
            input_model,
            str(e)
        ) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    check_empty_graph(graph, 'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
    graph = extract_node_attrs(graph, mxnet_op_extractor)
    check_softmax_node_inputs(graph)

    user_shapes, packed_outputs, _ = user_data_repack(graph, placeholder_shapes, outputs, None)
    output_op_nodes = add_output_ops(graph, packed_outputs)
    input_op_nodes = add_input_ops(graph, user_shapes, True)

    try:
        override_placeholder_shapes(graph, user_shapes, argv.batch)
    except ValueError as err:
        raise Error(
            'The following error happened while processing input shapes: {}. ' +
            refer_to_faq_msg(54),
            str(err)
        ) from err
    check_empty_graph(graph, 'add_output_ops and add_input_ops')

    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
    add_input_data_to_prior_boxes(graph, argv.input)

    graph = create_tensor_nodes(graph)

    graph_clean_up(graph)
    remove_output_ops(graph)
    mark_outputs(graph)
    remove_output_ops(graph)

    graph_clean_up(graph)

    log.debug("After removing specific nodes for output")

    print_graph_stat(graph)

    graph = partial_infer(graph)
    graph_clean_up(graph)
    check_empty_graph(graph, 'partial_infer')

    duplicate_shared_weights(graph)

    scale_input(graph, scale)
    add_mean_scale_values(graph, mean_scale_values)

    remove_op_nodes(graph, {'identity': True})

    graph_clean_up(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
    fuse_pad(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    fuse_pad(graph)

    # Converting Mul->Add to ScaleShift node
    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)

    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0
예제 #4
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str,
           output_dir: str):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(
            input_model, argv.input_symbol, argv.input, argv.nd_prefix_name,
            argv.pretrained_model_name, argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53), input_model, str(e)) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params,
                         model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    graph.check_empty_graph(
        'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
    extract_node_attrs(graph, mxnet_op_extractor)

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    fuse_pad(graph)

    # Converting Mul->Add to ScaleShift node
    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0