Esempio n. 1
0
def train(args, data_loader):
    model = VAE(input_size=args.input_size, h_dim=args.h_dim, z_dim=args.z_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(args.epochs):
        for i, (x, _) in enumerate(data_loader):
            x = x.view(-1, args.input_size)
            x_reconst, mu, log_var = model(x)
            reconst_loss = F.binary_cross_entropy(x_reconst,
                                                  x,
                                                  size_average=False)
            kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

            loss = reconst_loss + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                print('Epoch {}/{}, Step {}/{}, Loss: {:.4f}'.format(
                    epoch + 1, args.epochs, i + 1, len(data_loader),
                    loss.item()))

    return model
Esempio n. 2
0
logging.info(vars(options))

methods = ["c3d", "crnn", "vae", "gan"]
assert options.method in methods, "Not a valid method"

if options.method == "c3d":
    model = C3DModel(options.sequence_size,
                     batch_size=options.batch_size,
                     weight_file=options.weight_file)
elif options.method == "crnn":
    model = CRNN(options.sequence_size,
                 batch_size=options.batch_size,
                 weight_file=options.weight_file)
elif options.method == "vae":
    model = VAE(options.sequence_size,
                batch_size=options.batch_size,
                weight_file=options.weight_file)
else:
    print("{} is not available at this moment".format(options.method))
    exit(0)

db = Database(options.db_path,
              options.sequence_size,
              batch_size=options.batch_size,
              size=model.img_size,
              output_size=model.output_size,
              custom_lenght=options.custom_lenght)

n_epoch = 0
max_epoch = options.n_epochs
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser(description='VAE')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='Batch size for training (default=100)')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=100,
                        help='Number of epochs to train (default=100)')
    parser.add_argument('--latent_dim',
                        type=int,
                        default=32,
                        help='Dimension of latent space (default=32)')
    parser.add_argument('--episode_len',
                        type=int,
                        default=1000,
                        help='Length of rollout (default=1000)')
    parser.add_argument(
        '--kl_bound',
        type=float,
        default=0.5,
        help='Clamp KL loss by kl_bound*latent_dim from below (default=0.5)')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='Learning rate for optimizer (default=1e-4)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training (default=False)')
    parser.add_argument('--dir_name', help='Rollouts directory name')
    parser.add_argument('--log_interval',
                        nargs='?',
                        default='2',
                        type=int,
                        help='After how many batches to log (default=2)')
    args = parser.parse_args()

    # Read in and prepare the data.
    dataset = RolloutDataset(
        path_to_dir=os.path.join(DATA_DIR, 'rollouts', args.dir_name),
        size=int(args.dir_name.split('_')[-1]))  # TODO: hack. fix?
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # Use GPU if available.
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Set up the model and the optimizer.
    vae = VAE(latent_dim=args.latent_dim).to(device)
    optimizer = optim.Adam(params=vae.parameters(), lr=args.learning_rate)

    # Training procedure.
    def train(epoch):
        vae.train()
        train_loss = 0
        start_time = datetime.datetime.now()

        # for rollout_id, rollout in enumerate(data_loader):
        #     n_batches = len(rollout.squeeze()['obs']) // args.batch_size
        #     for batch_id in range(n_batches):
        #         start, stop = args.batch_size * batch_id, args.batch_size * (batch_id + 1)
        #         batch = rollout.squeeze()['obs'][start:stop]
        #         batch = batch.to(device)
        #
        #         optimizer.zero_grad()
        #
        #         recon_batch, mu, logvar = vae(batch)
        #         rec_loss, kl_loss = vae_loss(recon_batch, batch, mu, logvar, kl_bound=args.kl_bound)
        #         loss = rec_loss + kl_loss
        #         loss.backward()
        #         train_loss += loss.item()
        #
        #         optimizer.step()
        #
        #         if batch_id % args.log_interval == 0:
        #             print(
        #                 'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}'
        #                  .format(epoch, (batch_id + 1) * len(batch), len(data_loader.dataset),
        #                          100. * (batch_id + 1) / len(data_loader),
        #                          rec_loss.item() / len(batch),
        #                          kl_loss.item() / len(batch)))

        for batch_id, batch in enumerate(data_loader):
            batch = batch['obs']
            # Take a random observation from each rollout.
            batch = batch[
                torch.arange(args.batch_size, dtype=torch.long),
                torch.
                randint(high=1000, size=(args.batch_size, ), dtype=torch.long)]
            # TODO: use all obs from the rollout (from the randomized start)?
            batch = batch.to(device)

            optimizer.zero_grad()

            recon_batch, mu, logvar = vae(batch)
            rec_loss, kl_loss = vae_loss(recon_batch,
                                         batch,
                                         mu,
                                         logvar,
                                         kl_bound=args.kl_bound)
            loss = rec_loss + kl_loss
            loss.backward()
            train_loss += loss.item()

            optimizer.step()

            if batch_id % args.log_interval == 0:
                print(
                    'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}'
                    .format(epoch, (batch_id + 1) * len(batch),
                            len(data_loader.dataset),
                            100. * (batch_id + 1) / len(data_loader),
                            rec_loss.item() / len(batch),
                            kl_loss.item() / len(batch)))

        duration = datetime.datetime.now() - start_time
        print(
            'Epoch {} average train loss was {:.4f} after {}m{}s of training.'.
            format(epoch, train_loss / len(data_loader.dataset),
                   *divmod(int(duration.total_seconds()), 60)))

    # TODO: add test for VAE?

    # Train loop.
    for i in range(1, args.n_epochs + 1):
        train(i)

    # Save the learned model.
    if not os.path.exists(os.path.join(DATA_DIR, 'vae')):
        os.makedirs(os.path.join(DATA_DIR, 'vae'))

    torch.save(
        vae.state_dict(),
        os.path.join(
            DATA_DIR, 'vae',
            datetime.datetime.today().isoformat() + '_' + str(args.n_epochs)))
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser(
        description='Evolutionary training of controller')
    parser.add_argument('--env_name',
                        nargs='?',
                        default='CarRacing-v0',
                        help='Environment to use (default=CarRacing-v0)')
    parser.add_argument(
        '--n_rollouts',
        type=int,
        default=1,
        help='How many rollouts to perform when evaluating (default=1)')
    parser.add_argument('--n_generations',
                        type=int,
                        default=300,
                        help='Number of generations to train (default=300)')
    parser.add_argument('--latent_dim',
                        type=int,
                        default=32,
                        help='Dimension of latent space (default=32)')
    parser.add_argument('--seq_len',
                        type=int,
                        default=10,
                        help='Length of sequences for learning (default=10)')
    parser.add_argument('--action_dim',
                        type=int,
                        default=3,
                        help='Dimension of action space (default=3)')
    parser.add_argument('--rnn_hidden_dim',
                        nargs='?',
                        type=int,
                        default=256,
                        help='Dimension of RNN hidden state (default=256)')
    parser.add_argument(
        '--n_gaussians',
        type=int,
        default=5,
        help='Number of gaussians for the Mixture Density Network (default=5)')
    parser.add_argument(
        '--pop_size',
        type=int,
        default=64,
        help='Population size for evolutionary search (default=64)')
    parser.add_argument(
        '--n_workers',
        type=int,
        default=32,
        help='Number of workers for parallel processing (default=32)')
    parser.add_argument('--vae_fname', help='VAE model file name')
    parser.add_argument('--rnn_fname', nargs='?', help='RNN model file name')
    parser.add_argument(
        '--eval_interval',
        nargs='?',
        default=15,
        type=int,
        help='After how many generation to evaluate best params (default=15)')
    args = parser.parse_args()

    device = torch.device('cpu')

    # Load the VAE model from file.
    vae = VAE(latent_dim=args.latent_dim)
    vae.load_state_dict(
        torch.load(os.path.join(DATA_DIR, 'vae', args.vae_fname),
                   map_location={'cuda:0':
                                 'cpu'}))  # Previously trained on GPU.
    vae.to(device)

    # TODO: add identity/None RNN for dealing with the below?
    if args.rnn_fname is not None:  # Use memory module.
        # Load the MDNRNN model from file.
        mdnrnn = MDNRNN(action_dim=args.action_dim,
                        hidden_dim=args.rnn_hidden_dim,
                        latent_dim=args.latent_dim,
                        n_gaussians=args.n_gaussians)
        mdnrnn.load_state_dict(
            torch.load(os.path.join(DATA_DIR, 'rnn', args.rnn_fname),
                       map_location={'cuda:0':
                                     'cpu'}))  # Previously trained on GPU.
        mdnrnn.to(device)
    else:  # TODO: hacky, but dunno how to have default value for dim and pass it later without too many ifs. Fix?
        args.rnn_hidden_dim = 0
        mdnrnn = None

    # Set up controller model.
    agent = ControllerAgent(vae=vae,
                            v_dim=args.latent_dim,
                            action_dim=args.action_dim,
                            rnn=mdnrnn,
                            m_dim=args.rnn_hidden_dim)

    # Set up evolutionary strategy optimizer.
    with suppress_stdout(
    ):  # Suppress evolutionary strategy optimizer creation message.
        es = CMAES(
            num_params=param_count(agent.controller),
            sigma_init=0.1,  # initial standard deviation
            popsize=args.pop_size)

    # Set up multiprocessing.
    pool = mp.Pool(processes=args.n_workers)

    # Create results folder.
    dir_name = datetime.datetime.today().isoformat() + '_' + str(
        args.rnn_hidden_dim)
    os.makedirs(os.path.join(RESULTS_DIR, 'controller', dir_name))

    # TODO: add antithetic?
    for i in range(1, args.n_generations + 1):
        start_time = datetime.datetime.now()

        # Create a set of candidate specimens.
        specimens = es.ask()

        # Evaluate the fitness of candidate specimens.
        func = partial(evaluate,
                       env_name=args.env_name,
                       vae=vae,
                       rnn=mdnrnn,
                       v_dim=args.latent_dim,
                       action_dim=args.action_dim,
                       m_dim=args.rnn_hidden_dim,
                       n_rollouts=args.n_rollouts)
        fitness_list = np.array(pool.map(func, specimens))

        # Give list of fitness results back to ES.
        es.tell(fitness_list)

        # get best parameter, fitness from ES
        es_solution = es.result()
        duration = datetime.datetime.now() - start_time

        history = {
            'best_params': es_solution[0],  # Best historical parameters.
            'best_fitness': es_solution[1],  # Best historical reward.
            'curr_best_fitness':
            es_solution[2],  # Best fitness of current generation.
            'mean_fitness':
            fitness_list.mean(),  # Mean fitness of current generation.
            'std_fitness':
            fitness_list.std()  # Std of fitness of current generation.
        }
        np.savez(os.path.join(RESULTS_DIR, 'controller', dir_name, str(i)),
                 **history)

        print(
            'Gen: {0:}\t| Best fit of gen: {1:.2f}\t| Best fit historical: {2:.2f}\t|'
            ' Mean fit: {3:.2f}\t| Std of fit: {4:.2f}\t| Time: {5:}m {6:}s'.
            format(i, es_solution[2], es_solution[1], fitness_list.mean(),
                   fitness_list.std(),
                   *divmod(int(duration.total_seconds()), 60)))

        if i % args.eval_interval == 0:
            start_time = datetime.datetime.now()
            eval_fitness_list = np.array(
                pool.map(
                    func,
                    np.broadcast_to(es_solution[0], (args.n_workers, ) +
                                    es_solution[0].shape)))
            duration = datetime.datetime.now() - start_time
            print(
                '{0:}-worker average fit of best params after gen {1:}: {2:.2f}. Time: {3:}m {4:}s.'
                .format(args.n_workers, i, eval_fitness_list.mean(),
                        *divmod(int(duration.total_seconds()), 60)))

            np.savez(
                os.path.join(RESULTS_DIR, 'controller', dir_name,
                             str(i) + '_eval'), eval_fitness_list)
