示例#1
0
def prt_model(model_id, output_layer):
    # ResNet50

    print('loading model...')
    # 0 ResNet50
    if model_id == 0:
        if output_layer == 0:  # Keras
            model_name = 'ResNet50_Keras'
            model = ResNet50(weights='imagenet', include_top=False)
        img_size = 224
        #img_size = 30
        model_path = '../models/resnet50.onnx'
        datain = 'gpu_0/data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'ResNet50_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'ResNet50_ONNX_P'
            layerout = 'flatten0_output'

    # 1 VGG16
    elif model_id == 1:
        if output_layer == 0:  # Keras
            model_name = 'VGG16_Keras'
            model = VGG16(weights='imagenet', include_top=False)
        #img_size = 224
        img_size = 60
        model_path = '../models/vgg16.onnx'
        datain = 'data'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'VGG16_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'VGG16_ONNX_P'
            layerout = 'flatten2_output'

    # 2 VGG19
    elif model_id == 2:
        if output_layer == 0:  # Keras
            model_name = 'VGG19_Keras'
            model = VGG19(weights='imagenet', include_top=False)
        #img_size = 224
        img_size = 60
        model_path = '../models/vgg19.onnx'
        datain = 'data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'VGG19_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'VGG19_ONNX_P'
            layerout = 'flatten2_output'

    # 3 InceptionV3
    elif model_id == 3:
        if output_layer == 0:  # Keras
            model_name = 'InceptionV3_Keras'
            model = InceptionV3(weights='imagenet', include_top=False)
        #img_size = 224
        img_size = 90

    # 4 InceptionResNetV2
    #elif model_id == 4:
    #    if output_layer == 0: # Keras
    #        model_name = 'InceptionResNetV2_Keras'
    #        model = InceptionResNetV2(weights='imagenet', include_top=False)
    #    #img_size = 224
    #    img_size = 90

    # 5 Xception
    elif model_id == 5:
        if output_layer == 0:  # Keras
            model_name = 'Xception_Keras'
            model = Xception(weights='imagenet', include_top=False)
        #img_size = 224
        img_size = 30

    # 6 MobileNet
    elif model_id == 6:
        model_path = '../models/mobilenet.onnx'
        datain = 'data'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'MobileNet_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'MobileNet_ONNX_P'
            layerout = 'mobilenetv20_features_pool0_fwd_output'  # not good

    # 7 SqueezeNet
    elif model_id == 7:
        model_path = '../models/squeezenet.onnx'
        datain = 'data'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'SqueezeNet_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'SqueezeNet_ONNX_P'
            layerout = 'pooling3_output'  # not good

    # 8 AlexNet
    elif model_id == 8:
        model_path = '../models/alexnet.onnx'
        datain = 'data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'AlexNet_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'AlexNet_ONNX_P'
            layerout = 'flatten2_output'

    # 9 GoogleNet
    elif model_id == 9:
        model_path = '../models/googlenet.onnx'
        datain = 'data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'GoogleNet_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'GoogleNet_ONNX_P'
            layerout = 'flatten0_output'  # not good

    # 10 ShuffleNet
    elif model_id == 10:
        model_path = '../models/shufflenet.onnx'
        datain = 'gpu_0/data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'ShuffleNet_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'ShuffleNet_ONNX_P'
            layerout = 'flatten0_output'  # not good

    # 11 DenseNet121
    elif model_id == 11:
        model_path = '../models/densenet121.onnx'
        datain = 'data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'DenseNet121_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'DenseNet121_ONNX_P'
            layerout = 'pad124_output'  # not good

    # 12 ZfNet512
    elif model_id == 12:
        model_path = '../models/zfnet512.onnx'
        datain = 'gpu_0/data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'ZfNet512_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'ZfNet512_ONNX_P'
            layerout = 'flatten2_output'  # not good

    # 13 RCNN_ILSVRC13
    elif model_id == 13:
        model_path = '../models/rcnn_ilsvrc13.onnx'
        datain = 'data_0'
        if output_layer == 1:  # ONNX - Last Layer
            model_name = 'RCNN_ILSVRC13_ONNX_L'
            layerout = 'last'
        elif output_layer == 2:  # ONNX - Previous Layer
            model_name = 'RCNN_ILSVRC13_ONNX_P'
            layerout = 'flatten2_output'  # not good

    if output_layer > 0:
        print('loading model ' + model_path + '...')
        sym, arg_params, aux_params = import_model(model_path)
        if len(mx.test_utils.list_gpus()) == 0:
            ctx = mx.cpu()
        else:
            ctx = mx.gpu(0)

        all_layers = sym.get_internals()
        print(all_layers.list_outputs())
        if layerout == 'last':
            sym3 = sym
        else:
            sym3 = all_layers[layerout]
        model = mx.mod.Module(symbol=sym3,
                              context=ctx,
                              label_names=None,
                              data_names=[datain])
        image_size = (224, 224)
        img_size = image_size[1]
        model.bind(data_shapes=[(datain, (1, 3, image_size[0],
                                          image_size[1]))])
        model.set_params(arg_params, aux_params)

    print('model ' + model_name + ' loaded.')
    return model, img_size, model_name