def iw_log_p_x_dataset(data_loader,
                       model=None,
                       path=None,
                       n_samples=600,
                       n_chunks=3,
                       verbose=False,
                       ddp=False,
                       device_name="cuda:0",
                       max_batches=-1):
    if model is None and path is None:
        print(
            "Either provide a model, or a checkpoint path. Not neither. Aborting."
        )
        quit()

    if path is not None:
        print("Loading a model because provided a path {}.".format(path))
        model = load_from_checkpoint(path,
                                     world_master=True,
                                     ddp=ddp,
                                     device_name=device_name,
                                     evaluation=True,
                                     return_loss_term_manager=False)

    N = len(data_loader) if max_batches < 0 else max_batches
    print("N", N)

    log_p_xs = []
    sent_lens = []  # handy for perplexity
    for batch_i, batch in enumerate(data_loader):
        if verbose is True:
            print("*" * 40)
            print(f"{batch_i + 1:3d}/{N}")
            print("*" * 40)
        batch = transfer_batch_to_device(batch, device_name=device_name)

        with torch.no_grad():
            log_p_x = iw_log_p_x(model,
                                 batch,
                                 verbose=verbose,
                                 n_chunks=n_chunks,
                                 n_samples=n_samples)
            sent_lens.append(batch["attention_mask"].sum(dim=1))
            log_p_xs.append(log_p_x)

        if batch_i + 1 == N:
            break

    log_likelihood = torch.cat(log_p_xs, dim=0).cpu()
    sent_lens = torch.cat(sent_lens, dim=0).cpu()
    log_likelihood_p_w = log_likelihood / sent_lens

    return log_likelihood, log_likelihood_p_w, sent_lens
def evaluation_function(device_rank, run_name, model_path, max_batches,
                        result_dir_path, batch_size, dataset_name, objective,
                        world_size, num_workers):
    # Prepare some variables & result directory
    device_name = f"cuda:{device_rank}"
    latent_size = 32 if "latent32" in model_path else 64
    result_dir = Path(result_dir_path)
    os.makedirs(result_dir, exist_ok=True)

    result_file = result_dir / f"{device_name}_{run_name}_max_batches_{max_batches}.pickle"

    if os.path.isfile(result_file):

        print('_' * 80)
        print('_' * 80)
        print("Have done this one already!")
        print('_' * 80)
        print('_' * 80)

    else:

        print("-" * 30)
        print("run_name:", run_name)
        print("batch size:", batch_size)
        print("max_batches:", max_batches)
        print("latent size:", latent_size)
        print("device name:", device_name)
        print("-" * 30)

        # Get model
        # vae_model = #(path=model_path, device_name=device_name)
        vae_model = load_from_checkpoint(
            path=model_path,
            device_name=device_name,
            latent_size=latent_size,
            do_tie_embedding_spaces=True,
            add_decoder_output_embedding_bias=False,
            do_tie_weights=True,
            add_latent_via_embeddings=False,
            add_latent_via_memory=True,
            objective=objective,
            evaluation=True)
        vae_model = vae_model.to(device_name)

        # Get distributed validation data loader of PTB data set
        loader = get_dist_validation_loader(batch_size=batch_size,
                                            num_workers=num_workers,
                                            max_seq_len=64,
                                            world_size=world_size,
                                            dataset_name=dataset_name,
                                            tokenizer_name="roberta",
                                            device_name=device_name,
                                            gpu_rank=device_rank)

        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=world_size,
                                rank=device_rank)

        # Seed everything
        seed_everything(0)

        print(f"Len data loader on device {device_name}: {len(loader)}")
        N = max_batches if max_batches > 0 else len(loader)

        results = {}

        for batch_i, batch in enumerate(loader):
            print(f"{batch_i:3d}/{N} - {device_name}")
            batch = transfer_batch_to_device(batch, device_name=device_name)

            with torch.no_grad():
                out = vae_model(input_ids=batch["input_ids"],
                                attention_mask=batch["attention_mask"],
                                auto_regressive=False,
                                return_latents=False,
                                return_mu_logvar=False,
                                return_exact_match=True,
                                return_cross_entropy=True,
                                return_reconstruction_loss=True,
                                return_posterior_stats=True,
                                reduce_seq_dim_ce="mean",
                                reduce_seq_dim_exact_match="mean",
                                reduce_batch_dim_exact_match="mean",
                                reduce_batch_dim_ce="mean")

                for k, v in out.items():
                    if k not in results:
                        if torch.is_tensor(v):
                            results[k] = [v.item()]
                        else:
                            results[k] = [v]
                    else:
                        if torch.is_tensor(v):
                            results[k].append(v.item())
                        else:
                            results[k].append(v)

                if batch_i + 1 == max_batches:
                    break

            # Dump the results for this device
            pickle.dump(results, open(result_file, "wb"))
