Пример #1
0
def read(path, parser_configs=None):
    # type: (str, typing.Optional[typing.List[NNEFParserConfig]])->NNEFGraph

    if not (path.endswith('.tgz') or path.endswith('.nnef')
            or path.endswith('.txt') or os.path.isdir(path)):
        raise utils.NNEFToolsException(
            "Only .tgz or .nnef or .txt files or directories are supported")

    parser_config = NNEFParserConfig.combine_configs(
        parser_configs if parser_configs else [])

    path_to_load = None
    compressed = False

    try:
        if os.path.isdir(path):
            compressed = False
            with_weights = True
            path_to_load = path
        elif path.endswith('.tgz'):
            compressed = True
            with_weights = True
            path_to_load = tempfile.mkdtemp(prefix="nnef_")
            utils.tgz_extract(path, path_to_load)
        elif path.endswith('.nnef') or path.endswith('.txt'):
            compressed = False
            with_weights = False
            path_to_load = path
        else:
            assert False

        # If there are fragments in the graph and also in parser_config
        # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition
        if parser_config.fragments:
            re_graph = re.compile(r"^graph\s|\sgraph\s")
            re_fragment = re.compile(r"^fragment\s|\sfragment\s")
            graph_nnef_path = os.path.join(
                path_to_load,
                'graph.nnef') if os.path.isdir(path_to_load) else path_to_load
            with open(graph_nnef_path, 'r') as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if re_fragment.search(line):
                        parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments
                        break
                    if re_graph.search(line):
                        break

        return _read(parser_graph=parser_config.infer_shapes(
            parser_config.load_graph(path_to_load)),
                     with_weights=with_weights)

    finally:
        if compressed and path_to_load:
            shutil.rmtree(path_to_load)
Пример #2
0
def read(
    path,  # type: str
    parser_configs=None,  # type: typing.Optional[typing.List[NNEFParserConfig]]
    input_shape=None,  # type: typing.Union[None, typing.List[int], typing.Dict[str, typing.List[int]]]
):
    # type: (...)->NNEFGraph

    if not (path.endswith('.tgz') or path.endswith('.nnef')
            or path.endswith('.txt') or os.path.isdir(path)):
        raise utils.NNEFToolsException(
            "Only .tgz or .nnef or .txt files or directories are supported")

    parser_config = NNEFParserConfig.combine_configs(
        parser_configs if parser_configs else [])

    path_to_load = None
    compressed = False

    try:
        if os.path.isdir(path):
            compressed = False
            with_weights = True
            path_to_load = path
        elif path.endswith('.tgz'):
            compressed = True
            with_weights = True
            path_to_load = tempfile.mkdtemp(prefix="nnef_")
            utils.tgz_extract(path, path_to_load)
        elif path.endswith('.nnef') or path.endswith('.txt'):
            compressed = False
            with_weights = False
            path_to_load = path
        else:
            assert False

        # If there are fragments in the graph and also in parser_config
        # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition
        if parser_config.fragments:
            re_graph = re.compile(r"^graph\s|\sgraph\s")
            re_fragment = re.compile(r"^fragment\s|\sfragment\s")
            graph_nnef_path = os.path.join(
                path_to_load,
                'graph.nnef') if os.path.isdir(path_to_load) else path_to_load
            with open(graph_nnef_path, 'r') as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if re_fragment.search(line):
                        parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments
                        break
                    if re_graph.search(line):
                        break

        parser_graph = parser_config.load_graph(path_to_load)

        if input_shape is not None:
            if not isinstance(input_shape, (list, dict)):
                raise utils.NNEFToolsException(
                    "input_shape must be list or dict")

            for op in parser_graph.operations:
                if op.name == 'external':
                    if isinstance(input_shape, dict):
                        name = op.outputs['output']
                        if name in input_shape:
                            op.attribs['shape'] = input_shape[name]
                    else:
                        op.attribs['shape'] = input_shape

        parser_config.infer_shapes(parser_graph)
        return _read(parser_graph=parser_graph, with_weights=with_weights)

    finally:
        if compressed and path_to_load:
            shutil.rmtree(path_to_load)
