Beispiel #1
0
def write(
        nnef_graph,  # type: NNEFGraph
        tgz_or_dir_path,  # type: str
        write_weights=True,  # type: bool
        raise_on_missing_weight=True,  # type: bool
        extensions=None,  # type: typing.Optional[typing.List[str]]
        fragments=None,  # type: typing.Optional[str]
        only_print_used_fragments=False,  # type: bool
        compression_level=0,  # type: int
):
    # type: (...) -> None

    compressed = tgz_or_dir_path.endswith('.tgz')
    dir_path = None

    try:
        if compressed:
            dir_path = tempfile.mkdtemp(prefix="nnef_")
        else:
            dir_path = tgz_or_dir_path
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)

        with open(os.path.join(dir_path, "graph.nnef"), "w") as f:
            _print(nnef_graph,
                   file_handle=f,
                   extensions=extensions,
                   fragments=fragments,
                   only_print_used_fragments=only_print_used_fragments)

        if any(t.quantization is not None for t in nnef_graph.tensors):
            with open(os.path.join(dir_path, "graph.quant"), "w") as f:
                _print_quantization(nnef_graph, file_handle=f)

        if write_weights:
            _write_weights(nnef_graph,
                           dir_path=dir_path,
                           raise_on_missing_weight=raise_on_missing_weight)

        if compressed:
            utils.tgz_compress(dir_path,
                               tgz_or_dir_path,
                               compression_level=compression_level)
    finally:
        if compressed and dir_path:
            shutil.rmtree(dir_path)
Beispiel #2
0
def main():
    try:
        args = get_args(sys.argv)

        args.params = InputSources(args.params)

        if args.seed != -1:
            np.random.seed(args.seed)

        parent_dir_of_input_model = os.path.dirname(
            utils.path_without_trailing_separator(args.network))

        tmp_dir = None
        if args.network.endswith('.tgz'):
            nnef_path = tmp_dir = tempfile.mkdtemp(
                prefix="nnef_", dir=parent_dir_of_input_model)
            utils.tgz_extract(args.network, nnef_path)
        else:
            nnef_path = args.network

        try:
            parser_configs = NNEFParserConfig.load_configs(
                args.custom_operations, load_standard=True)
            reader = nnef_io.Reader(parser_configs=parser_configs)

            # read without weights
            graph = reader(
                os.path.join(nnef_path, 'graph.nnef') if os.path.
                isdir(nnef_path) else nnef_path)
            if os.path.isdir(nnef_path):
                output_path = nnef_path
            elif nnef_path.endswith('.nnef') or nnef_path.endswith('.txt'):
                output_path = tmp_dir = tempfile.mkdtemp(
                    prefix="nnef_", dir=parent_dir_of_input_model)
            else:
                assert False

            did_generate_weights = generate_weights(graph,
                                                    nnef_path,
                                                    output_path,
                                                    input_sources=args.params)
            nnef_path = output_path

            if tmp_dir and did_generate_weights:
                if args.network.endswith('.tgz'):
                    print("Info: Changing input archive", file=sys.stderr)
                    shutil.move(args.network,
                                args.network + '.nnef-tools-backup')
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=args.network)
                    os.remove(args.network + '.nnef-tools-backup')
                else:
                    output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz'
                    backup_path = output_path + '.nnef-tools-backup'
                    if os.path.exists(output_path):
                        shutil.move(output_path, backup_path)
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=output_path)
                    if os.path.exists(backup_path):
                        os.remove(backup_path)
        finally:
            if tmp_dir:
                shutil.rmtree(tmp_dir)
    except Exception as e:
        print('Error: {}'.format(e), file=sys.stderr)
        exit(1)
Beispiel #3
0
def run_using_argv(argv):
    try:
        args = get_args(argv)
        write_outputs = args.output_names is None or args.output_names

        if args.input is None:
            if sys.stdin.isatty():
                raise utils.NNEFToolsException("No input provided!")
            utils.set_stdin_to_binary()

        if write_outputs:
            if args.output is None:
                if sys.stdout.isatty():
                    raise utils.NNEFToolsException("No output provided!")
                utils.set_stdout_to_binary()

        parent_dir_of_input_model = os.path.dirname(
            utils.path_without_trailing_separator(args.network))
        tmp_dir = None

        if args.network.endswith('.tgz'):
            nnef_path = tmp_dir = tempfile.mkdtemp(
                prefix="nnef_", dir=parent_dir_of_input_model)
            utils.tgz_extract(args.network, nnef_path)
        else:
            nnef_path = args.network

        try:
            parser_configs = NNEFParserConfig.load_configs(
                args.custom_operations, load_standard=True)

            # read without weights
            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    infer_shapes=False)
            graph = reader(
                os.path.join(nnef_path, 'graph.nnef') if os.path.
                isdir(nnef_path) else nnef_path)

            if args.input is None:
                inputs = tuple(
                    nnef.read_tensor(sys.stdin)
                    for _ in range(len(graph.inputs)))
            elif len(args.input) == 1 and os.path.isdir(args.input[0]):
                inputs = tuple(
                    nnef_io.read_nnef_tensor(
                        os.path.join(args.input[0], tensor.name + '.dat'))
                    for tensor in graph.inputs)
            else:
                inputs = tuple(
                    nnef_io.read_nnef_tensor(path) for path in args.input)

            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    input_shape=tuple(
                                        list(input.shape) for input in inputs))

            graph = reader(nnef_path)

            tensor_hooks = []

            stats_hook = None
            if args.stats:
                stats_hook = backend.StatisticsHook()
                tensor_hooks.append(stats_hook)

            if write_outputs and args.output_names is not None:
                if '*' in args.output_names:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=[
                                t.name for t in graph.tensors
                                if not t.is_constant and not t.is_variable
                            ],
                            output_directory=args.output))
                else:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=args.output_names,
                            output_directory=args.output))

            if args.permissive:
                backend.try_to_fix_unsupported_attributes(graph)

            outputs = backend.run(nnef_graph=graph,
                                  inputs=inputs,
                                  device=args.device,
                                  custom_operations=get_custom_runners(
                                      args.custom_operations),
                                  tensor_hooks=tensor_hooks)

            if write_outputs and args.output_names is None:
                if args.output is None:
                    for array in outputs:
                        nnef.write_tensor(sys.stdout, array)
                else:
                    for tensor, array in zip(graph.outputs, outputs):
                        nnef_io.write_nnef_tensor(
                            os.path.join(args.output, tensor.name + '.dat'),
                            array)

            if stats_hook:
                if args.stats.endswith('/') or args.stats.endswith('\\'):
                    stats_path = os.path.join(nnef_path, args.stats,
                                              'graph.stats')
                else:
                    stats_path = os.path.join(nnef_path, args.stats)
                stats_hook.save_statistics(stats_path)

            if tmp_dir and (args.stats and _is_inside(nnef_path, args.stats)):
                if args.network.endswith('.tgz'):
                    print("Info: Changing input archive", file=sys.stderr)
                    shutil.move(args.network,
                                args.network + '.nnef-tools-backup')
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=args.network)
                    os.remove(args.network + '.nnef-tools-backup')
                else:
                    output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz'
                    backup_path = output_path + '.nnef-tools-backup'
                    if os.path.exists(output_path):
                        shutil.move(output_path, backup_path)
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=output_path)
                    if os.path.exists(backup_path):
                        os.remove(backup_path)
        finally:
            if tmp_dir:
                shutil.rmtree(tmp_dir)
    except utils.NNEFToolsException as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)
    except nnef.Error as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)