Esempio n. 5
0
def run(batch_size, max_batch_steps, epochs, annealing_epochs, temp, min_af, loader_workers, eval_freq, _run: "Run"):
    pyro.clear_param_store()
    _run.add_artifact(_run.config["config_file"])

    # Seed randomness for repeatability
    seed_random()

    # dataset
    wildfire_dataset = WildFireDataset(train=True, config_file="config.ini")
    data_loader = DataLoader(wildfire_dataset, batch_size=batch_size, shuffle=True, num_workers=loader_workers)
    expected_batch_size = np.ceil(len(wildfire_dataset) / batch_size)
    expected_batch_size = max_batch_steps if max_batch_steps > 0 else expected_batch_size
    vae_config = get_vae_config()

    with open(temp / "vae_config.json", "w") as fptr:
        json.dump(vae_config.__dict__, fptr, indent=1)

    _run.add_artifact(temp / "vae_config.json")

    vae = VAE(vae_config)
    svi = SVI(vae.model, vae.guide, vae.optimizer, loss=Trace_ELBO())

    from src.data.dataset import _ct

    for step in trange(epochs, desc="Epoch: ", ascii=False, dynamic_ncols=True,
                       bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}'):
        if step < annealing_epochs:
            annealing_factor = min_af + (1.0 - min_af) * step / annealing_epochs
        else:
            annealing_factor = 1.0
        _run.log_scalar("annealing_factor", annealing_factor, step=step)

        epoch_elbo = 0.0
        epoch_time_slices = 0
        for batch_steps_i, d in tqdm(enumerate(data_loader), desc="Batch: ", ascii=False, dynamic_ncols=True,
                                     bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}',
                                     total=expected_batch_size,
                                     leave=False):
            epoch_elbo += svi.step(_ct(d.diurnality), _ct(d.viirs), _ct(d.land_cover), _ct(d.latitude),
                                   _ct(d.longitude), _ct(d.meteorology), annealing_factor)
            epoch_time_slices += d.viirs.shape[0] * d.viirs.shape[0]
            if 0 < max_batch_steps == batch_steps_i:
                break
        elbo = -epoch_elbo / epoch_time_slices
        print(f" [{step:05d}] ELBO: {elbo:.3f}", end="")
        _run.log_scalar("elbo", elbo, step=step)
        alpha = pyro.param("alpha").item()
        beta = pyro.param("beta").item()
        _run.log_scalar("alpha", alpha, step=step)
        _run.log_scalar("beta", beta, step=step)
        inferred_mean, inferred_std = beta_to_mean_std(alpha, beta)
        _run.log_scalar("inferred_mean", inferred_mean, step=step)
        _run.log_scalar("inferred_std", inferred_std, step=step)

        if eval_freq > 0 and step > 0 and step % eval_freq == 0:
            logger.info("Evaluating")
            eval_light(Path(_run.observers[0].dir), vae, data_loader, wildfire_dataset, step)
            vae.train()

    torch.save(vae.state_dict(), temp / "model_final.pt")
    _run.add_artifact(temp / "model_final.pt")
    vae.optimizer.save(temp / "optimizer.pt")
    _run.add_artifact(temp / "optimizer.pt")
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser(description='RNN')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='Input batch size for training (default=100)')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=20,
                        help='Number of epochs to train (default=20)')
    parser.add_argument('--latent_dim',
                        type=int,
                        default=32,
                        help='Dimension of latent space (default=32)')
    parser.add_argument('--seq_len',
                        type=int,
                        default=1000,
                        help='Length of sequences for learning (default=1000)')
    parser.add_argument('--action_dim',
                        type=int,
                        default=3,
                        help='Dimension of action space (default=3)')
    parser.add_argument('--rnn_hidden_dim',
                        type=int,
                        default=256,
                        help='Dimension of RNN hidden state (default=256)')
    parser.add_argument(
        '--n_gaussians',
        type=int,
        default=5,
        help='Number of gaussians for the Mixture Density Network (default=5)')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-3,
                        help='Learning rate for optimizer (default=1e-3)')
    parser.add_argument('--grad_clip',
                        type=float,
                        default=1.0,
                        help='Gradient clipping value (default=1.0)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--vae_fname', help='VAE model file name')
    parser.add_argument('--train_dir_name',
                        help='Rollouts directory name for training')
    parser.add_argument('--test_dir_name',
                        help='Rollouts directory name for testing')
    parser.add_argument('--log_interval',
                        nargs='?',
                        default='2',
                        type=int,
                        help='After how many epochs to log')
    args = parser.parse_args()

    # TODO: is there a better way to do this?
    if not os.path.exists(
            os.path.join(DATA_DIR, 'rollouts', args.train_dir_name)):
        print("Folder {} does not exist.".format(args.train_dir_name))
        pass
    if not os.path.exists(
            os.path.join(DATA_DIR, 'rollouts', args.test_dir_name)):
        print("Folder {} does not exist.".format(args.test_dir_name))
        pass

    # Read in and prepare the data.
    train_dataset = RolloutDataset(
        path_to_dir=os.path.join(DATA_DIR, 'rollouts', args.train_dir_name),
        size=int(args.train_dir_name.split('_')[-1]))  # TODO: hack. fix?
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    test_dataset = RolloutDataset(
        path_to_dir=os.path.join(DATA_DIR, 'rollouts', args.test_dir_name),
        size=int(args.test_dir_name.split('_')[-1]))  # TODO: hack. fix?
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    # Use GPU if available.
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Load the VAE model from file.
    vae = VAE(latent_dim=args.latent_dim)
    vae.load_state_dict(
        torch.load(os.path.join(DATA_DIR, 'vae', args.vae_fname)))
    vae.to(device)

    # Set up the MDNRNN model and the optimizer.
    mdnrnn = MDNRNN(action_dim=args.action_dim,
                    hidden_dim=args.rnn_hidden_dim,
                    latent_dim=args.latent_dim,
                    n_gaussians=args.n_gaussians).to(device)
    optimizer = optim.Adam(params=mdnrnn.parameters(), lr=args.learning_rate)

    # Train procedure.
    def train(epoch):
        mdnrnn.train()
        train_loss = 0
        start_time = datetime.datetime.now()
        for batch_id, batch in enumerate(train_loader):
            obs_batch = batch['obs'].to(device)
            act_batch = batch['act'].to(device)

            optimizer.zero_grad()

            # Encode obs using VAE.
            vae_obs_batch = obs_batch.view(
                (-1, ) + obs_batch.size()[2:])  # Reshape for VAE.
            z_batch = vae.reparameterize(*vae.encode(vae_obs_batch))
            z_batch = z_batch.view(-1, args.seq_len, args.latent_dim)

            # Predict all but first encoded obs from all but last encoded obs and action.
            targets = z_batch[:, 1:]
            z_batch = z_batch[:, :-1]
            act_batch = act_batch[:, :-1]

            pi, mu, sigma, _ = mdnrnn(act_batch, z_batch)

            loss = nll_gmm_loss(targets, pi, mu, sigma)
            loss.backward()
            train_loss += loss.item()

            torch.nn.utils.clip_grad_value_(mdnrnn.parameters(),
                                            args.grad_clip)
            optimizer.step()

            if batch_id % args.log_interval == 0:
                print(
                    'Epoch: {0:}\t| Examples: {1:}/{2:} ({3:.0f}%)\t| Loss: {4:.2f}\t'
                    .format(epoch, (batch_id + 1) * len(obs_batch),
                            len(train_loader.dataset),
                            100. * (batch_id + 1) / len(train_loader),
                            loss.item() / len(obs_batch)))

        duration = datetime.datetime.now() - start_time
        print(
            'Epoch {} average train loss was {:.4f} after {}m{}s of training.'.
            format(epoch, train_loss / len(train_loader.dataset),
                   *divmod(int(duration.total_seconds()), 60)))

    # Test procedure.
    def test(epoch):
        mdnrnn.eval()
        test_loss = 0
        with torch.no_grad():
            for batch_id, batch in enumerate(test_loader):
                obs_batch = batch['obs'].to(device)
                act_batch = batch['act'].to(device)

                # Encode obs using VAE.
                vae_obs_batch = obs_batch.view(
                    (-1, ) + obs_batch.size()[2:])  # Reshape for VAE.
                z_batch = vae.reparameterize(*vae.encode(vae_obs_batch))
                z_batch = z_batch.view(-1, args.seq_len, args.latent_dim)

                # Predict all but first encoded obs from all but last encoded obs and action.
                targets = z_batch[:, 1:]
                z_batch = z_batch[:, :-1]
                act_batch = act_batch[:, :-1]

                pi, mu, sigma, _ = mdnrnn(act_batch, z_batch)

                test_loss += nll_gmm_loss(targets, pi, mu, sigma).item()
            print('Epoch {} average test loss was {:.4f}.'.format(
                epoch, test_loss / len(test_loader.dataset)))

    # Train/test loop.
    for i in range(1, args.n_epochs + 1):
        train(i)
        test(i)

    # Save the learned model.
    if not os.path.exists(os.path.join(DATA_DIR, 'rnn')):
        os.makedirs(os.path.join(DATA_DIR, 'rnn'))

    torch.save(
        mdnrnn.state_dict(),
        os.path.join(
            DATA_DIR, 'rnn',
            datetime.datetime.today().isoformat() + '_' + str(args.n_epochs)))
