num_train_classes=1200,
                                         num_val_classes=100)
    shape = (28, 28, 1)
    latent_dim = 128
    omniglot_generator = get_generator(latent_dim)
    omniglot_discriminator = get_discriminator()
    omniglot_parser = OmniglotParser(shape=shape)

    experiment_name = prefix + str(labeled_percentage)

    gan = GAN(
        'omniglot',
        image_shape=shape,
        latent_dim=latent_dim,
        database=omniglot_database,
        parser=omniglot_parser,
        generator=omniglot_generator,
        discriminator=omniglot_discriminator,
        visualization_freq=50,
        d_learning_rate=0.0003,
        g_learning_rate=0.0003,
    )
    # gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50)
    gan.load_latest_checkpoint()

    print("GAN training finished")
    time.sleep(1)

    train_folders = omniglot_database.train_folders
    keys = list(train_folders.keys())
    sample_size = np.max([
        N_WAY * META_BATCH_SIZE,
Exemplo n.º 2
0
    mini_imagenet_database = MiniImagenetDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    mini_imagenet_generator = get_generator(latent_dim)
    mini_imagenet_discriminator = get_discriminator()
    mini_imagenet_parser = MiniImagenetParser(shape=shape)

    experiment_name = prefix+str(labeled_percentage)

    # for the SSGAN we need to feed the labels, L, when initializing
    gan = GAN(
        'mini_imagenet',
        image_shape=shape,
        latent_dim=latent_dim,
        database=mini_imagenet_database,
        parser=mini_imagenet_parser,
        generator=mini_imagenet_generator,
        discriminator=mini_imagenet_discriminator,
        visualization_freq=1,
        d_learning_rate=0.0003,
        g_learning_rate=0.0003,
    )
    gan.perform_training(epochs=GAN_EPOCHS, checkpoint_freq=50)
    gan.load_latest_checkpoint()

    print("training GAN is done")
    time.sleep(1)

    # Split labeled and not labeled
    train_folders = mini_imagenet_database.train_folders
    keys = list(train_folders.keys())
    labeled_keys = np.random.choice(keys, int(len(train_folders.keys())*labeled_percentage), replace=False)

if __name__ == '__main__':
    fungi_database = FungiDatabase()
    shape = (84, 84, 3)
    latent_dim = 512
    mini_imagenet_generator = get_generator(latent_dim)
    mini_imagenet_discriminator = get_discriminator()
    mini_imagenet_parser = MiniImagenetParser(shape=shape)

    gan = GAN(
        'fungi',
        image_shape=shape,
        latent_dim=latent_dim,
        database=fungi_database,
        parser=mini_imagenet_parser,
        generator=mini_imagenet_generator,
        discriminator=mini_imagenet_discriminator,
        visualization_freq=1,
        d_learning_rate=0.0003,
        g_learning_rate=0.0003,
    )
    # gan.perform_training(epochs=1000, checkpoint_freq=5)
    gan.load_latest_checkpoint()

    maml_gan = MAMLGANFungi(gan=gan,
                            latent_dim=latent_dim,
                            generated_image_shape=shape,
                            database=fungi_database,
                            network_cls=MiniImagenetModel,
                            n=5,
                            k_ml=1,