Exemplo n.º 1
0
 def test_remove_out_data_for_memory(self):
     graph = build_graph(self.nodes, [('input_node', 'memory_node'),
                                      ('memory_node', 'output_node'),
                                      ('output_node', 'op_output')])
     KaldiRemoveMemoryOutputBackReplacementPattern(
     ).find_and_replace_pattern(graph)
     self.assertNotIn('output_node', graph.node)
Exemplo n.º 2
0
 def test_do_not_remove_out_data_for_memory(self):
     graph = build_graph(self.nodes, [('input_node', 'memory_node')])
     graph.add_node('output_node', **self.nodes['output_node'])
     graph.add_edge('memory_node', 'output_node', out=0)
     KaldiRemoveMemoryOutputBackReplacementPattern(
     ).find_and_replace_pattern(graph)
     self.assertIn('output_node', graph.node)
Exemplo n.º 3
0
 def test_remove_out_data_for_memory(self):
     graph = build_graph(self.nodes, [('input_node', 'memory_node')])
     # Need for matching in pattern. The edge memory_node->out_node must contain only the attribute 'out' = 0
     # build_graph creates edge  memory_node->out_node with attributes 'in' and 'out'
     graph.add_node('output_node',
                    is_output=True,
                    **self.nodes['output_node'])
     graph.add_edge('memory_node', 'output_node', out=0)
     KaldiRemoveMemoryOutputBackReplacementPattern(
     ).find_and_replace_pattern(graph)
     self.assertNotIn('output_node', graph.node)
Exemplo n.º 4
0
def driver(argv, input_model, output_model_name, output_dir):
    meta_info = get_meta_info(argv)

    EltwiseChecker.enabled = False

    try:
        graph, input_shapes = load_kaldi_model(input_model)
    except Exception as e:
        raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) +
                    refer_to_faq_msg(91)) from e
    graph.check_empty_graph('load_kaldi_nnet_model')
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'kaldi'
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
    update_extractors_with_extensions(kaldi_type_extractors)
    extract_node_attrs(graph, lambda node: kaldi_extractor(node))

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)

    graph = partial_infer(graph)

    # The order is intentional, firstly eliminate repeated, then remove redundant
    FuseRepeatedReshapes().find_and_replace_pattern(graph)
    EliminateRedundantReshape().find_and_replace_pattern(graph)
    graph.check_empty_graph('partial_infer')
    if argv.counts:
        try:
            counts = read_counts_file(argv.counts)
        except Exception as e:
            raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
                        refer_to_faq_msg(92)) from e

        apply_biases_to_last_layer(graph, counts)

    if argv.remove_output_softmax:
        RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
        graph_clean_up(graph)
        log.debug("After removing softmax")
        graph.print_graph_stat()

    # Intentionally after all transformations
    KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)

    remove_const_ops(graph)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    remove_output_ops(graph)

    prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
    return 0
