示例#1
0
def get_model(model_type):
    if model_type == 'shufflenet':
        model = ShuffleNet(groups=3, weights=None)
    elif model_type == 'shufflenet_v2':
        model = ShuffleNetV2(bottleneck_ratio=1, weights=None)
    else:
        raise ValueError('Unsupported model type')
    return model
示例#2
0
def get_model(model_type, include_top=True):
    if model_type == 'shufflenet':
        input_shape = (224, 224, 3)
        model = ShuffleNet(groups=3, weights=None, include_top=include_top)
    elif model_type == 'shufflenet_v2':
        input_shape = (224, 224, 3)
        model = ShuffleNetV2(bottleneck_ratio=1, weights=None, include_top=include_top)
    elif model_type == 'nanonet':
        input_shape = (224, 224, 3)
        model = NanoNet(weights=None, include_top=include_top)
    else:
        raise ValueError('Unsupported model type')
    return model, input_shape[:2]
def get_model(model_type, include_top=True):
    if model_type == 'shufflenet':
        input_shape = (224, 224, 3)
        model = ShuffleNet(input_shape=input_shape,
                           groups=3,
                           weights=None,
                           include_top=include_top)
    elif model_type == 'shufflenet_v2':
        input_shape = (224, 224, 3)
        model = ShuffleNetV2(input_shape=input_shape,
                             bottleneck_ratio=1,
                             weights=None,
                             include_top=include_top)
    elif model_type == 'nanonet':
        input_shape = (224, 224, 3)
        model = NanoNet(input_shape=input_shape,
                        weights=None,
                        include_top=include_top)
    elif model_type == 'darknet53':
        input_shape = (224, 224, 3)
        model = DarkNet53(input_shape=input_shape,
                          weights=None,
                          include_top=include_top)
    elif model_type == 'cspdarknet53':
        input_shape = (224, 224, 3)
        model = CSPDarkNet53(input_shape=input_shape,
                             weights=None,
                             include_top=include_top)
    elif model_type == 'mobilevit_s':
        input_shape = (256, 256, 3)
        model = MobileViT_S(input_shape=input_shape,
                            weights=None,
                            include_top=include_top)
    elif model_type == 'mobilevit_xs':
        input_shape = (256, 256, 3)
        model = MobileViT_XS(input_shape=input_shape,
                             weights=None,
                             include_top=include_top)
    elif model_type == 'mobilevit_xxs':
        input_shape = (256, 256, 3)
        model = MobileViT_XXS(input_shape=input_shape,
                              weights=None,
                              include_top=include_top)
    else:
        raise ValueError('Unsupported model type')
    return model, input_shape[:2]
示例#4
0
def MyNet(input_shape=(218, 178, 3), n_classes=2360):
    img_input = Input(shape=input_shape)

    # Stem block: 52 x 42 x 64
    x = conv2d_bn(img_input, 16, 3, strides=2, padding='valid')
    x = conv2d_bn(x, 16, 3, padding='valid')
    x = conv2d_bn(x, 32, 3)
    x = MaxPooling2D(3, strides=2)(x)

    # Mixed 5b (Inception-A block): 52 x 42 x 160
    branch_0 = conv2d_bn(x, 48, 1)
    branch_1 = conv2d_bn(x, 24, 1)
    branch_1 = conv2d_bn(branch_1, 32, 5)
    branch_2 = conv2d_bn(x, 32, 1)
    branch_2 = conv2d_bn(branch_2, 48, 3)
    branch_2 = conv2d_bn(branch_2, 48, 3)
    branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 32, 1)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    channel_axis = 1 if K.image_data_format() == 'channels_first' else 3
    x = Concatenate(axis=channel_axis, name='mixed_5b')(branches)

    # 5x block35 (Inception-ResNet-A block): 52 x 42 x 160
    for block_idx in range(1, 6):
        x = separable_inception_resnet_block(x,
                                             scale=0.17,
                                             block_type='block35',
                                             block_idx=block_idx)

    shuffle_model = ShuffleNet(include_top=True,
                               input_tensor=x,
                               scale_factor=1.0,
                               num_shuffle_units=[10, 5],
                               groups=1,
                               pooling='avg',
                               classes=n_classes)

    # Create model
    model = Model(img_input, shuffle_model.output, name='MyNet')
    model.summary()
    return model
def preprocess(x):
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    x /= 255.0
    x -= 0.5
    x *= 2.0
    return x


if __name__ == '__main__':
    groups = 3
    batch_size = 128
    inital_epoch = 0
    ds = '/mnt/daten/Development/ILSVRC2012_256'

    model = ShuffleNet(groups=groups, pooling='avg')
    plot_model(model, 'model.png', show_shapes=True)
    # model.load_weights('%s.hdf5' % model.name, by_name=True)
    csv_logger = CSVLogger('%s.log' % model.name, append=(inital_epoch is not 0))
    checkpoint = ModelCheckpoint(filepath='%s.hdf5' % model.name, verbose=1,
                                 save_best_only=True, monitor='val_acc', mode='max')

    learn_rates = [0.05, 0.01, 0.005, 0.001, 0.0005]
    lr_scheduler = LearningRateScheduler(lambda epoch: learn_rates[epoch // 30])

    train_datagen = ImageDataGenerator(preprocessing_function=preprocess,
                                       zoom_range=0.25,
                                       width_shift_range=0.05,
                                       height_shift_range=0.05,
                                       horizontal_flip=True)
示例#6
0
from shufflenet import ShuffleNet

model = ShuffleNet()