Пример #1
0
    mini_imagenet_database = MiniImagenetDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    mini_imagenet_generator = get_generator(latent_dim)
    mini_imagenet_discriminator = get_discriminator()
    mini_imagenet_parser = MiniImagenetParser(shape=shape)

    experiment_name = prefix+str(labeled_percentage)

    # for the SSGAN we need to feed the labels, L, when initializing
    gan = GAN(
        'mini_imagenet',
        image_shape=shape,
        latent_dim=latent_dim,
        database=mini_imagenet_database,
        parser=mini_imagenet_parser,
        generator=mini_imagenet_generator,
        discriminator=mini_imagenet_discriminator,
        visualization_freq=1,
        d_learning_rate=0.0003,
        g_learning_rate=0.0003,
    )
    gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50)
    gan.load_latest_checkpoint()

    print("training GAN is done")
    time.sleep(1)

    # Split labeled and not labeled
    train_folders = mini_imagenet_database.train_folders
    keys = list(train_folders.keys())
    labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False)
Пример #2
0
if __name__ == '__main__':
    omniglot_database = OmniglotDatabase(random_seed=47,
                                         num_train_classes=1200,
                                         num_val_classes=100)
    shape = (28, 28, 1)
    latent_dim = 128
    omniglot_generator = get_generator(latent_dim)
    omniglot_discriminator = get_discriminator()
    omniglot_parser = OmniglotParser(shape=shape)

    gan = GAN(
        'omniglot',
        image_shape=shape,
        latent_dim=latent_dim,
        database=omniglot_database,
        parser=omniglot_parser,
        generator=omniglot_generator,
        discriminator=omniglot_discriminator,
        visualization_freq=50,
        d_learning_rate=0.0003,
        g_learning_rate=0.0003,
    )
    gan.perform_training(epochs=500, checkpoint_freq=50)
    gan.load_latest_checkpoint()

    maml_gan = MAMLGAN(gan=gan,
                       latent_dim=latent_dim,
                       generated_image_shape=shape,
                       database=omniglot_database,
                       network_cls=SimpleModel,
                       n=5,
                       k=1,