def read(path, parser_configs=None): # type: (str, typing.Optional[typing.List[NNEFParserConfig]])->NNEFGraph if not (path.endswith('.tgz') or path.endswith('.nnef') or path.endswith('.txt') or os.path.isdir(path)): raise utils.NNEFToolsException( "Only .tgz or .nnef or .txt files or directories are supported") parser_config = NNEFParserConfig.combine_configs( parser_configs if parser_configs else []) path_to_load = None compressed = False try: if os.path.isdir(path): compressed = False with_weights = True path_to_load = path elif path.endswith('.tgz'): compressed = True with_weights = True path_to_load = tempfile.mkdtemp(prefix="nnef_") utils.tgz_extract(path, path_to_load) elif path.endswith('.nnef') or path.endswith('.txt'): compressed = False with_weights = False path_to_load = path else: assert False # If there are fragments in the graph and also in parser_config # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition if parser_config.fragments: re_graph = re.compile(r"^graph\s|\sgraph\s") re_fragment = re.compile(r"^fragment\s|\sfragment\s") graph_nnef_path = os.path.join( path_to_load, 'graph.nnef') if os.path.isdir(path_to_load) else path_to_load with open(graph_nnef_path, 'r') as f: while True: line = f.readline() if not line: break if re_fragment.search(line): parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments break if re_graph.search(line): break return _read(parser_graph=parser_config.infer_shapes( parser_config.load_graph(path_to_load)), with_weights=with_weights) finally: if compressed and path_to_load: shutil.rmtree(path_to_load)
def get_writer(input_format, output_format, compress, with_weights=True, custom_converters=None): if output_format == 'nnef': from nnef_tools.io.nnef.nnef_io import Writer fragments = NNEFParserConfig.combine_configs( NNEFParserConfig.load_configs(custom_converters, load_standard=False)).fragments return Writer(write_weights=with_weights, fragments=fragments, only_print_used_fragments=True, compression_level=compress if compress >= 0 else 0) elif output_format == 'tensorflow-py': from nnef_tools.io.tensorflow.tf_py_io import Writer if custom_converters: custom_imports, custom_op_protos = get_tf_py_imports_and_op_protos( custom_converters) else: custom_imports, custom_op_protos = None, None return Writer(write_weights=with_weights, custom_op_protos=custom_op_protos, custom_imports=custom_imports) elif output_format == 'tensorflow-pb': from nnef_tools.io.tensorflow.tf_pb_io import Writer return Writer(convert_from_tf_py=True) elif output_format == 'tensorflow-lite': from nnef_tools.io.tensorflow.tflite_io import Writer return Writer(convert_from_tf_py=True) elif output_format == 'onnx': from nnef_tools.io.onnx.onnx_io import Writer return Writer() elif output_format == 'caffe': from nnef_tools.io.caffe.caffe_io import Writer return Writer() elif output_format == 'caffe2': from nnef_tools.io.caffe2.caffe2_io import Writer return Writer() else: assert False
def read(path, parser_configs=None): # type: (str, typing.Optional[typing.List[NNEFParserConfig]])->NNEFGraph assert path.endswith('.tgz') or path.endswith('.nnef') or os.path.isdir(path), \ "Only .tgz or .nnef files or directories are supported" parser_config = NNEFParserConfig.combine_configs( parser_configs if parser_configs else []) path_to_load = None compressed = False try: if os.path.isdir(path): compressed = False with_weights = True path_to_load = path elif path.endswith('.tgz'): compressed = True with_weights = True path_to_load = tempfile.mkdtemp(prefix="nnef_") _tgz_extract(path, path_to_load) elif path.endswith('.nnef'): compressed = False with_weights = False path_to_load = path else: assert False return _read(parser_graph=parser_config.infer_shapes( parser_config.load_graph(path_to_load)), with_weights=with_weights) finally: if compressed and path_to_load: shutil.rmtree(path_to_load)
def read( path, # type: str parser_configs=None, # type: typing.Optional[typing.List[NNEFParserConfig]] input_shape=None, # type: typing.Union[None, typing.List[int], typing.Dict[str, typing.List[int]]] ): # type: (...)->NNEFGraph if not (path.endswith('.tgz') or path.endswith('.nnef') or path.endswith('.txt') or os.path.isdir(path)): raise utils.NNEFToolsException( "Only .tgz or .nnef or .txt files or directories are supported") parser_config = NNEFParserConfig.combine_configs( parser_configs if parser_configs else []) path_to_load = None compressed = False try: if os.path.isdir(path): compressed = False with_weights = True path_to_load = path elif path.endswith('.tgz'): compressed = True with_weights = True path_to_load = tempfile.mkdtemp(prefix="nnef_") utils.tgz_extract(path, path_to_load) elif path.endswith('.nnef') or path.endswith('.txt'): compressed = False with_weights = False path_to_load = path else: assert False # If there are fragments in the graph and also in parser_config # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition if parser_config.fragments: re_graph = re.compile(r"^graph\s|\sgraph\s") re_fragment = re.compile(r"^fragment\s|\sfragment\s") graph_nnef_path = os.path.join( path_to_load, 'graph.nnef') if os.path.isdir(path_to_load) else path_to_load with open(graph_nnef_path, 'r') as f: while True: line = f.readline() if not line: break if re_fragment.search(line): parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments break if re_graph.search(line): break parser_graph = parser_config.load_graph(path_to_load) if input_shape is not None: if not isinstance(input_shape, (list, dict)): raise utils.NNEFToolsException( "input_shape must be list or dict") for op in parser_graph.operations: if op.name == 'external': if isinstance(input_shape, dict): name = op.outputs['output'] if name in input_shape: op.attribs['shape'] = input_shape[name] else: op.attribs['shape'] = input_shape parser_config.infer_shapes(parser_graph) return _read(parser_graph=parser_graph, with_weights=with_weights) finally: if compressed and path_to_load: shutil.rmtree(path_to_load)