示例#1
0
    def __call__(self, path, input_shapes=None):
        compressed = os.path.splitext(path) in ['tgz', 'gz'
                                                ] and not os.path.isdir(path)

        folder = None
        try:
            if compressed:
                folder = tempfile.mkdtemp(prefix="nnef_")
                tgz_extract(path, folder)
                path = folder

            if not os.path.isdir(path):
                raise IOError(
                    "NNEF model must be a (compressed) folder, but an uncompressed file was provided"
                )

            nnef_graph = nnef.load_graph(path,
                                         stdlib=self._stdlib,
                                         lowered=self._decomposed)
            if self._infer_shapes:
                nnef.infer_shapes(nnef_graph,
                                  external_shapes=input_shapes or {},
                                  custom_shapes=self._custom_shapes or {})

            return _build_graph(nnef_graph)
        finally:
            if folder is not None:
                shutil.rmtree(folder)
示例#2
0
def nnef2ir(inputFolder, outputFolder):
    nnef_graph = nnef.load_graph(inputFolder)
    nnef.infer_shapes(nnef_graph)
    graph = nnef_graph_to_ir_graph(nnef_graph)
    graph.toFile(outputFolder)
示例#3
0
                    action="store_true",
                    help='perform shape validation as well')
    args = ap.parse_args()

    stdlib = ''
    if args.stdlib:
        try:
            with open(args.stdlib) as file:
                stdlib = file.read()
        except FileNotFoundError as e:
            print('Could not open file: ' + args.stdlib)
            exit(-1)

    try:
        graph = nnef.load_graph(args.path,
                                stdlib=stdlib,
                                lowered=args.lower.split(','))
    except nnef.Error as err:
        print('Parse error: ' + str(err))
        exit(-1)

    if args.shapes:
        try:
            nnef.infer_shapes(graph)
        except nnef.Error as err:
            print('Shape error: ' + str(err))
            exit(-1)

    print(
        nnef.format_graph(graph.name, graph.inputs, graph.outputs,
                          graph.operations))
示例#4
0
 def load_graph(self, path):
     return nnef.load_graph(path=path,
                            stdlib=self._source,
                            lowered=self._expand)
示例#5
0
 def load_graph(self, path):
     return nnef.load_graph(path=path,
                            stdlib=self.fragments,
                            lowered=self.lowered)
示例#6
0
 def load_graph(self, path):
     return nnef.load_graph(path=path,
                            stdlib=self.custom_ops,
                            lowered=self.expand)