num_train_classes=1200, num_val_classes=100) shape = (28, 28, 1) latent_dim = 128 omniglot_generator = get_generator(latent_dim) omniglot_discriminator = get_discriminator() omniglot_parser = OmniglotParser(shape=shape) experiment_name = prefix + str(labeled_percentage) gan = GAN( 'omniglot', image_shape=shape, latent_dim=latent_dim, database=omniglot_database, parser=omniglot_parser, generator=omniglot_generator, discriminator=omniglot_discriminator, visualization_freq=50, d_learning_rate=0.0003, g_learning_rate=0.0003, ) # gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50) gan.load_latest_checkpoint() print("GAN training finished") time.sleep(1) train_folders = omniglot_database.train_folders keys = list(train_folders.keys()) sample_size = np.max([ N_WAY * META_BATCH_SIZE,
mini_imagenet_database = MiniImagenetDatabase() shape = (84, 84, 3) latent_dim = 512 mini_imagenet_generator = get_generator(latent_dim) mini_imagenet_discriminator = get_discriminator() mini_imagenet_parser = MiniImagenetParser(shape=shape) experiment_name = prefix+str(labeled_percentage) # for the SSGAN we need to feed the labels, L, when initializing gan = GAN( 'mini_imagenet', image_shape=shape, latent_dim=latent_dim, database=mini_imagenet_database, parser=mini_imagenet_parser, generator=mini_imagenet_generator, discriminator=mini_imagenet_discriminator, visualization_freq=1, d_learning_rate=0.0003, g_learning_rate=0.0003, ) gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50) gan.load_latest_checkpoint() print("training GAN is done") time.sleep(1) # Split labeled and not labeled train_folders = mini_imagenet_database.train_folders keys = list(train_folders.keys()) labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False)
if __name__ == '__main__': fungi_database = FungiDatabase() shape = (84, 84, 3) latent_dim = 512 mini_imagenet_generator = get_generator(latent_dim) mini_imagenet_discriminator = get_discriminator() mini_imagenet_parser = MiniImagenetParser(shape=shape) gan = GAN( 'fungi', image_shape=shape, latent_dim=latent_dim, database=fungi_database, parser=mini_imagenet_parser, generator=mini_imagenet_generator, discriminator=mini_imagenet_discriminator, visualization_freq=1, d_learning_rate=0.0003, g_learning_rate=0.0003, ) # gan.perform_training(epochs=1000, checkpoint_freq=5) gan.load_latest_checkpoint() maml_gan = MAMLGANFungi(gan=gan, latent_dim=latent_dim, generated_image_shape=shape, database=fungi_database, network_cls=MiniImagenetModel, n=5, k_ml=1,