def collect_models():
    models = dict()
    models["MobileNetV2"] = tfapp.MobileNetV2(input_shape=IMG_SHAPE,
                                              include_top=False,
                                              weights='imagenet')
    models["NASNetMobile"] = tfapp.NASNetMobile(input_shape=(130, 386, 3),
                                                include_top=False,
                                                weights='imagenet')
    models["DenseNet121"] = tfapp.DenseNet121(input_shape=IMG_SHAPE,
                                              include_top=False,
                                              weights='imagenet')
    models["VGG16"] = tfapp.VGG16(input_shape=IMG_SHAPE,
                                  include_top=False,
                                  weights='imagenet')
    models["Xception"] = tfapp.Xception(input_shape=(134, 390, 3),
                                        include_top=False,
                                        weights='imagenet')
    models["ResNet50V2"] = tfapp.ResNet50V2(input_shape=IMG_SHAPE,
                                            include_top=False,
                                            weights='imagenet')
    models["NASNetLarge"] = tfapp.NASNetLarge(input_shape=(130, 386, 3),
                                              include_top=False,
                                              weights='imagenet')

    # omit non 2^n shape
    # models["InceptionV3"] = tfapp.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    # models["InceptionResNetV2"] = \
    #     tfapp.InceptionResNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    return models
def compute_mean_and_std(model_name, X, input_shape):
    if model_name == 'Xception':
        model = applications.Xception(weights='imagenet',
                                      include_top=False,
                                      input_shape=input_shape)
    elif model_name == 'VGG16':
        model = applications.VGG16(weights='imagenet',
                                   include_top=False,
                                   input_shape=input_shape)
    elif model_name == 'VGG19':
        model = applications.VGG19(weights='imagenet',
                                   include_top=False,
                                   input_shape=input_shape)
    elif model_name == 'ResNet50':
        model = applications.ResNet50(weights='imagenet',
                                      include_top=False,
                                      input_shape=input_shape)
    elif model_name == 'InceptionResNetV2':
        model = applications.InceptionResNetV2(weights='imagenet',
                                               include_top=False,
                                               input_shape=input_shape)
    elif model_name == 'InceptionV3':
        model = applications.InceptionV3(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_shape)
    elif model_name == 'MobileNet':
        model = applications.MobileNet(weights='imagenet',
                                       include_top=False,
                                       input_shape=input_shape)
    elif model_name == 'DenseNet121':
        model = applications.DenseNet121(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_shape)
    elif model_name == 'DenseNet169':
        model = applications.DenseNet169(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_shape)
    elif model_name == 'DenseNet201':
        model = applications.DenseNet201(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_shape)
    elif model_name == 'NASNetMobile':
        model = applications.NASNetMobile(weights='imagenet',
                                          include_top=False,
                                          input_shape=input_shape)
    elif model_name == 'NASNetLarge':
        model = applications.NASNetLarge(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_shape)
    else:
        assert (False), "Specified base model is not available !"

    features = model.predict(X)[:, 0, 0, :]

    return features.mean(axis=0), features.std(axis=0)