def summary_statistics(path,
                       run_name,
                       data_loader,
                       max_batches=-1,
                       device="cuda:0",
                       result_folder="result-files",
                       result_file=None):

    os.makedirs(f"{result_folder}/{run_name}", exist_ok=True)

    if result_file is None:
        result_file = f"{result_folder}/{run_name}/sum_stats_{run_name}.pth"

    # Make a loss term manager from checkpoint (includes the model)
    loss_term_manager = load_from_checkpoint(path,
                                             world_master=True,
                                             ddp=False,
                                             dataset_size=len(data_loader),
                                             device_name=device,
                                             evaluation=True,
                                             return_loss_term_manager=True)

    # Set to VAE standard objective
    loss_term_manager.objective = "vae"

    results = {}
    N = max_batches if max_batches > 0 else len(data_loader)

    for batch_i, batch in enumerate(data_loader):
        print("Batch {:3d}/{:3d}".format(batch_i + 1, N), end="\r")

        with torch.no_grad():
            batch = transfer_batch_to_device(batch, device)

            if "decoderOnly" in path:
                decoder_only = True
            else:
                decoder_only = False

            out = loss_term_manager(input_ids=batch["input_ids"],
                                    attention_mask=batch["attention_mask"],
                                    return_exact_match=True,
                                    return_reconstruction_loss=True,
                                    decoder_only=decoder_only,
                                    return_posterior_stats=True,
                                    device_name=device,
                                    return_cross_entropy=False,
                                    reduce_seq_dim_ce="mean",
                                    reduce_batch_dim_ce="mean",
                                    reduce_seq_dim_exact_match="mean",
                                    reduce_batch_dim_exact_match="mean",
                                    train=False)

            for k, v in out.items():
                if torch.is_tensor(v) and v.dim() == 0:
                    x = v.item()
                else:
                    x = v

                if k in results:
                    results[k].append(x)
                else:
                    results[k] = [x]

            if batch_i + 1 == max_batches:
                break

    results_cat = {}
    for k, v in results.items():
        if torch.is_tensor(v[0]):
            results_cat[k] = torch.cat(v, dim=0)
        else:
            results_cat[k] = v

    dump_pickle(results_cat, result_file)
def iw_log_p_x_generated(model=None,
                         path=None,
                         n_batches=10,
                         batch_size=64,
                         n_samples=600,
                         n_chunks=3,
                         verbose=False,
                         ddp=False,
                         device_name="cuda:0",
                         max_seq_len_gen=64):

    if model is None and path is None:
        print(
            "Either provide a model, or a checkpoint path. Not neither. Aborting."
        )
        quit()

    if path is not None:
        print("Loading a model because provided a path {}.".format(path))
        model = load_from_checkpoint(path,
                                     world_master=True,
                                     ddp=ddp,
                                     device_name=device_name,
                                     evaluation=True,
                                     return_loss_term_manager=False)

    log_p_xs, log_p_x_ws = [], []

    for batch_i in range(n_batches):
        if verbose:
            print(f"Batch {batch_i}/{n_batches}")

        with torch.no_grad():
            # Sample from the model by decoding from prior auto-regressively with sampling
            out = model(
                return_reconstruction_loss=False,
                return_posterior_stats=False,
                auto_regressive=True,
                max_seq_len=max_seq_len_gen,
                return_predictions=True,
                nucleus_sampling=True,
                top_k=0,  # no filtering
                top_p=1.0,  # no filtering
                decode_sample_from_prior=True,
                n_prior_samples=batch_size,
                device_name=device_name)

            padded_predictions, mask, lens = make_batch_from_model_samples(
                out["predictions"])

            batch = dict(input_ids=padded_predictions.to(device_name),
                         attention_mask=mask.to(device_name))

            log_p_x = iw_log_p_x(model,
                                 batch,
                                 n_samples=n_samples,
                                 n_chunks=n_chunks,
                                 verbose=True).cpu()
            log_p_x_w = log_p_x / lens

            log_p_xs.append(log_p_x)
            log_p_x_ws.append(log_p_x_w)

    log_p_xs = torch.cat(log_p_xs)
    log_p_x_ws = torch.cat(log_p_x_ws)

    return log_p_xs, log_p_x_ws, lens
