# vae.perform_training(epochs=20, checkpoint_freq=100)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAML_VAE(vae=vae,
                        latent_algorithm='p1',
                        database=mini_imagenet_database,
                        network_cls=MiniImagenetModel,
                        n=5,
                        k_ml=1,
                        k_val_ml=5,
                        k_val=1,
                        k_val_val=15,
                        k_test=1,
                        k_val_test=15,
                        meta_batch_size=4,
                        num_steps_ml=5,
                        lr_inner_ml=0.05,
                        num_steps_validation=5,
                        save_after_iterations=1000,
                        meta_learning_rate=0.001,
                        report_validation_frequency=200,
                        log_train_images_after_iteration=200,
                        num_tasks_val=100,
                        clip_gradients=True,
                        experiment_name='mini_imagenet_crop_random_uniform',
                        val_seed=42,
                        val_test_batch_norm_momentum=0.0)

    maml_vae.visualize_meta_learning_task(shape, num_tasks_to_visualize=2)

    maml_vae.train(iterations=8000)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAML_VAE(
        vae=vae,
        database=voxceleb_database,
        network_cls=SimpleModel,
        n=5,
        k=1,
        k_val_ml=5,
        k_val_val=15,
        k_val_test=15,
        k_test=5,
        meta_batch_size=4,
        num_steps_ml=5,
        lr_inner_ml=0.4,
        num_steps_validation=5,
        save_after_iterations=1000,
        meta_learning_rate=0.001,
        report_validation_frequency=200,
        log_train_images_after_iteration=200,
        number_of_tasks_val=100,
        number_of_tasks_test=1000,
        clip_gradients=False,
        experiment_name='voxceleb_std_1.0',
        val_seed=42,
        val_test_batch_norm_momentum=0.0
    )

    maml_vae.visualize_meta_learning_task(shape, num_tasks_to_visualize=2)
    vae.perform_training(epochs=500, checkpoint_freq=100)
    vae.load_latest_checkpoint()
    # vae.visualize_meta_learning_task()

    maml_vae = MAML_VAE(vae=vae,
                        database=omniglot_database,
                        latent_algorithm='p1',
                        network_cls=SimpleModel,
                        n=20,
                        k_ml=1,
                        k_val_ml=1,
                        k_val=1,
                        k_val_val=1,
                        k_test=1,
                        k_val_test=1,
                        meta_batch_size=4,
                        num_steps_ml=5,
                        lr_inner_ml=0.4,
                        num_steps_validation=5,
                        save_after_iterations=1000,
                        meta_learning_rate=0.001,
                        report_validation_frequency=200,
                        log_train_images_after_iteration=200,
                        num_tasks_val=100,
                        clip_gradients=False,
                        experiment_name='omniglot_vae_0.5_shift',
                        val_seed=42,
                        val_test_batch_norm_momentum=0.0)

    # maml_vae.visualize_meta_learning_task(shape, num_tasks_to_visualize=2)

    maml_vae.train(iterations=1000)