示例#3
0
def Xception(num_classes: int, image_size=(256, 256)):

    base_model = applications.Xception(weights=None,
                                       include_top=False,
                                       input_shape=(*image_size, 3))
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_classes)(x)
    predictions = layers.Activation('softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)

    return model
示例#4
0
    def models_as_layers():
        """
        Similar to the previous example, we can use entire models an arbitrary number of times in a single model. This
        can be quite useful in certain applications, for example here we use a Siamese model to process 2 camera inputs
        using pre-learned representations. From there we could add a dense head to also calculate the distance to any
        detected object

        :return: None
        """
        xception_base = applications.Xception(weights=None, include_top=False)

        left_input = Input(shape=(255, 255, 3))
        right_input = Input(shape=(255, 255, 3))

        left_features = xception_base(left_input)
        right_features = xception_base(right_input)

        merged_features = layers.concatenate([left_features, right_features],
                                             axis=-1)
def model(input_size=(1024, 1024, 3),
          num_classes=20,
          depthwise=False,
          output_stride=16,
          backbone='xception'):
    if backbone == 'modified_xception':
        # Input Layer
        inputs = keras.Input(input_size, name='xception_input')
        # Do the entry block, also returns the low layer for ASPP and Decoder
        xception, low_layer = entry_block(inputs)
        # Now do the middle block
        for i in range(16):
            blockname = 'middle_block{}'.format(i)
            xception = middle_block(xception, blockname)
        # Activate the last add
        xception = keras.layers.Activation(
            'relu', name='exit_block_relu_add')(xception)
        # Do the exit block
        xception = exit_block(xception)
    elif backbone == 'xception':
        xception = applications.Xception(weights='imagenet',
                                         include_top=False,
                                         input_shape=input_size,
                                         classes=20)
        for layer in xception.layers:
            layer.trainable = False
        inputs = xception.layers[0].output
        low_layer = xception.get_layer('add_1').output
        low_layer.trainable = True
        previous = xception.layers[-1].output
        previous.trainable = True
    elif backbone == 'mobilenetv2':
        mobilenet = applications.MobileNetV2(weights='imagenet',
                                             include_top=False,
                                             input_shape=input_size,
                                             classes=20)
        inputs = mobilenet.layers[0].output
        low_layer = mobilenet.get_layer('block_3_depthwise').output
        previous = mobilenet.layers[-1].output
    # ASPP
    aspp = ASPP(previous, depthwise=depthwise, output_stride=output_stride)
    aspp_up = keras.layers.UpSampling2D(size=(4, 4),
                                        interpolation='bilinear',
                                        name='decoder_ASPP_upsample')(aspp)
    # Decoder Begins Here
    conv1x1 = keras.layers.Conv2D(48,
                                  1,
                                  strides=1,
                                  padding='same',
                                  use_bias=False,
                                  name="decoder_conv1x1")(low_layer)
    norm1x1 = keras.layers.BatchNormalization(
        name='decoder_conv1x1_batch_norm')(conv1x1)
    relu1x1 = keras.layers.Activation('relu',
                                      name='decoder_conv1x1_relu')(norm1x1)

    # Concatenate ASPP and the 1x1 Convolution
    decode_concat = keras.layers.Concatenate(name='decoder_concat')(
        [aspp_up, relu1x1])

    # Do some Convolutions
    if depthwise:
        conv1_decoder = keras.layers.SeparableConv2D(
            256, 3, strides=1, padding='same',
            name='decoder_conv1')(decode_concat)
    else:
        conv1_decoder = keras.layers.Conv2D(
            256, 3, strides=1, padding='same',
            name='decoder_conv1')(decode_concat)
    norm1_decoder = keras.layers.BatchNormalization(
        name='decoder_conv1_batch_norm')(conv1_decoder)
    relu1_decoder = keras.layers.Activation(
        'relu', name='decoder_conv1_relu')(norm1_decoder)

    if depthwise:
        conv2_decoder = keras.layers.SeparableConv2D(
            256, 3, strides=1, padding='same',
            name='decoder_conv2')(relu1_decoder)
    else:
        conv2_decoder = keras.layers.Conv2D(
            256, 3, strides=1, padding='same',
            name='decoder_conv2')(relu1_decoder)
    norm2_decoder = keras.layers.BatchNormalization(
        name='decoder_conv2_batch_norm')(conv2_decoder)
    relu2_decoder = keras.layers.Activation(
        'relu', name='decoder_conv2_relu')(norm2_decoder)

    # Do classification 1x1 layer
    classification = keras.layers.Conv2D(num_classes,
                                         kernel_size=(1, 1),
                                         strides=(1, 1),
                                         activation='softmax',
                                         name='Classification',
                                         dtype=tf.float32)(relu2_decoder)
    classification_up = keras.layers.UpSampling2D(
        size=(8, 8),
        interpolation='bilinear',
        name="Classificaton_Upsample",
        dtype=tf.float32)(classification)

    # Create the model
    model = keras.Model(inputs=inputs, outputs=classification_up)
    # Return the model
    return model
示例#6
0
def train_model(model_name=None):
    train_gen = create_data_augmentations()
    validation_gen = create_data()

    train_generator = train_gen.flow_from_directory(train_data_dir,
                                                    target_size=image_size,
                                                    color_mode='rgb',
                                                    batch_size=batch_size,
                                                    class_mode='categorical',
                                                    shuffle=True)

    validation_generator = validation_gen.flow_from_directory(
        validation_data_dir,
        target_size=image_size,
        color_mode='rgb',
        class_mode="categorical")

    if model_name:
        model_final = load_model(model_name)
    else:
        base_model = applications.Xception(weights=None,
                                           include_top=False,
                                           input_shape=(image_size[0],
                                                        image_size[1], 3))

        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        fully_connected = Dense(class_count, activation='softmax')(x)
        model_final = Model(inputs=base_model.input, outputs=fully_connected)

    # Compile the final modal with loss and optimizer.
    model_final.compile(loss="categorical_crossentropy",
                        optimizer=optimizers.SGD(lr=lr, momentum=momentum),
                        metrics=["accuracy"])

    model_final.summary()

    step_size_train = train_generator.n // train_generator.batch_size

    # Initializing monitoring params for training.
    history = LossAccHistory()
    early = EarlyStopping(monitor='val_accuracy',
                          min_delta=0,
                          patience=2,
                          verbose=1,
                          mode='auto')
    network_file_name = main_dir + train_dir + model_file
    checkpoint = ModelCheckpoint(network_file_name,
                                 monitor='val_accuracy',
                                 verbose=1,
                                 save_best_only=True,
                                 save_weights_only=False,
                                 mode='auto',
                                 save_freq='epoch')

    print("Train...")

    train_start_time = time.time()

    model_history = model_final.fit_generator(
        generator=train_generator,
        steps_per_epoch=step_size_train,
        validation_data=validation_generator,
        epochs=epochs,
        shuffle=True,
        callbacks=[history, checkpoint, early])

    train_end_time = time.time()

    train_time_in_minutes = int((train_end_time - train_start_time) / 60)

    print("-------Done-------")

    print_acc(model_history.history)
    print_loss(model_history.history)
    print_time_log(train_time_in_minutes)
    calculate_test_accuracy(load_model(main_dir + train_dir + model_file))
示例#7
0
def estimate(X_train, y_train, back_bone):
    IMAGE_WIDTH = 224                               # Image width
    IMAGE_HEIGHT = 224                              # Image height
    input_shape = (IMAGE_WIDTH, IMAGE_HEIGHT, 3)    # (width, height, channel) channel = 3 ---> RGB
    batch_size = 8
    epochs = 1                                     # Number of epochs
    ntrain = 0.8 * len(X_train)                     # split data with 80/20 train/validation
    nval = 0.2 * len(X_train)
    back_bone = str(back_bone)
    print("Using " + back_bone + "...")
    X = []
    X_train = np.reshape(np.array(X_train), [len(X_train), ])

    for img in list(range(0, len(X_train))):

        if X_train[img].ndim >= 3:
            X.append(cv2.resize(
                X_train[img][:, :, :3], (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_CUBIC))
        else:
            smimg = cv2.cvtColor(X_train[img][0], cv2.COLOR_GRAY2RGB)
            X.append(cv2.resize(smimg, (IMAGE_WIDTH, IMAGE_HEIGHT),
                                interpolation=cv2.INTER_CUBIC))

        if y_train[img] == 'COVID':
            y_train[img] = 1
        elif y_train[img] == 'NonCOVID':
            y_train[img] = 0
        else:
            continue

    x = np.array(X)
    X_train, X_val, y_train, y_val = train_test_split(          # 20% validation set
        x, y_train, test_size=0.20, random_state=2)

    # data generator
    if back_bone == 'ResNet50V2' or back_bone =='1':
        train_datagen = ImageDataGenerator(
            preprocessing_function=resnet_preprocess,
            rotation_range=15,
            shear_range=0.1,
            zoom_range=0.2,
            horizontal_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1
        )

        val_datagen = ImageDataGenerator(preprocessing_function=resnet_preprocess)

    elif back_bone == 'Xception' or back_bone =='2':
        train_datagen = ImageDataGenerator(
            preprocessing_function=xception_preprocess,
            rotation_range=15,
            shear_range=0.1,
            zoom_range=0.2,
            horizontal_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1
        )

        val_datagen = ImageDataGenerator(preprocessing_function=xception_preprocess)
    elif back_bone == "DenseNet201" or back_bone =='3':
        train_datagen = ImageDataGenerator(
            preprocessing_function=denset_preprocess,
            rotation_range=15,
            shear_range=0.1,
            zoom_range=0.2,
            horizontal_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1
        )

        val_datagen = ImageDataGenerator(preprocessing_function=denset_preprocess)
    elif back_bone == "MobileNetV2" or back_bone == '4':
        train_datagen = ImageDataGenerator(
            preprocessing_function= mobile_preprocess,
            rotation_range=15,
            shear_range=0.1,
            zoom_range=0.2,
            horizontal_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1
        )

        val_datagen = ImageDataGenerator(preprocessing_function=mobile_preprocess)
    else:
        raise ValueError('Please select transfer learning model!')
    train_generator = train_datagen.flow(
        X_train, y_train, batch_size=batch_size, shuffle=True)
    val_generator = val_datagen.flow(
        X_val, y_val, batch_size=batch_size, shuffle=True)

    # model
    model = Sequential()
    if back_bone == 'ResNet50V2' or back_bone == '1':
        base_model = applications.resnet_v2.ResNet50V2(
            include_top=False, pooling='avg', weights='imagenet', input_shape=input_shape)
    elif back_bone == 'Xception' or back_bone == '2':
        base_model = applications.Xception(
            include_top=False, pooling='avg', weights='imagenet', input_shape=input_shape)
    elif back_bone == 'DenseNet201' or back_bone =='3':
        base_model = applications.DenseNet201(
            include_top=False, pooling='avg', weights='imagenet', input_shape=input_shape)
    elif back_bone == 'MobileNetV2' or back_bone =='4':
        base_model = applications.MobileNetV2(
            include_top=False, pooling='avg', weights='imagenet', input_shape=input_shape)
    else:
        raise ValueError('Please select transfer learning model!')

    base_model.trainable = False

    model.add(base_model)

    model.add(BatchNormalization())
    model.add(Flatten())

    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.1))

    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.2))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer=optimizers.Adam(lr=1e-4),
                  metrics=['acc'])

    # callbacks
    lr_reducer = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.8,
        patience=5,
        verbose=1,
        mode='auto',
        min_delta=0.0001,
        cooldown=3,
        min_lr=1e-14)

    best_loss_path = "Model.h5"
    ck_loss_model = ModelCheckpoint(
        best_loss_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

    callbacks = [ck_loss_model, lr_reducer]

    history = model.fit_generator(train_generator,
                                  steps_per_epoch=ntrain//batch_size,
                                  epochs=epochs,
                                  validation_data=val_generator,
                                  validation_steps=nval // batch_size,
                                  callbacks=callbacks)

    return model
def build_autoencoder(base_model_name, input_shape, imagenet_mean,
                      imagenet_std, hidden_layer_size, n_classes,
                      weight_decay):
    if base_model_name == 'Xception':
        base_model = applications.Xception(weights='imagenet',
                                           include_top=False,
                                           input_shape=input_shape)
    elif base_model_name == 'VGG16':
        base_model = applications.VGG16(weights='imagenet',
                                        include_top=False,
                                        input_shape=input_shape)
    elif base_model_name == 'VGG19':
        base_model = applications.VGG19(weights='imagenet',
                                        include_top=False,
                                        input_shape=input_shape)
    elif base_model_name == 'ResNet50':
        base_model = applications.ResNet50(weights='imagenet',
                                           include_top=False,
                                           input_shape=input_shape)
    elif base_model_name == 'InceptionResNetV2':
        base_model = applications.InceptionResNetV2(weights='imagenet',
                                                    include_top=False,
                                                    input_shape=input_shape)
    elif base_model_name == 'InceptionV3':
        base_model = applications.InceptionV3(weights='imagenet',
                                              include_top=False,
                                              input_shape=input_shape)
    elif base_model_name == 'MobileNet':
        base_model = applications.MobileNet(weights='imagenet',
                                            include_top=False,
                                            input_shape=input_shape)
    elif base_model_name == 'DenseNet121':
        base_model = applications.DenseNet121(weights='imagenet',
                                              include_top=False,
                                              input_shape=input_shape)
    elif base_model_name == 'DenseNet169':
        base_model = applications.DenseNet169(weights='imagenet',
                                              include_top=False,
                                              input_shape=input_shape)
    elif base_model_name == 'DenseNet201':
        base_model = applications.DenseNet201(weights='imagenet',
                                              include_top=False,
                                              input_shape=input_shape)
    elif base_model_name == 'NASNetMobile':
        base_model = applications.NASNetMobile(weights='imagenet',
                                               include_top=False,
                                               input_shape=input_shape)
    elif base_model_name == 'NASNetLarge':
        base_model = applications.NASNetLarge(weights='imagenet',
                                              include_top=False,
                                              input_shape=input_shape)
    else:
        assert (False), "Specified base model is not available !"

    n_features = base_model.output.shape[-1]

    x = base_model.output
    x = tf.keras.layers.Lambda(lambda x: (x - imagenet_mean) / imagenet_std)(
        x)  # normalization
    x = tf.keras.layers.Activation(activation='sigmoid',
                                   name='encoder')(x)  # sigmoid
    x = tf.keras.layers.Dense(units=hidden_layer_size,
                              activation=None)(x)  # encoding
    x = tf.keras.layers.Activation(activation='relu')(x)  # relu
    x = tf.keras.layers.Dense(units=n_features,
                              activation=None,
                              name='decoder')(x)  # decoding
    x = tf.keras.layers.Dense(units=n_classes, activation='sigmoid')(
        x)  # x = tf.keras.layers.Activation(activation='sigmoid')(x) # sigmoid

    model = tf.keras.Model(inputs=base_model.input, outputs=x[:, 0, 0, :])

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(
                layer, tf.keras.layers.Dense):
            layer.add_loss(
                tf.keras.regularizers.l2(weight_decay)(layer.kernel))
        if hasattr(layer, 'bias_regularizer') and layer.use_bias:
            layer.add_loss(tf.keras.regularizers.l2(weight_decay)(layer.bias))

    return model
示例#9
0
from tensorflow.keras import layers
from tensorflow.keras import applications
from tensorflow.keras import Input, Model

xception_base = applications.Xception(weights=None, include_top=False)

left_input = Input(shape=(250, 250, 3))
right_input = Input(shape=(250, 250, 3))

left_features = xception_base(left_input)
right_features = xception_base(right_input)

merged_features = layers.concatenate([left_features, right_features], axis=-1)

model = Model([left_input, right_input], merged_features)

model.summary()
示例#10
0
# load X, y data
X = np.load(f'./imgs_{IMAGE_SIZE}_3.npy', allow_pickle=True)
y = np.load(f'./labels_{IMAGE_SIZE}_3.npy', allow_pickle=True)
print(X.shape)
print(y.shape)

# transform the labels to binary representation so that we can train on the data
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y)