def acc_drop_over_relative_seq_len(data_loader,
                                   model=None,
                                   path=None,
                                   device="cuda:0",
                                   max_batches=-1,
                                   N_bins=30):
    N = max_batches if max_batches > 0 else len(data_loader)
    assert not (model is None
                and path is None), "Either supply model or a path. Aborting."

    if path is not None and model is None:
        model = load_from_checkpoint(path,
                                     world_master=True,
                                     ddp=False,
                                     device_name=device,
                                     evaluation=True)

    prior_accs = []
    post_accs = []
    masks = []

    for batch_i, batch in enumerate(data_loader):
        print("Batch {:3d}/{:3d}".format(batch_i + 1, N), end="\r")

        # save mask
        labels = batch["input_ids"][:, 1:].contiguous()  # skip <s> token
        label_mask = (labels != 1).float()  # pad token is int 1
        masks.append(label_mask)

        # transfer batch to device
        batch = transfer_batch_to_device(batch, device)

        # save acc stats of experiments for batch
        for decode_prior_samples in [True, False]:
            with torch.no_grad():
                preds = model(input_ids=batch["input_ids"],
                              attention_mask=batch["attention_mask"],
                              auto_regressive=False,
                              max_seq_len=64,
                              return_exact_match=True,
                              return_cross_entropy=False,
                              return_reconstruction_loss=False,
                              return_posterior_stats=False,
                              reduce_seq_dim_ce="mean",
                              reduce_seq_dim_exact_match="none",
                              reduce_batch_dim_exact_match="none",
                              reduce_batch_dim_ce="none",
                              nucleus_sampling=False,
                              top_k=0,
                              top_p=1.0,
                              decode_sample_from_prior=decode_prior_samples,
                              n_prior_samples=batch["input_ids"].shape[0],
                              device_name=device)

            if decode_prior_samples is True:
                prior_accs.append(preds["exact_match"].cpu())
            else:
                post_accs.append(preds["exact_match"].cpu())

        if (batch_i + 1) == max_batches:
            break

    prior_accs = cat_pad_uneven(prior_accs, pad_value=0)
    post_accs = cat_pad_uneven(post_accs, pad_value=0)
    masks = cat_pad_uneven(masks, pad_value=0)
    seq_lens = masks.sum(dim=1)

    n_samples, max_len = prior_accs.shape
    positions = torch.arange(1, max_len + 1).unsqueeze(0).repeat(n_samples, 1)
    relative_positions = positions / seq_lens.unsqueeze(1)

    prior_accs_masked = torch.masked_select(prior_accs, masks == 1.0)
    post_accs_masked = torch.masked_select(post_accs, masks == 1.0)
    acc_drops = post_accs_masked - prior_accs_masked
    relative_positions_masked = torch.masked_select(relative_positions,
                                                    masks == 1.0)

    bin_means, bin_edges, bin_ids = stats.binned_statistic(
        relative_positions_masked.tolist(),
        acc_drops.tolist(),
        statistic='mean',
        bins=N_bins)

    return_dict = dict(bin_means=bin_means,
                       bin_edges=bin_edges,
                       acc_drops=acc_drops,
                       prior_accs=prior_accs_masked,
                       posterior_accs=post_accs_masked)

    return return_dict
