def find_and_replace_pattern(self, graph: Graph):
        transposed_for_IE_flag = 'transposed_for_IE'
        for fc_node in graph.get_op_nodes(type='FullyConnected'):
            weights_node = fc_node.in_node(1)
            fc_node.in_edge(1)['bin'] = 'weights'
            if weights_node.has_and_set(transposed_for_IE_flag):
                continue
            weights_node.value = np.transpose(weights_node.value)
            weights_node[transposed_for_IE_flag] = True
            log.debug(
                "Transposed weights {} for FC node {}; weights.shape = {}"
                "".format(weights_node.name, fc_node.name, weights_node.shape))
            weights_node.shape = np.array(weights_node.value.shape)

        # FIXME remove this line and make transformation run recursively when recursive transformation run feature is
        # implemented
        for_each_sub_graph_recursively(graph, self.find_and_replace_pattern)
示例#2
0
def emit_ir(graph: Graph, argv: argparse.Namespace):
    NormalizeTI().find_and_replace_pattern(graph)
    for_graph_and_each_sub_graph_recursively(
        graph,
        RemoveConstOps().find_and_replace_pattern)
    for_graph_and_each_sub_graph_recursively(
        graph,
        CreateConstNodesReplacement().find_and_replace_pattern)
    if not graph.graph['cmd_params'].generate_experimental_IR_V10:
        for_each_sub_graph_recursively(
            graph,
            RemoveOutputOps().find_and_replace_pattern)
    if not graph.graph['cmd_params'].generate_experimental_IR_V10:
        for_graph_and_each_sub_graph_recursively(
            graph,
            RemoveOutputOps().find_and_replace_pattern)

    prepare_emit_ir(
        graph=graph,
        data_type=graph.graph['cmd_params'].data_type,
        output_dir=argv.output_dir,
        output_model_name=argv.model_name,
        mean_data=graph.graph['mf'] if 'mf' in graph.graph else None,
        input_names=graph.graph['input_names']
        if 'input_names' in graph.graph else [],
        meta_info=get_meta_info(argv))

    if not (argv.framework == 'tf'
            and argv.tensorflow_custom_operations_config_update):
        output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
        print('\n[ SUCCESS ] Generated IR version {} model.'.format(
            get_ir_version(argv)))
        print('[ SUCCESS ] XML file: {}.xml'.format(
            os.path.join(output_dir, argv.model_name)))
        print('[ SUCCESS ] BIN file: {}.bin'.format(
            os.path.join(output_dir, argv.model_name)))

    return 0
示例#3
0
 def find_and_replace_pattern(self, graph: Graph):
     for_each_sub_graph_recursively(
         graph, self.reshapes_with_two_inputs_to_reshape_with_dim)