def write( nnef_graph, # type: NNEFGraph tgz_or_dir_path, # type: str write_weights=True, # type: bool raise_on_missing_weight=True, # type: bool extensions=None, # type: typing.Optional[typing.List[str]] fragments=None, # type: typing.Optional[str] only_print_used_fragments=False, # type: bool compression_level=0, # type: int ): # type: (...) -> None compressed = tgz_or_dir_path.endswith('.tgz') dir_path = None try: if compressed: dir_path = tempfile.mkdtemp(prefix="nnef_") else: dir_path = tgz_or_dir_path if not os.path.exists(dir_path): os.makedirs(dir_path) with open(os.path.join(dir_path, "graph.nnef"), "w") as f: _print(nnef_graph, file_handle=f, extensions=extensions, fragments=fragments, only_print_used_fragments=only_print_used_fragments) if any(t.quantization is not None for t in nnef_graph.tensors): with open(os.path.join(dir_path, "graph.quant"), "w") as f: _print_quantization(nnef_graph, file_handle=f) if write_weights: _write_weights(nnef_graph, dir_path=dir_path, raise_on_missing_weight=raise_on_missing_weight) if compressed: utils.tgz_compress(dir_path, tgz_or_dir_path, compression_level=compression_level) finally: if compressed and dir_path: shutil.rmtree(dir_path)
def main(): try: args = get_args(sys.argv) args.params = InputSources(args.params) if args.seed != -1: np.random.seed(args.seed) parent_dir_of_input_model = os.path.dirname( utils.path_without_trailing_separator(args.network)) tmp_dir = None if args.network.endswith('.tgz'): nnef_path = tmp_dir = tempfile.mkdtemp( prefix="nnef_", dir=parent_dir_of_input_model) utils.tgz_extract(args.network, nnef_path) else: nnef_path = args.network try: parser_configs = NNEFParserConfig.load_configs( args.custom_operations, load_standard=True) reader = nnef_io.Reader(parser_configs=parser_configs) # read without weights graph = reader( os.path.join(nnef_path, 'graph.nnef') if os.path. isdir(nnef_path) else nnef_path) if os.path.isdir(nnef_path): output_path = nnef_path elif nnef_path.endswith('.nnef') or nnef_path.endswith('.txt'): output_path = tmp_dir = tempfile.mkdtemp( prefix="nnef_", dir=parent_dir_of_input_model) else: assert False did_generate_weights = generate_weights(graph, nnef_path, output_path, input_sources=args.params) nnef_path = output_path if tmp_dir and did_generate_weights: if args.network.endswith('.tgz'): print("Info: Changing input archive", file=sys.stderr) shutil.move(args.network, args.network + '.nnef-tools-backup') utils.tgz_compress(dir_path=nnef_path, file_path=args.network) os.remove(args.network + '.nnef-tools-backup') else: output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz' backup_path = output_path + '.nnef-tools-backup' if os.path.exists(output_path): shutil.move(output_path, backup_path) utils.tgz_compress(dir_path=nnef_path, file_path=output_path) if os.path.exists(backup_path): os.remove(backup_path) finally: if tmp_dir: shutil.rmtree(tmp_dir) except Exception as e: print('Error: {}'.format(e), file=sys.stderr) exit(1)
def run_using_argv(argv): try: args = get_args(argv) write_outputs = args.output_names is None or args.output_names if args.input is None: if sys.stdin.isatty(): raise utils.NNEFToolsException("No input provided!") utils.set_stdin_to_binary() if write_outputs: if args.output is None: if sys.stdout.isatty(): raise utils.NNEFToolsException("No output provided!") utils.set_stdout_to_binary() parent_dir_of_input_model = os.path.dirname( utils.path_without_trailing_separator(args.network)) tmp_dir = None if args.network.endswith('.tgz'): nnef_path = tmp_dir = tempfile.mkdtemp( prefix="nnef_", dir=parent_dir_of_input_model) utils.tgz_extract(args.network, nnef_path) else: nnef_path = args.network try: parser_configs = NNEFParserConfig.load_configs( args.custom_operations, load_standard=True) # read without weights reader = nnef_io.Reader(parser_configs=parser_configs, infer_shapes=False) graph = reader( os.path.join(nnef_path, 'graph.nnef') if os.path. isdir(nnef_path) else nnef_path) if args.input is None: inputs = tuple( nnef.read_tensor(sys.stdin) for _ in range(len(graph.inputs))) elif len(args.input) == 1 and os.path.isdir(args.input[0]): inputs = tuple( nnef_io.read_nnef_tensor( os.path.join(args.input[0], tensor.name + '.dat')) for tensor in graph.inputs) else: inputs = tuple( nnef_io.read_nnef_tensor(path) for path in args.input) reader = nnef_io.Reader(parser_configs=parser_configs, input_shape=tuple( list(input.shape) for input in inputs)) graph = reader(nnef_path) tensor_hooks = [] stats_hook = None if args.stats: stats_hook = backend.StatisticsHook() tensor_hooks.append(stats_hook) if write_outputs and args.output_names is not None: if '*' in args.output_names: tensor_hooks.append( backend.ActivationExportHook( tensor_names=[ t.name for t in graph.tensors if not t.is_constant and not t.is_variable ], output_directory=args.output)) else: tensor_hooks.append( backend.ActivationExportHook( tensor_names=args.output_names, output_directory=args.output)) if args.permissive: backend.try_to_fix_unsupported_attributes(graph) outputs = backend.run(nnef_graph=graph, inputs=inputs, device=args.device, custom_operations=get_custom_runners( args.custom_operations), tensor_hooks=tensor_hooks) if write_outputs and args.output_names is None: if args.output is None: for array in outputs: nnef.write_tensor(sys.stdout, array) else: for tensor, array in zip(graph.outputs, outputs): nnef_io.write_nnef_tensor( os.path.join(args.output, tensor.name + '.dat'), array) if stats_hook: if args.stats.endswith('/') or args.stats.endswith('\\'): stats_path = os.path.join(nnef_path, args.stats, 'graph.stats') else: stats_path = os.path.join(nnef_path, args.stats) stats_hook.save_statistics(stats_path) if tmp_dir and (args.stats and _is_inside(nnef_path, args.stats)): if args.network.endswith('.tgz'): print("Info: Changing input archive", file=sys.stderr) shutil.move(args.network, args.network + '.nnef-tools-backup') utils.tgz_compress(dir_path=nnef_path, file_path=args.network) os.remove(args.network + '.nnef-tools-backup') else: output_path = args.network.rsplit('.', 1)[0] + '.nnef.tgz' backup_path = output_path + '.nnef-tools-backup' if os.path.exists(output_path): shutil.move(output_path, backup_path) utils.tgz_compress(dir_path=nnef_path, file_path=output_path) if os.path.exists(backup_path): os.remove(backup_path) finally: if tmp_dir: shutil.rmtree(tmp_dir) except utils.NNEFToolsException as e: print("Error: " + str(e), file=sys.stderr) exit(1) except nnef.Error as e: print("Error: " + str(e), file=sys.stderr) exit(1)