Пример #3
0
def main():
    try:
        args = get_args(sys.argv)

        args.params = InputSources(args.params)

        if args.seed != -1:
            np.random.seed(args.seed)

        parent_dir_of_input_model = os.path.dirname(
            utils.path_without_trailing_separator(args.network))

        tmp_dir = None
        if args.network.endswith('.tgz'):
            nnef_path = tmp_dir = tempfile.mkdtemp(
                prefix="nnef_", dir=parent_dir_of_input_model)
            utils.tgz_extract(args.network, nnef_path)
        else:
            nnef_path = args.network

        try:
            parser_configs = NNEFParserConfig.load_configs(
                args.custom_operations, load_standard=True)
            reader = nnef_io.Reader(parser_configs=parser_configs)

            # read without weights
            graph = reader(
                os.path.join(nnef_path, 'graph.nnef') if os.path.
                isdir(nnef_path) else nnef_path)
            if os.path.isdir(nnef_path):
                output_path = nnef_path
            elif nnef_path.endswith('.nnef') or nnef_path.endswith('.txt'):
                output_path = tmp_dir = tempfile.mkdtemp(
                    prefix="nnef_", dir=parent_dir_of_input_model)
            else:
                assert False

            did_generate_weights = generate_weights(graph,
                                                    nnef_path,
                                                    output_path,
                                                    input_sources=args.params)
            nnef_path = output_path

            if tmp_dir and did_generate_weights:
                if args.network.endswith('.tgz'):
                    print("Info: Changing input archive", file=sys.stderr)
                    shutil.move(args.network,
                                args.network + '.nnef-tools-backup')
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=args.network)
                    os.remove(args.network + '.nnef-tools-backup')
                else:
                    output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz'
                    backup_path = output_path + '.nnef-tools-backup'
                    if os.path.exists(output_path):
                        shutil.move(output_path, backup_path)
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=output_path)
                    if os.path.exists(backup_path):
                        os.remove(backup_path)
        finally:
            if tmp_dir:
                shutil.rmtree(tmp_dir)
    except Exception as e:
        print('Error: {}'.format(e), file=sys.stderr)
        exit(1)
Пример #4
0
def run_using_argv(argv):
    try:
        args = get_args(argv)
        write_outputs = args.output_names is None or args.output_names

        if args.input is None:
            if sys.stdin.isatty():
                raise utils.NNEFToolsException("No input provided!")
            utils.set_stdin_to_binary()

        if write_outputs:
            if args.output is None:
                if sys.stdout.isatty():
                    raise utils.NNEFToolsException("No output provided!")
                utils.set_stdout_to_binary()

        parent_dir_of_input_model = os.path.dirname(
            utils.path_without_trailing_separator(args.network))
        tmp_dir = None

        if args.network.endswith('.tgz'):
            nnef_path = tmp_dir = tempfile.mkdtemp(
                prefix="nnef_", dir=parent_dir_of_input_model)
            utils.tgz_extract(args.network, nnef_path)
        else:
            nnef_path = args.network

        try:
            parser_configs = NNEFParserConfig.load_configs(
                args.custom_operations, load_standard=True)

            # read without weights
            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    infer_shapes=False)
            graph = reader(
                os.path.join(nnef_path, 'graph.nnef') if os.path.
                isdir(nnef_path) else nnef_path)

            if args.input is None:
                inputs = tuple(
                    nnef.read_tensor(sys.stdin)
                    for _ in range(len(graph.inputs)))
            elif len(args.input) == 1 and os.path.isdir(args.input[0]):
                inputs = tuple(
                    nnef_io.read_nnef_tensor(
                        os.path.join(args.input[0], tensor.name + '.dat'))
                    for tensor in graph.inputs)
            else:
                inputs = tuple(
                    nnef_io.read_nnef_tensor(path) for path in args.input)

            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    input_shape=tuple(
                                        list(input.shape) for input in inputs))

            graph = reader(nnef_path)

            tensor_hooks = []

            stats_hook = None
            if args.stats:
                stats_hook = backend.StatisticsHook()
                tensor_hooks.append(stats_hook)

            if write_outputs and args.output_names is not None:
                if '*' in args.output_names:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=[
                                t.name for t in graph.tensors
                                if not t.is_constant and not t.is_variable
                            ],
                            output_directory=args.output))
                else:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=args.output_names,
                            output_directory=args.output))

            if args.permissive:
                backend.try_to_fix_unsupported_attributes(graph)

            outputs = backend.run(nnef_graph=graph,
                                  inputs=inputs,
                                  device=args.device,
                                  custom_operations=get_custom_runners(
                                      args.custom_operations),
                                  tensor_hooks=tensor_hooks)

            if write_outputs and args.output_names is None:
                if args.output is None:
                    for array in outputs:
                        nnef.write_tensor(sys.stdout, array)
                else:
                    for tensor, array in zip(graph.outputs, outputs):
                        nnef_io.write_nnef_tensor(
                            os.path.join(args.output, tensor.name + '.dat'),
                            array)

            if stats_hook:
                if args.stats.endswith('/') or args.stats.endswith('\\'):
                    stats_path = os.path.join(nnef_path, args.stats,
                                              'graph.stats')
                else:
                    stats_path = os.path.join(nnef_path, args.stats)
                stats_hook.save_statistics(stats_path)

            if tmp_dir and (args.stats and _is_inside(nnef_path, args.stats)):
                if args.network.endswith('.tgz'):
                    print("Info: Changing input archive", file=sys.stderr)
                    shutil.move(args.network,
                                args.network + '.nnef-tools-backup')
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=args.network)
                    os.remove(args.network + '.nnef-tools-backup')
                else:
                    output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz'
                    backup_path = output_path + '.nnef-tools-backup'
                    if os.path.exists(output_path):
                        shutil.move(output_path, backup_path)
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=output_path)
                    if os.path.exists(backup_path):
                        os.remove(backup_path)
        finally:
            if tmp_dir:
                shutil.rmtree(tmp_dir)
    except utils.NNEFToolsException as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)
    except nnef.Error as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)