Esempio n. 7
0
def eval_dmm(experiment_dir):
    experiment_dir = Path(experiment_dir)
    config = cp.ConfigParser()
    config.read(experiment_dir / "config.ini")
    with open(experiment_dir / "metrics.json", "rb") as fptr:
        metrics = json.load(fptr)

    # load dataset
    logger.info(f"Loading dataset")
    batch_size = config["vae-eval"].getint("batch_size")

    wildfire_dataset = WildFireDataset(train=True,
                                       config_file=experiment_dir /
                                       "config.ini")
    from torch.utils.data import DataLoader
    data_loader = DataLoader(wildfire_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=1)

    plot_epoch(experiment_dir,
               np.array(metrics['elbo']['values']),
               "ELBO",
               ylim=(-10, 0))
    for f in ['alpha', 'beta', 'inferred_mean', 'inferred_std']:
        plot_epoch(experiment_dir, metrics[f]['values'], f)

    logger.info(f"Loading model")
    with open(experiment_dir / "vae_config.json", "rb") as fptr:
        vae_config = json.load(
            fptr, object_hook=lambda dct: VAEConfig(**dct))  # type:VAEConfig

    vae = VAE(vae_config)
    vae.load_state_dict(torch.load(experiment_dir / "model_final.pt"))

    z_loc, _ = get_latent(vae, data_loader)
    plot_tsne(z_loc, wildfire_dataset, max_samples=1000)
    plt.savefig(experiment_dir / f"tsne_final_train.png")
    plt.close()
    plot_latent(z_loc, experiment_dir, wildfire_dataset)

    f_12, f_24 = make_forecast(vae, wildfire_dataset, data_loader)
    mse_12, mse_24 = eval_mse(f_12, f_24, wildfire_dataset[:].viirs[:,
                                                                    5, :, :],
                              wildfire_dataset[:].viirs[:, 6, :, :])

    metrics["mse_12_train"] = float(mse_12)
    metrics["mse_24_train"] = float(mse_24)
    logger.info(f"MSE (+12HR): {mse_12:.3f}")
    logger.info(f"MSE (+24HR): {mse_24:.3f}")
    threshold = 0.5
    iou_12, iou_24 = eval_jaccard(f_12,
                                  f_24,
                                  wildfire_dataset[:].viirs[:, 5, :, :],
                                  wildfire_dataset[:].viirs[:, 6, :, :],
                                  threshold=threshold)
    metrics["iou_12_train"] = float(iou_12)
    metrics["iou_24_train"] = float(iou_24)
    logger.info(f"IOU (+12HR): {iou_12:.3f}")
    logger.info(f"IOU (+24HR): {iou_24:.3f}")

    plot_forecast(f_12, f_24, experiment_dir / "train", wildfire_dataset)

    # on test set
    wildfire_dataset = WildFireDataset(train=False,
                                       config_file=experiment_dir /
                                       "config.ini")
    from torch.utils.data import DataLoader
    data_loader = DataLoader(wildfire_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=1)

    z_loc, _ = get_latent(vae, data_loader)
    plot_tsne(z_loc, wildfire_dataset, max_samples=1000)
    plt.savefig(experiment_dir / f"tsne_final_test.png")
    plt.close()

    f_12, f_24 = make_forecast(vae, wildfire_dataset, data_loader)
    mse_12, mse_24 = eval_mse(f_12, f_24, wildfire_dataset[:].viirs[:,
                                                                    5, :, :],
                              wildfire_dataset[:].viirs[:, 6, :, :])
    metrics["mse_12_test"] = float(mse_12)
    metrics["mse_24_test"] = float(mse_24)
    logger.info(f"MSE (+12HR): {mse_12:.3f}")
    logger.info(f"MSE (+24HR): {mse_24:.3f}")

    iou_12, iou_24 = eval_jaccard(f_12,
                                  f_24,
                                  wildfire_dataset[:].viirs[:, 5, :, :],
                                  wildfire_dataset[:].viirs[:, 6, :, :],
                                  threshold=threshold)
    metrics["iou_12_test"] = float(iou_12)
    metrics["iou_24_test"] = float(iou_24)
    logger.info(f"IOU (+12HR): {iou_12:.3f}")
    logger.info(f"IOU (+24HR): {iou_24:.3f}")

    plot_forecast(f_12, f_24, experiment_dir / "test", wildfire_dataset)

    with open(experiment_dir / "metrics.json", "w") as fptr:
        json.dump(metrics, fptr)