def __call__(self, path, input_shapes=None): compressed = os.path.splitext(path) in ['tgz', 'gz' ] and not os.path.isdir(path) folder = None try: if compressed: folder = tempfile.mkdtemp(prefix="nnef_") tgz_extract(path, folder) path = folder if not os.path.isdir(path): raise IOError( "NNEF model must be a (compressed) folder, but an uncompressed file was provided" ) nnef_graph = nnef.load_graph(path, stdlib=self._stdlib, lowered=self._decomposed) if self._infer_shapes: nnef.infer_shapes(nnef_graph, external_shapes=input_shapes or {}, custom_shapes=self._custom_shapes or {}) return _build_graph(nnef_graph) finally: if folder is not None: shutil.rmtree(folder)
def test_reshape(self): graph = nnef.parse_string(""" version 1.0; graph G( input ) -> ( output ) { input = external(shape = [1,2,3,4]); output = reshape(input, axis_start = 1, axis_count = 2, shape = [6]); } """) nnef.infer_shapes(graph)
def nnef2ir(inputFolder, outputFolder): nnef_graph = nnef.load_graph(inputFolder) nnef.infer_shapes(nnef_graph) graph = nnef_graph_to_ir_graph(nnef_graph) graph.toFile(outputFolder)
import nnef def shuffle_shape(input, groups): assert input[ 1] % groups == 0, "input channels ({}) is not divisible by groups ({})".format( input[1], groups) return input graph = nnef.parse_string(""" version 1.0; extension KHR_enable_fragment_definitions; fragment shuffle<?>( input: tensor<?>, groups: integer ) -> ( output: tensor<?> ); graph Net( input ) -> ( output ) { input = external(shape = [1,3,224,224]); filter = variable(shape = [32,3,5,5], label = 'conv/filter'); conv = conv(input, filter); output = shuffle(conv, groups = 4); } """) nnef.infer_shapes(graph, custom_shapes={'shuffle': shuffle_shape}) print( nnef.format_graph(graph.name, graph.inputs, graph.outputs, graph.operations))
args = ap.parse_args() stdlib = '' if args.stdlib: try: with open(args.stdlib) as file: stdlib = file.read() except FileNotFoundError as e: print('Could not open file: ' + args.stdlib) exit(-1) try: graph = nnef.load_graph(args.path, stdlib=stdlib, lowered=args.lower.split(',')) except nnef.Error as err: print('Parse error: ' + str(err)) exit(-1) if args.shapes: try: nnef.infer_shapes(graph) except nnef.Error as err: print('Shape error: ' + str(err)) exit(-1) print( nnef.format_graph(graph.name, graph.inputs, graph.outputs, graph.operations)) print('Validation succeeded')
def infer_shapes(self, graph): nnef.infer_shapes(graph=graph, custom_shapes=self._shapes) return graph