mini_imagenet_database = MiniImagenetDatabase() shape = (84, 84, 3) latent_dim = 512 mini_imagenet_generator = get_generator(latent_dim) mini_imagenet_discriminator = get_discriminator() mini_imagenet_parser = MiniImagenetParser(shape=shape) experiment_name = prefix+str(labeled_percentage) # for the SSGAN we need to feed the labels, L, when initializing gan = GAN( 'mini_imagenet', image_shape=shape, latent_dim=latent_dim, database=mini_imagenet_database, parser=mini_imagenet_parser, generator=mini_imagenet_generator, discriminator=mini_imagenet_discriminator, visualization_freq=1, d_learning_rate=0.0003, g_learning_rate=0.0003, ) gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50) gan.load_latest_checkpoint() print("training GAN is done") time.sleep(1) # Split labeled and not labeled train_folders = mini_imagenet_database.train_folders keys = list(train_folders.keys()) labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False)
if __name__ == '__main__': omniglot_database = OmniglotDatabase(random_seed=47, num_train_classes=1200, num_val_classes=100) shape = (28, 28, 1) latent_dim = 128 omniglot_generator = get_generator(latent_dim) omniglot_discriminator = get_discriminator() omniglot_parser = OmniglotParser(shape=shape) gan = GAN( 'omniglot', image_shape=shape, latent_dim=latent_dim, database=omniglot_database, parser=omniglot_parser, generator=omniglot_generator, discriminator=omniglot_discriminator, visualization_freq=50, d_learning_rate=0.0003, g_learning_rate=0.0003, ) gan.perform_training(epochs=500, checkpoint_freq=50) gan.load_latest_checkpoint() maml_gan = MAMLGAN(gan=gan, latent_dim=latent_dim, generated_image_shape=shape, database=omniglot_database, network_cls=SimpleModel, n=5, k=1,