def darknet_parse(architecture_name, test_input_path): ensure_dir("./data/") from mmdnn.conversion.examples.darknet.extractor import darknet_extractor from mmdnn.conversion.darknet.darknet_parser import DarknetParser # download model architecture_file = darknet_extractor.download(architecture_name, TestModels.cachedir) # get original model prediction result original_predict = darknet_extractor.inference(architecture_name, architecture_file, TestModels.cachedir, test_input_path) del darknet_extractor # original to IR IR_file = TestModels.tmpdir + 'darknet_' + architecture_name + "_converted" if architecture_name == "yolov3": start = "1" else: start = "0" parser = DarknetParser(architecture_file[0], architecture_file[1], start) parser.run(IR_file) del parser del DarknetParser return original_predict
def _convert(args): if args.inputShape != None: inputshape = [] for x in args.inputShape: shape = x.split(',') inputshape.append([int(x) for x in shape]) else: inputshape = [None] if args.srcFramework == 'caffe': from mmdnn.conversion.caffe.transformer import CaffeTransformer transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape[0], phase=args.caffePhase) graph = transformer.transform_graph() data = transformer.transform_data() from mmdnn.conversion.caffe.writer import JsonFormatter, ModelSaver, PyWriter JsonFormatter(graph).dump(args.dstPath + ".json") print("IR network structure is saved as [{}.json].".format( args.dstPath)) prototxt = graph.as_graph_def().SerializeToString() with open(args.dstPath + ".pb", 'wb') as of: of.write(prototxt) print("IR network structure is saved as [{}.pb].".format(args.dstPath)) import numpy as np with open(args.dstPath + ".npy", 'wb') as of: np.save(of, data) print("IR weights are saved as [{}.npy].".format(args.dstPath)) return 0 elif args.srcFramework == 'caffe2': raise NotImplementedError("Caffe2 is not supported yet.") elif args.srcFramework == 'keras': if args.network != None: model = (args.network, args.weights) else: model = args.weights from mmdnn.conversion.keras.keras2_parser import Keras2Parser parser = Keras2Parser(model) elif args.srcFramework == 'tensorflow' or args.srcFramework == 'tf': if args.dstNodeName is None: raise ValueError( "Need to provide the output node of Tensorflow model.") assert args.network or args.weights if not args.network: if args.inNodeName is None: raise ValueError( "Need to provide the input node of Tensorflow model.") if inputshape is None: raise ValueError( "Need to provide the input node shape of Tensorflow model." ) assert len(args.inNodeName) == len(inputshape) from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2 parser = TensorflowParser2(args.weights, inputshape, args.inNodeName, args.dstNodeName) else: from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser if args.inNodeName and inputshape[0]: parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape[0], args.inNodeName) else: parser = TensorflowParser(args.network, args.weights, args.dstNodeName) elif args.srcFramework == 'mxnet': assert inputshape != None if args.weights == None: model = (args.network, inputshape[0]) else: import re if re.search('.', args.weights): args.weights = args.weights[:-7] prefix, epoch = args.weights.rsplit('-', 1) model = (args.network, prefix, epoch, inputshape[0]) from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser parser = MXNetParser(model) elif args.srcFramework == 'cntk': from mmdnn.conversion.cntk.cntk_parser import CntkParser model = args.network or args.weights parser = CntkParser(model) elif args.srcFramework == 'pytorch': assert inputshape != None from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser model = args.network or args.weights assert model != None parser = PytorchParser(model, inputshape[0]) elif args.srcFramework == 'torch' or args.srcFramework == 'torch7': from mmdnn.conversion.torch.torch_parser import TorchParser model = args.network or args.weights assert model != None parser = TorchParser(model, inputshape[0]) elif args.srcFramework == 'onnx': from mmdnn.conversion.onnx.onnx_parser import ONNXParser parser = ONNXParser(args.network) elif args.srcFramework == 'darknet': from mmdnn.conversion.darknet.darknet_parser import DarknetParser parser = DarknetParser(args.network, args.weights, args.darknetStart) elif args.srcFramework == 'coreml': from mmdnn.conversion.coreml.coreml_parser import CoremlParser parser = CoremlParser(args.network) else: raise ValueError("Unknown framework [{}].".format(args.srcFramework)) parser.run(args.dstPath) return 0