Ejemplo n.º 1
0
def extract_model(args):
    if args.framework == 'caffe':
        from mmdnn.conversion.examples.caffe.extractor import caffe_extractor
        extractor = caffe_extractor()

    elif args.framework == 'caffe2':
        raise NotImplementedError("Caffe2 is not supported yet.")

    elif args.framework == 'keras':
        from mmdnn.conversion.examples.keras.extractor import keras_extractor
        extractor = keras_extractor()

    elif args.framework == 'tensorflow' or args.framework == 'tf':
        from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
        extractor = tensorflow_extractor()

    elif args.framework == 'mxnet':
        from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
        extractor = mxnet_extractor()

    elif args.framework == 'cntk':
        pass
    else:
        raise ValueError("Unknown framework [{}].".format(args.framework))

    files = extractor.download(args.network, args.path)

    if files and args.image:
        predict = extractor.inference(args.network, args.path, args.image)
        top_indices = predict.argsort()[-5:][::-1]
        result = [(i, predict[i]) for i in top_indices]
        print(result)
Ejemplo n.º 2
0
def extract_model(args):
    if args.framework == 'caffe':
        from mmdnn.conversion.examples.caffe.extractor import caffe_extractor
        extractor = caffe_extractor()

    elif args.framework == 'caffe2':
        raise NotImplementedError("Caffe2 is not supported yet.")

    elif args.framework == 'keras':
        from mmdnn.conversion.examples.keras.extractor import keras_extractor
        extractor = keras_extractor()

    elif args.framework == 'tensorflow' or args.framework == 'tf':
        from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
        extractor = tensorflow_extractor()

    elif args.framework == 'mxnet':
        from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
        extractor = mxnet_extractor()

    elif args.framework == 'cntk':
        from mmdnn.conversion.examples.cntk.extractor import cntk_extractor
        extractor = cntk_extractor()

    else:
        raise ValueError("Unknown framework [{}].".format(args.framework))

    files = extractor.download(args.network, args.path)

    if files and args.image:
        predict = extractor.inference(args.network, args.path, args.image)
        top_indices = predict.argsort()[-5:][::-1]
        result = [(i, predict[i]) for i in top_indices]
        print(result)
Ejemplo n.º 3
0
def extract_model(args):
    if args.framework == 'caffe':
        from mmdnn.conversion.examples.caffe.extractor import caffe_extractor
        extractor = caffe_extractor()

    elif args.framework == 'keras':
        from mmdnn.conversion.examples.keras.extractor import keras_extractor
        extractor = keras_extractor()

    elif args.framework == 'tensorflow' or args.framework == 'tf':
        from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
        extractor = tensorflow_extractor()

    elif args.framework == 'mxnet':
        from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
        extractor = mxnet_extractor()

    elif args.framework == 'cntk':
        from mmdnn.conversion.examples.cntk.extractor import cntk_extractor
        extractor = cntk_extractor()

    elif args.framework == 'pytorch':
        from mmdnn.conversion.examples.pytorch.extractor import pytorch_extractor
        extractor = pytorch_extractor()

    elif args.framework == 'darknet':
        from mmdnn.conversion.examples.darknet.extractor import darknet_extractor
        extractor = darknet_extractor()

    else:
        raise ValueError("Unknown framework [{}].".format(args.framework))

    files = extractor.download(args.network, args.path)

    if files and args.image:
        predict = extractor.inference(args.network, files, args.path,
                                      args.image)
        if type(predict) == list:
            print(predict)

        else:
            if predict.ndim == 1:
                if predict.shape[0] == 1001:
                    offset = 1
                else:
                    offset = 0
                top_indices = predict.argsort()[-5:][::-1]
                predict = [(i, predict[i]) for i in top_indices]
                predict = generate_label(predict, args.label, offset)

                for line in predict:
                    print(line)

            else:
                print(predict.shape)
                print(predict)