示例#1
0
def get_reader(input_format,
               output_format,
               input_shape=None,
               permissive=False,
               with_weights=True,
               custom_converters=None):
    if input_format == 'nnef':
        from nnef_tools.io.nnef.nnef_io import Reader

        configs = [NNEFParserConfig.STANDARD_CONFIG]

        if output_format in [
                'tensorflow-pb', 'tensorflow-py', 'tensorflow-lite'
        ]:
            from nnef_tools.conversion.tensorflow import nnef_to_tf
            configs.append(nnef_to_tf.ParserConfig)
        elif output_format in ['onnx']:
            from nnef_tools.conversion.onnx import nnef_to_onnx
            configs.append(nnef_to_onnx.ParserConfig)
        elif output_format in ['caffe']:
            from nnef_tools.conversion.caffe import nnef_to_caffe
            configs.append(nnef_to_caffe.ParserConfig)
        else:
            assert False

        configs += NNEFParserConfig.load_configs(custom_converters,
                                                 load_standard=False)

        return Reader(parser_configs=configs,
                      unify=(output_format in ['caffe']))
    elif input_format == 'tensorflow-pb':
        # TODO custom converter
        from nnef_tools.io.tensorflow.tf_pb_io import Reader
        return Reader(convert_to_tf_py=True, input_shape=input_shape)
    elif input_format == 'tensorflow-py':
        from nnef_tools.io.tensorflow.tf_py_io import Reader
        if custom_converters:
            custom_traceable_functions = get_tf_py_custom_traceable_functions(
                custom_converters)
        else:
            custom_traceable_functions = None
        return Reader(expand_gradients=True,
                      custom_traceable_functions=custom_traceable_functions)
    elif input_format == 'tensorflow-lite':
        # TODO custom converter
        from nnef_tools.io.tensorflow.tflite_io import Reader
        return Reader(convert_to_tf_py=True)
    elif input_format == 'onnx':
        # TODO custom converter
        from nnef_tools.io.onnx.onnx_io import Reader
        return Reader(propagate_shapes=True, input_shape=input_shape)
    elif input_format == 'caffe':
        # TODO custom converter
        from nnef_tools.io.caffe.caffe_io import Reader
        return Reader()
    else:
        assert False
示例#2
0
def main():
    try:
        args = get_args(sys.argv)

        if not args.output:
            if sys.stdout.isatty():
                raise utils.NNEFToolsException("No output provided.")
            utils.set_stdout_to_binary()

        args.params = InputSources(args.params)

        if args.seed != -1:
            np.random.seed(args.seed)

        parser_configs = NNEFParserConfig.load_configs(args.custom_operations,
                                                       load_standard=True)
        reader = nnef_io.Reader(parser_configs=parser_configs,
                                input_shape=args.shape)

        # read without weights
        graph = reader(
            os.path.join(args.network, 'graph.nnef') if os.path.
            isdir(args.network) else args.network)

        inputs = tuple(
            args.params.create_input(name=input.name,
                                     np_dtype=input.get_numpy_dtype(),
                                     shape=input.shape,
                                     allow_bigger_batch=True)
            for input in graph.inputs)

        if args.output:
            for tensor, array in zip(graph.inputs, inputs):
                nnef_io.write_nnef_tensor(
                    os.path.join(args.output, tensor.name + '.dat'), array)
        else:
            for array in inputs:
                nnef.write_tensor(sys.stdout, array)
    except Exception as e:
        print('Error: {}'.format(e), file=sys.stderr)
        exit(1)