예제 #6
0
def dist_iw_log_likelihood_x_obs_x_gen(device_rank, run_name, model_path,
                                       max_batches, result_dir_path,
                                       batch_size, dataset_name, world_size,
                                       num_workers, n_samples, n_chunks,
                                       max_seq_len_gen):
    # Prepare some variables & result directory
    device_name = f"cuda:{device_rank}"

    result_dir = Path(result_dir_path) / run_name
    os.makedirs(result_dir, exist_ok=True)

    # single GPU file
    result_file = result_dir / f"{device_name}_{run_name}_max_batches_{max_batches}.pickle"

    # combined for all GPUs
    full_result_file = result_dir / f"{run_name}_world_size_{world_size}_max_batches_{max_batches}_" \
                                    f"batch_size_{batch_size}_n_samples_{n_samples}.pickle"

    if os.path.isfile(result_file) or os.path.isfile(full_result_file):

        print('_' * 80)
        print('_' * 80)
        print("Have done this one already!")
        print('_' * 80)
        print('_' * 80)

    else:

        print("-" * 30)
        print("run_name:", run_name)
        print("batch size:", batch_size)
        print("max_batches:", max_batches)
        print("device name:", device_name)
        print("-" * 30)

        # Get model
        vae_model = load_from_checkpoint(model_path,
                                         world_master=True,
                                         ddp=False,
                                         device_name=device_name,
                                         evaluation=True,
                                         return_loss_term_manager=False)

        # Get distributed validation data loader of PTB data set
        valid_loader = get_dist_validation_loader(
            batch_size=batch_size,
            num_workers=num_workers,
            max_seq_len=64,
            world_size=world_size,
            dataset_name=dataset_name,
            tokenizer_name="roberta",
            device_name=device_name,
            gpu_rank=device_rank,
            train_validation="validation")

        train_loader = get_dist_validation_loader(batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  max_seq_len=64,
                                                  world_size=world_size,
                                                  dataset_name=dataset_name,
                                                  tokenizer_name="roberta",
                                                  device_name=device_name,
                                                  gpu_rank=device_rank,
                                                  train_validation="train")

        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=world_size,
                                rank=device_rank)

        # Seed everything
        seed_everything(0)

        N_valid = max_batches if max_batches > 0 else len(valid_loader)
        N_train = max_batches if max_batches > 0 else len(train_loader)
        N = max([N_valid, N_train])

        print(f"N_valid {N_valid} N_train {N_train} N {N}")

        with torch.no_grad():

            log_p_x_obs_valid, log_p_x_w_obs_valid, lens_obs_valid = iw_log_p_x_dataset(
                valid_loader,
                model=vae_model,
                path=None,
                n_samples=n_samples,
                n_chunks=n_chunks,
                verbose=True,
                ddp=False,
                device_name=device_name,
                max_batches=N_valid)

            log_p_x_obs_train, log_p_x_w_obs_train, lens_obs_train = iw_log_p_x_dataset(
                train_loader,
                model=vae_model,
                path=None,
                n_samples=n_samples,
                n_chunks=n_chunks,
                verbose=True,
                ddp=False,
                device_name=device_name,
                max_batches=N_train)

            log_p_x_gen, log_p_x_w_gen, lens_gen = iw_log_p_x_generated(
                model=vae_model,
                path=None,
                n_batches=N,
                batch_size=batch_size,
                n_samples=n_samples,
                n_chunks=n_chunks,
                verbose=True,
                ddp=False,
                device_name=device_name,
                max_seq_len_gen=max_seq_len_gen)

        results = dict(log_p_x_obs_valid=log_p_x_obs_valid.cpu(),
                       log_p_x_w_obs_valid=log_p_x_w_obs_valid.cpu(),
                       lens_obs_valid=lens_obs_valid.cpu(),
                       log_p_x_obs_train=log_p_x_obs_train.cpu(),
                       log_p_x_w_obs_train=log_p_x_w_obs_train.cpu(),
                       lens_obs_train=lens_obs_train.cpu(),
                       log_p_x_gen=log_p_x_gen.cpu(),
                       log_p_x_w_gen=log_p_x_w_gen.cpu(),
                       lens_gen=lens_gen.cpu())

        # Dump the results for this device
        pickle.dump(results, open(result_file, "wb"))
    # Calculate MI bounds for these models
    mutual_information_results = {}
    for name, path in run_names_paths_to_evaluate:

        vae_model = get_model_on_device(device_name=DEVICE_NAME,
                                        latent_size=768,
                                        gradient_checkpointing=False,
                                        add_latent_via_memory=True,
                                        add_latent_via_embeddings=True,
                                        do_tie_weights=True,
                                        world_master=True)

        _, _, vae_model, _, _, _, _ = load_from_checkpoint(vae_model,
                                                           path,
                                                           world_master=True,
                                                           ddp=False,
                                                           use_amp=False)

        mi_results = calc_all_mi_bounds(vae_model,
                                        VALID_LOADER,
                                        device_name=DEVICE_NAME,
                                        max_batches=MAX_BATCHES,
                                        batch_size=BATCH_SIZE)

        mutual_information_results[name] = mi_results

    prefix = "/home/cbarkhof/code-thesis/NewsVAE/evaluation/29DEC/"
    pickle_filename = "29DEC-mutual-information-results.p"
    pickle_path = prefix + pickle_filename