Exemplo n.º 5
0
def driver(argv, input_model, output_model_name, output_dir):
    log_step(argv.steps, 'LOAD')
    meta_info = get_meta_info(argv)

    EltwiseChecker.enabled = False

    try:
        graph = load_kaldi_model(input_model)
    except Exception as e:
        raise Error('Model Optimizer is not able to parse Kaldi model {}. '.format(input_model) +
                    refer_to_faq_msg(91)) from e
    graph.check_empty_graph('load_kaldi_nnet_model')
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'kaldi'

    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

    update_extractors_with_extensions(kaldi_type_extractors)
    extract_node_attrs(graph, lambda node: kaldi_extractor(node))

    # --------------------------------- LOAD END ------------------------------------------------------
    log_step(argv.steps, 'FRONT')
    ReplaceLSTMNodePattern().find_and_replace_pattern(graph)
    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
    log_step(argv.steps, 'MIDDLE')
    graph = partial_infer(graph)

    ReplacePNormNodePattern().find_and_replace_pattern(graph)
    ReplaceMemoryOffsetNodePattern().find_and_replace_pattern(graph)
    ReplaceMemoryOffsetWithMemoryNodePattern().find_and_replace_pattern(graph)
    RemoveMemoryDuplicationPattern().find_and_replace_pattern(graph)
    MergeNeighborSplicePattern().find_and_replace_pattern(graph)
    RemoveUselessCropsPattern().find_and_replace_pattern(graph)
    RemoveIdentity().find_and_replace_pattern(graph)
    graph_clean_up(graph)

    AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)

    ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
    graph_clean_up(graph)

    # The order is intentional, firstly eliminate repeated, then remove redundant
    FuseRepeatedReshapes().find_and_replace_pattern(graph)
    EliminateRedundantReshape().find_and_replace_pattern(graph)
    graph_clean_up(graph)
    graph.check_empty_graph('partial_infer')
    if argv.counts:
        try:
            counts = read_counts_file(argv.counts)
        except Exception as e:
            raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
                        refer_to_faq_msg(92)) from e

        apply_biases_to_last_layer(graph, counts)

    if argv.remove_output_softmax:
        RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
        graph_clean_up(graph)
        log.debug("After removing softmax")
        graph.print_graph_stat()

    log_step(argv.steps, 'BACK')
    LeakyReluToReluWithNegativeSlope().find_and_replace_pattern(graph)
    TransposeToPermute().find_and_replace_pattern(graph)
    DivideToEltwises().find_and_replace_pattern(graph)
    SubtractToEltwises().find_and_replace_pattern(graph)
    SimpleEltwiseToEltwiseOp().find_and_replace_pattern(graph)
    for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected)

    # Intentionally after all transformations
    if argv.remove_memory:
        CutMemory().find_and_replace_pattern(graph)
        graph_clean_up(graph)
    ParameterToInput().find_and_replace_pattern(graph)

    KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
    ForceStrictPrecision().find_and_replace_pattern(graph)
    remove_const_ops(graph)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    remove_output_ops(graph)
    log_step(argv.steps, 'EMIT')
    prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
    return 0
Exemplo n.º 6
0
def driver(argv,
           input_model,
           output_model_name,
           outputs,
           output_dir,
           scale,
           placeholder_shapes=None,
           mean_scale_values=()):
    meta_info = get_meta_info(argv)

    EltwiseChecker.enabled = False

    try:
        graph, input_shapes = load_kaldi_model(input_model)
    except Exception as e:
        raise Error('Model Optimizer is not able to read Kaldi model {}. '.
                    format(input_model) + refer_to_faq_msg(91)) from e
    check_empty_graph(graph, 'load_kaldi_nnet_model')
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'kaldi'
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4

    update_extractors_with_extensions(kaldi_type_extractors)

    extract_node_attrs(graph, lambda node: kaldi_extractor(node))

    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)

    output_op_nodes = add_output_ops(
        graph, outputs)  # TODO pass real outputs instead of None
    log.debug("After adding specific nodes for outputs")
    print_graph_stat(graph)

    check_empty_graph(graph, 'add_output_ops')
    create_tensor_nodes(graph)

    graph_clean_up(graph)
    log.debug("After removing specific nodes for output")
    print_graph_stat(graph)

    override_placeholder_shapes(graph, placeholder_shapes)
    override_batch(graph, argv.batch)

    graph_clean_up(graph)
    log.debug("After setting input shapes")
    print_graph_stat(graph)
    graph_clean_up(graph)
    remove_output_ops(graph)
    log.debug("After removing specific nodes for output")
    print_graph_stat(graph)

    # You need to pass required network outputs here
    # but we don't have a way yet, so just passing all discovered sinks
    mark_outputs(graph)
    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)
    graph = partial_infer(graph)

    # The order is intentional, firstly eliminate repeated, then remove redundant
    FuseRepeatedReshapes().find_and_replace_pattern(graph)
    EliminateRedundantReshape().find_and_replace_pattern(graph)
    check_empty_graph(graph, 'partial_infer')
    if argv.counts:
        try:
            counts = read_counts_file(argv.counts)
        except Exception as e:
            raise Error('Model Optimizer is not able to read counts file {}'.
                        format(argv.counts) + refer_to_faq_msg(92)) from e

        apply_biases_to_last_layer(graph, counts)

    if argv.remove_output_softmax:
        RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
        graph_clean_up(graph)
        log.debug("After removing softmax")
        print_graph_stat(graph)

    # Intentionally after all transformations
    KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(
        graph)
    prepare_emit_ir(graph,
                    argv.data_type,
                    output_dir,
                    output_model_name,
                    meta_info=meta_info)
    return 0