Example #1
0
    omniglot_database = OmniglotDatabase(random_seed=47,
                                         num_train_classes=1200,
                                         num_val_classes=100)
    shape = (28, 28, 1)
    latent_dim = 20
    omniglot_encoder = get_encoder(latent_dim)
    omniglot_decoder = get_decoder(latent_dim)
    omniglot_parser = OmniglotParser(shape=shape)

    vae = VAE(
        'omniglot',
        image_shape=shape,
        latent_dim=latent_dim,
        database=omniglot_database,
        parser=omniglot_parser,
        encoder=omniglot_encoder,
        decoder=omniglot_decoder,
        visualization_freq=5,
        learning_rate=0.001,
    )
    vae.perform_training(epochs=1000, checkpoint_freq=100)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAML_VAE(vae=vae,
                        database=omniglot_database,
                        latent_algorithm='p1',
                        network_cls=SimpleModel,
                        n=5,
                        k=1,
    # import tensorflow as tf
    # tf.config.experimental_run_functions_eagerly(True)

    voxceleb_database = VoxCelebDatabase()
    shape = (16000, 1)
    latent_dim = 20
    voxceleb_encoder = get_encoder(latent_dim)
    voxceleb_decoder = get_decoder(latent_dim)
    voxceleb_parser = VoxCelebParser(shape=shape)

    vae = VAE(
        'voxceleb',
        image_shape=shape,
        latent_dim=latent_dim,
        database=voxceleb_database,
        parser=voxceleb_parser,
        encoder=voxceleb_encoder,
        decoder=voxceleb_decoder,
        visualization_freq=1,
        learning_rate=0.001,
    )
    vae.perform_training(epochs=1000, checkpoint_freq=100, vis_callback_cls=AudioCallback)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAML_VAE(
        vae=vae,
        database=voxceleb_database,
        network_cls=SimpleModel,
        n=5,
        k=1,
    # import tensorflow as tf
    # tf.config.experimental_run_functions_eagerly(True)

    celebalot_database = CelebADatabase()
    shape = (84, 84, 3)
    latent_dim = 500
    celebalot_encoder = get_encoder(latent_dim)
    celebalot_decoder = get_decoder(latent_dim)
    celebalot_parser = CelebAParser(shape=shape)

    vae = VAE(
        'celeba',
        image_shape=shape,
        latent_dim=latent_dim,
        database=celebalot_database,
        parser=celebalot_parser,
        encoder=celebalot_encoder,
        decoder=celebalot_decoder,
        visualization_freq=1,
        learning_rate=0.001,
    )
    # vae.perform_training(epochs=20, checkpoint_freq=100)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAMLVAECelebA(vae=vae,
                             latent_algorithm='p3',
                             database=celebalot_database,
                             network_cls=MiniImagenetModel,
                             n=2,
                             k=1,