예제 #8
0
def train(device_rank, config, run_name):
    print("**** DEVICE: ", device_rank)

    # Device
    device_name = utils_train.set_device(device_rank)

    # Determine world size and whether this device is world master
    world_master, world_size = utils_train.get_world_specs(config.n_gpus, config.n_nodes, device_name)

    # Determine the maximum number of steps for this device
    global_max_steps, global_max_grad_steps = utils_train.determine_global_max_steps(config.max_global_train_steps,
                                                                                     config.batch_size, world_size,
                                                                                     config.accumulate_n_batches_grad)

    # Initiate process group and specify backend configurations
    if config.ddp:
        if world_master: print("Init process group...")
        if world_master: print(f"--> CPU count {multiprocessing.cpu_count()}")
        dist.init_process_group(backend='nccl', init_method='env://',
                                world_size=int(config.n_gpus * config.n_nodes), rank=device_rank)

    # Seed everything
    seed_everything(config.seed)

    # Data loaders / data set / samplers (if ddp)
    data_loaders, data, samplers = utils_train.get_dataloader(["train", "validation"], ddp=config.ddp,
                                                              batch_size=config.batch_size,
                                                              num_workers=config.num_workers,
                                                              max_seq_len=config.max_seq_len,
                                                              world_size=world_size, dataset_name=config.dataset_name,
                                                              tokenizer_name=config.tokenizer_name,
                                                              device_name=device_name, world_master=world_master,
                                                              gpu_rank=device_rank)

    # These are actual steps, not gradient steps, so they work in combination with global step
    max_train_steps_epoch_per_rank, max_valid_steps_epoch_per_rank = utils_train.determine_max_epoch_steps_per_rank(
        config.max_train_steps_epoch_per_rank, config.max_valid_steps_epoch_per_rank, data.datasets,
        config.batch_size, world_size=world_size, world_master=world_master)
    max_epochs = config.max_epochs if config.max_epochs > 0 else 100
    config.max_train_steps_epoch_per_rank = max_train_steps_epoch_per_rank  # overwrite this
    config.max_valid_steps_epoch_per_rank = max_valid_steps_epoch_per_rank  # overwrite this

    print("*"*80)
    print("config.max_train_steps_epoch_per_rank", config.max_train_steps_epoch_per_rank)
    print("*" * 80)

    # Get model and loss term manager
    dataset_size = data.datasets['train'].shape[0]
    if config.load_from_checkpoint:
        assert os.path.isfile(config.checkpoint_file), f"checkpoint file does not exists: {config.checkpoint_file}"
        loss_term_manager = utils_train.load_from_checkpoint(config.checkpoint_file, world_master=world_master,
                                                             ddp=config.ddp, device_name=device_name,
                                                             evaluation=False, return_loss_term_manager=True,
                                                             loss_term_manager_config=config)
    else:
        loss_term_manager = vae.get_loss_term_manager_with_model(config, world_master=world_master,
                                                                 dataset_size=dataset_size, device_name=device_name)

    autoencoder = False
    if config.objective == "beta-vae" and config.b_vae_beta_constant_linear_lagrangian == "constant" and config.b_vae_beta == 0.0:
        print("** AUTO ENCODER OBJECTIVE!!")
        autoencoder = True

    # Initialise logging
    if config.logging and world_master:
        utils_train.init_logging(loss_term_manager.vae_model, run_name, config.code_dir_path,
                                 config.wandb_project, config, config.run_dir_name)

    # Set-up DDP
    if config.ddp:
        # Wrap both the model and constraints etc in a loss_term_manager nn.Module as suggested here:
        # https://discuss.pytorch.org/t/multiple-modules-with-distributed-data-parallel/115621
        loss_term_manager = torch.nn.parallel.DistributedDataParallel(loss_term_manager,
                                                                      device_ids=[device_rank],
                                                                      find_unused_parameters=False) # not needed to check
        print(f"-> Turned on DDP for device rank {device_rank}")

    # Zero grads TODO: fix this
    # loss_term_manager.zero_grad()

    # Initialise the stats to keep track of
    stats = utils_train.make_nested_dict()
    finished_training = False

    epoch, global_step, global_grad_step, not_improved_epochs = 0, 0, 0, 0
    # NB, I am not using D_ks for pareto checkpointing anymore.

    epoch_pareto_effiency_dict = utils_train.prepare_pareto_dict(config=config)
    current_efficient_epochs = []

    if world_master: print("Start or resume training!")

    # ----------------------------------------------------------------------------------------------------
    # TRAINING!
    # ----------------------------------------------------------------------------------------------------
    while not finished_training:

        print("finished_training", finished_training)

        # TRAIN, VALID
        for phase in data_loaders.keys():

            if finished_training:
                break

            if config.ddp:
                print(f"-> Setting epoch explicitly to {epoch} on device {device_name}")
                samplers[phase].set_epoch(epoch)  # needed to explicitly shuffle

            max_steps = max_train_steps_epoch_per_rank if phase == 'train' else max_valid_steps_epoch_per_rank
            atts_to_latent, masks, = [], []
            # latents = []

            for batch_i, batch in enumerate(data_loaders[phase]):
                # ----------------------------------------------------------------------------------------------------
                # TRAIN / VALIDATION STEPS
                # ----------------------------------------------------------------------------------------------------

                # SET DEVICE
                batch = utils_train.transfer_batch_to_device(batch, device_name)

                # PERFORM TRAIN / VALIDATION STEP
                if phase == 'train':
                    loss_term_manager, losses = do_train_step(
                        loss_term_manager,
                        batch, global_step,
                        use_amp=config.use_amp,
                        accumulate_n_batches_grad=config.accumulate_n_batches_grad,
                        device_name=device_name,
                        gradient_clipping=config.gradient_clipping,
                        decoder_only=config.decoder_only,
                        ddp=config.ddp)
                else:
                    # save_latents happens now outside the train loop
                    losses = do_valid_step(loss_term_manager, batch,
                                           device_name=device_name, ddp=config.ddp, decoder_only=config.decoder_only,
                                           iw_ll_n_samples=config.iw_ll_n_samples, eval_iw_ll_x_gen=config.eval_iw_ll_x_gen,
                                           max_seq_len_x_gen=config.max_seq_len_x_gen, save_latents=False)
                    # if "latent_z" in losses:
                    #     latents.append(losses["latent_z"])
                    #     del losses["latent_z"]

                    if "attention_to_latent" in losses:
                        atts_to_latent.append(losses["attention_to_latent"].cpu())
                        masks.append(batch["attention_mask"][:, 1:].cpu())
                        del losses["attention_to_latent"]

                # ----------------------------------------------------------------------------------------------------
                # INSERT STATISTICS, PRINT, LOG, CHECKPOINT
                # ----------------------------------------------------------------------------------------------------

                # INSERT STATISTICS
                stats = utils_train.insert_stats(stats, losses, epoch, phase)

                # PRINT
                if world_master and global_step % config.print_every_n_steps == 0 and config.print_stats:
                    utils_train.print_stats(stats, epoch, phase, global_step, global_max_steps,
                                            global_grad_step, global_max_grad_steps, batch_i, max_steps,
                                            config.objective)

                # LOG STEP (only if world master)
                if batch_i % config.log_every_n_steps == 0 and config.logging and world_master and phase == 'train':
                    if config.add_latent_w_matrix_influence:
                        utils_train.add_matrix_influence_weight_to_loss(loss_term_manager, global_step,
                                                                        global_grad_step, ddp=config.ddp)
                    utils_train.log_losses_step(losses, phase, epoch, config.log_every_n_steps, global_step,
                                                global_grad_step)

                # Analyse and save latents for runs with save_latents == True
                if global_step % config.save_latents_every_x_steps == 0 and config.save_latents:
                    utils_train.analyse_save_latents(data_loaders["validation"], loss_term_manager.vae_model, stats,
                                                     config.code_dir_path, config.run_dir_name, run_name,
                                                     global_step, epoch, device_name=device_name)


                # ----------------------------------------------------------------------------------------------------
                # KEEP TRACK OF STEPS (IN PHASE AND GLOBALLY)
                # ----------------------------------------------------------------------------------------------------

                # ADVANCE STEP if in train mode
                if phase == "train":
                    global_step += 1
                    if global_step % config.accumulate_n_batches_grad == 0:
                        global_grad_step += 1

                # CHECK IF EPOCH PHASE IS OVER (after advancing one)
                if batch_i >= max_steps: break
                if global_step >= global_max_steps or epoch >= max_epochs: finished_training = True; break

            # ----------------------------------------------------------------------------------------------------
            # END OF TRAIN / VALID PHASE
            # ----------------------------------------------------------------------------------------------------

            # BEST MODEL CHECKPOINT
            if phase == 'validation' and world_master:
                val_epoch_stats = stats[epoch]["validation"]

                # if len(latents) > 0:
                #     utils_train.save_latents(latents, global_step, epoch, run_name,
                #                              config.code_dir_path, config.run_dir_name)

                # Update the epoch_pareto_effiency_dict and determine efficient_epochs
                epoch_pareto_effiency_dict, efficient_epochs = utils_train.determine_pareto_checkpoint(
                    val_epoch_stats, epoch_pareto_effiency_dict, epoch, logging=config.logging,
                    decoder_only=config.decoder_only or autoencoder) # if AE, also evaluate based on -D

                # Check if anything changed, if not keep count of not improved epochs
                if efficient_epochs == current_efficient_epochs:
                    not_improved_epochs += 1
                else:
                    not_improved_epochs = 0

                current_efficient_epochs = efficient_epochs

                # Early stopping
                if (not_improved_epochs >= config.early_stop_epochs) and config.early_stopping:
                    print("*" * 50)
                    print("EARLY STOPPING!")
                    print("*" * 50)
                    finished_training = True

                # Checkpoint according to efficient_epochs, save the data
                if config.checkpoint:
                    vae_model = loss_term_manager.vae_model if config.ddp is False else loss_term_manager.module.vae_model
                    utils_train.save_checkpoint_model(vae_model, run_name, config.code_dir_path, global_step,
                                                      epoch, config, efficient_epochs, epoch_pareto_effiency_dict,
                                                      config.run_dir_name)

        # ----------------------------------------------------------------------------------------------------
        # END OF EPOCH
        # ----------------------------------------------------------------------------------------------------

        # LOG EPOCH STATS (if world master)
        if config.logging and world_master:
            print("LOG EPOCH STATS")

            utils_train.log_stats_epoch(stats, epoch, global_step, global_grad_step, atts_to_latent, masks)

        epoch += 1

    # Dump train stats and pareto stats
    path = config.code_dir_path + "/" + run_name
    pickle.dump(stats, open(path + "/stats.pickle", "wb"))
    pickle.dump(epoch_pareto_effiency_dict, open(path + "/pareto_dict.pickle", "wb"))