print(list(mlb.classes_))

X = applications.xception.preprocess_input(X)

base_model = applications.Xception(weights='imagenet',
                                   input_shape=(150, 150, 3),
                                   include_top=False)
base_model.trainable = False

inputs = layers.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Flatten()(x)
x = layers.Dense(2048, activation='relu')(x)
outputs = layers.Dense(len(mlb.classes_), activation='sigmoid')(x)
model = models.Model(inputs, outputs)
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

wandb.init(project="texture-classification",
示例#11
0
    def encode(self, input_image):
        """
        :param input_image: (batch, height, width, channel)
        """
        input_shape = input_image.get_shape()
        height = input_shape[1]
        net_name = self.net_name
        weights = "imagenet" if self.pretrained_weight else None

        jsonfile = op.join(opts.PROJECT_ROOT, "model", "build_model",
                           "scaled_layers.json")
        output_layers = self.read_output_layers(jsonfile)
        out_layer_names = output_layers[net_name]

        if net_name == "MobileNetV2":
            from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
            pproc_img = layers.Lambda(lambda x: preprocess_input(x),
                                      name="preprocess_mobilenet")(input_image)
            ptmodel = tfapp.MobileNetV2(input_shape=input_shape,
                                        include_top=False,
                                        weights=weights)

        elif net_name == "NASNetMobile":
            from tensorflow.keras.applications.nasnet import preprocess_input
            assert height == 128

            def preprocess_layer(x):
                x = preprocess_input(x)
                x = tf.image.resize(x,
                                    size=NASNET_SHAPE[:2],
                                    method="bilinear")
                return x

            pproc_img = layers.Lambda(lambda x: preprocess_layer(x),
                                      name="preprocess_nasnet")(input_image)
            ptmodel = tfapp.NASNetMobile(input_shape=NASNET_SHAPE,
                                         include_top=False,
                                         weights=weights)

        elif net_name == "DenseNet121":
            from tensorflow.keras.applications.densenet import preprocess_input
            pproc_img = layers.Lambda(lambda x: preprocess_input(x),
                                      name="preprocess_densenet")(input_image)
            ptmodel = tfapp.DenseNet121(input_shape=input_shape,
                                        include_top=False,
                                        weights=weights)

        elif net_name == "VGG16":
            from tensorflow.keras.applications.vgg16 import preprocess_input
            pproc_img = layers.Lambda(lambda x: preprocess_input(x),
                                      name="preprocess_vgg16")(input_image)
            ptmodel = tfapp.VGG16(input_shape=input_shape,
                                  include_top=False,
                                  weights=weights)

        elif net_name == "Xception":
            from tensorflow.keras.applications.xception import preprocess_input
            assert height == 128

            def preprocess_layer(x):
                x = preprocess_input(x)
                x = tf.image.resize(x,
                                    size=XCEPTION_SHAPE[:2],
                                    method="bilinear")
                return x

            pproc_img = layers.Lambda(lambda x: preprocess_layer(x),
                                      name="preprocess_xception")(input_image)
            ptmodel = tfapp.Xception(input_shape=XCEPTION_SHAPE,
                                     include_top=False,
                                     weights=weights)

        elif net_name == "ResNet50V2":
            from tensorflow.keras.applications.resnet import preprocess_input
            pproc_img = layers.Lambda(lambda x: preprocess_input(x),
                                      name="preprocess_resnet")(input_image)
            ptmodel = tfapp.ResNet50V2(input_shape=input_shape,
                                       include_top=False,
                                       weights=weights)

        elif net_name == "NASNetLarge":
            from tensorflow.keras.applications.nasnet import preprocess_input
            assert height == 128

            def preprocess_layer(x):
                x = preprocess_input(x)
                x = tf.image.resize(x,
                                    size=NASNET_SHAPE[:2],
                                    method="bilinear")
                return x

            pproc_img = layers.Lambda(lambda x: preprocess_layer(x),
                                      name="preprocess_nasnet")(input_image)
            ptmodel = tfapp.NASNetLarge(input_shape=NASNET_SHAPE,
                                        include_top=False,
                                        weights=weights)
        else:
            raise WrongInputException("Wrong pretrained model name: " +
                                      net_name)

        # collect multi scale convolutional features
        layer_outs = []
        for layer_name in out_layer_names:
            layer = ptmodel.get_layer(name=layer_name[1], index=layer_name[0])
            # print("extract feature layers:", layer.name, layer.get_input_shape_at(0), layer.get_output_shape_at(0))
            layer_outs.append(layer.output)

        # create model with multi scale features
        multi_scale_model = tf.keras.Model(ptmodel.input,
                                           layer_outs,
                                           name=net_name + "_base")
        features_ms = multi_scale_model(pproc_img)
        return features_ms