예제 #1
0
    def darknet_parse(architecture_name, test_input_path):
        ensure_dir("./data/")
        from mmdnn.conversion.examples.darknet.extractor import darknet_extractor
        from mmdnn.conversion.darknet.darknet_parser import DarknetParser
        # download model
        architecture_file = darknet_extractor.download(architecture_name,
                                                       TestModels.cachedir)

        # get original model prediction result
        original_predict = darknet_extractor.inference(architecture_name,
                                                       architecture_file,
                                                       TestModels.cachedir,
                                                       test_input_path)
        del darknet_extractor

        # original to IR
        IR_file = TestModels.tmpdir + 'darknet_' + architecture_name + "_converted"

        if architecture_name == "yolov3":
            start = "1"
        else:
            start = "0"

        parser = DarknetParser(architecture_file[0], architecture_file[1],
                               start)
        parser.run(IR_file)
        del parser
        del DarknetParser
        return original_predict
예제 #2
0
def _convert(args):
    if args.inputShape != None:
        inputshape = []
        for x in args.inputShape:
            shape = x.split(',')
            inputshape.append([int(x) for x in shape])
    else:
        inputshape = [None]
    if args.srcFramework == 'caffe':
        from mmdnn.conversion.caffe.transformer import CaffeTransformer
        transformer = CaffeTransformer(args.network,
                                       args.weights,
                                       "tensorflow",
                                       inputshape[0],
                                       phase=args.caffePhase)
        graph = transformer.transform_graph()
        data = transformer.transform_data()

        from mmdnn.conversion.caffe.writer import JsonFormatter, ModelSaver, PyWriter
        JsonFormatter(graph).dump(args.dstPath + ".json")
        print("IR network structure is saved as [{}.json].".format(
            args.dstPath))

        prototxt = graph.as_graph_def().SerializeToString()
        with open(args.dstPath + ".pb", 'wb') as of:
            of.write(prototxt)
        print("IR network structure is saved as [{}.pb].".format(args.dstPath))

        import numpy as np
        with open(args.dstPath + ".npy", 'wb') as of:
            np.save(of, data)
        print("IR weights are saved as [{}.npy].".format(args.dstPath))

        return 0

    elif args.srcFramework == 'caffe2':
        raise NotImplementedError("Caffe2 is not supported yet.")

    elif args.srcFramework == 'keras':
        if args.network != None:
            model = (args.network, args.weights)
        else:
            model = args.weights

        from mmdnn.conversion.keras.keras2_parser import Keras2Parser
        parser = Keras2Parser(model)

    elif args.srcFramework == 'tensorflow' or args.srcFramework == 'tf':

        if args.dstNodeName is None:
            raise ValueError(
                "Need to provide the output node of Tensorflow model.")

        assert args.network or args.weights
        if not args.network:
            if args.inNodeName is None:
                raise ValueError(
                    "Need to provide the input node of Tensorflow model.")
            if inputshape is None:
                raise ValueError(
                    "Need to provide the input node shape of Tensorflow model."
                )
            assert len(args.inNodeName) == len(inputshape)
            from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2
            parser = TensorflowParser2(args.weights, inputshape,
                                       args.inNodeName, args.dstNodeName)

        else:
            from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
            if args.inNodeName and inputshape[0]:
                parser = TensorflowParser(args.network, args.weights,
                                          args.dstNodeName, inputshape[0],
                                          args.inNodeName)
            else:
                parser = TensorflowParser(args.network, args.weights,
                                          args.dstNodeName)

    elif args.srcFramework == 'mxnet':
        assert inputshape != None
        if args.weights == None:
            model = (args.network, inputshape[0])
        else:
            import re
            if re.search('.', args.weights):
                args.weights = args.weights[:-7]
            prefix, epoch = args.weights.rsplit('-', 1)
            model = (args.network, prefix, epoch, inputshape[0])

        from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
        parser = MXNetParser(model)

    elif args.srcFramework == 'cntk':
        from mmdnn.conversion.cntk.cntk_parser import CntkParser
        model = args.network or args.weights
        parser = CntkParser(model)

    elif args.srcFramework == 'pytorch':
        assert inputshape != None
        from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
        model = args.network or args.weights
        assert model != None
        parser = PytorchParser(model, inputshape[0])

    elif args.srcFramework == 'torch' or args.srcFramework == 'torch7':
        from mmdnn.conversion.torch.torch_parser import TorchParser
        model = args.network or args.weights
        assert model != None
        parser = TorchParser(model, inputshape[0])

    elif args.srcFramework == 'onnx':
        from mmdnn.conversion.onnx.onnx_parser import ONNXParser
        parser = ONNXParser(args.network)

    elif args.srcFramework == 'darknet':
        from mmdnn.conversion.darknet.darknet_parser import DarknetParser
        parser = DarknetParser(args.network, args.weights, args.darknetStart)

    elif args.srcFramework == 'coreml':
        from mmdnn.conversion.coreml.coreml_parser import CoremlParser
        parser = CoremlParser(args.network)

    else:
        raise ValueError("Unknown framework [{}].".format(args.srcFramework))

    parser.run(args.dstPath)

    return 0