示例#3
0
def run_using_argv(argv):
    try:
        args = get_args(argv)
        write_outputs = args.output_names is None or args.output_names

        if args.input is None:
            if sys.stdin.isatty():
                raise utils.NNEFToolsException("No input provided!")
            utils.set_stdin_to_binary()

        if write_outputs:
            if args.output is None:
                if sys.stdout.isatty():
                    raise utils.NNEFToolsException("No output provided!")
                utils.set_stdout_to_binary()

        parent_dir_of_input_model = os.path.dirname(
            utils.path_without_trailing_separator(args.network))
        tmp_dir = None

        if args.network.endswith('.tgz'):
            nnef_path = tmp_dir = tempfile.mkdtemp(
                prefix="nnef_", dir=parent_dir_of_input_model)
            utils.tgz_extract(args.network, nnef_path)
        else:
            nnef_path = args.network

        try:
            parser_configs = NNEFParserConfig.load_configs(
                args.custom_operations, load_standard=True)

            # read without weights
            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    infer_shapes=False)
            graph = reader(
                os.path.join(nnef_path, 'graph.nnef') if os.path.
                isdir(nnef_path) else nnef_path)

            if args.input is None:
                inputs = tuple(
                    nnef.read_tensor(sys.stdin)
                    for _ in range(len(graph.inputs)))
            elif len(args.input) == 1 and os.path.isdir(args.input[0]):
                inputs = tuple(
                    nnef_io.read_nnef_tensor(
                        os.path.join(args.input[0], tensor.name + '.dat'))
                    for tensor in graph.inputs)
            else:
                inputs = tuple(
                    nnef_io.read_nnef_tensor(path) for path in args.input)

            reader = nnef_io.Reader(parser_configs=parser_configs,
                                    input_shape=tuple(
                                        list(input.shape) for input in inputs))

            graph = reader(nnef_path)

            tensor_hooks = []

            stats_hook = None
            if args.stats:
                stats_hook = backend.StatisticsHook()
                tensor_hooks.append(stats_hook)

            if write_outputs and args.output_names is not None:
                if '*' in args.output_names:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=[
                                t.name for t in graph.tensors
                                if not t.is_constant and not t.is_variable
                            ],
                            output_directory=args.output))
                else:
                    tensor_hooks.append(
                        backend.ActivationExportHook(
                            tensor_names=args.output_names,
                            output_directory=args.output))

            if args.permissive:
                backend.try_to_fix_unsupported_attributes(graph)

            outputs = backend.run(nnef_graph=graph,
                                  inputs=inputs,
                                  device=args.device,
                                  custom_operations=get_custom_runners(
                                      args.custom_operations),
                                  tensor_hooks=tensor_hooks)

            if write_outputs and args.output_names is None:
                if args.output is None:
                    for array in outputs:
                        nnef.write_tensor(sys.stdout, array)
                else:
                    for tensor, array in zip(graph.outputs, outputs):
                        nnef_io.write_nnef_tensor(
                            os.path.join(args.output, tensor.name + '.dat'),
                            array)

            if stats_hook:
                if args.stats.endswith('/') or args.stats.endswith('\\'):
                    stats_path = os.path.join(nnef_path, args.stats,
                                              'graph.stats')
                else:
                    stats_path = os.path.join(nnef_path, args.stats)
                stats_hook.save_statistics(stats_path)

            if tmp_dir and (args.stats and _is_inside(nnef_path, args.stats)):
                if args.network.endswith('.tgz'):
                    print("Info: Changing input archive", file=sys.stderr)
                    shutil.move(args.network,
                                args.network + '.nnef-tools-backup')
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=args.network)
                    os.remove(args.network + '.nnef-tools-backup')
                else:
                    output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz'
                    backup_path = output_path + '.nnef-tools-backup'
                    if os.path.exists(output_path):
                        shutil.move(output_path, backup_path)
                    utils.tgz_compress(dir_path=nnef_path,
                                       file_path=output_path)
                    if os.path.exists(backup_path):
                        os.remove(backup_path)
        finally:
            if tmp_dir:
                shutil.rmtree(tmp_dir)
    except utils.NNEFToolsException as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)
    except nnef.Error as e:
        print("Error: " + str(e), file=sys.stderr)
        exit(1)