示例#1
0
def novelty(opt):
    """Script that computes the novelty of generated sentences."""
    # Load options, if they are stored
    try:
        opt = load_options(opt)
    except InvalidPathError as e:
        raise NoModelError(
            "Aborting testing without a valid model to load.") from e
    except Error as e:
        warn(
            "{}\n Make sure all preset arguments coincide with the model you are loading."
            .format(e))

    # Set device so script works on both GPU and CPU
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    word_to_idx, idx_to_word = load_word_index_maps(opt)
    opt.N = 0
    decoder = initialize_model(opt, word_to_idx)

    # Model loading is mandatory when testing, otherwise the tests will not be executed
    decoder = load_model(opt, decoder)

    with torch.no_grad():
        # Novelty is inverse TER of a generated sentence compared with the training corpus
        novelty, full_scores = compute_novelty(
            get_samples(opt, decoder, idx_to_word, word_to_idx),
            osp.join(opt.data_folder, opt.train_file), opt, idx_to_word)
    vprint("Novelty: {}".format(novelty), opt.verbosity, 0)
    save_novelties(full_scores)
    return novelty
示例#2
0
def generate_data(opt):
    """Script that generates data from a (trained) generative model of language and writes it to file."""
    # Load options, if they are stored
    try:
        opt = load_options(opt)
    except InvalidPathError as e:
        raise NoModelError(
            "Aborting generating without a valid model to load.") from e
    except Error as e:
        warn(
            "{}\n Make sure all preset arguments coincide with the model you are loading."
            .format(e))

    opt.N = 0

    # Set device so script works on both GPU and CPU
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    word_to_idx, idx_to_word = load_word_index_maps(opt)

    decoder = initialize_model(opt, word_to_idx)

    # Model loading is mandatory when generating
    decoder = load_model(opt, decoder)

    decoder.eval()
    for mode in ['train', 'valid', 'test']:
        if mode == 'train':
            tot_samples = file_len(osp.join(opt.data_folder, opt.train_file))
        elif mode == 'valid':
            tot_samples = file_len(osp.join(opt.data_folder, opt.val_file))
        elif mode == 'test':
            tot_samples = file_len(osp.join(opt.data_folder, opt.test_file))

        samples = ""
        sample_indices = ""

        # We sample a user-specified number of samples n times to match the amount of data in the respective files
        for _ in range(int(tot_samples / opt.num_samples)):
            s, si = get_samples(opt, decoder, idx_to_word, word_to_idx)
            samples += s + "\n"
            sample_indices += si + "\n"

        # Here we sample the remainder of the division to exactly match the number of sequences in the data files
        opt.num_samples = int(tot_samples % opt.num_samples)
        s, si = get_samples(opt, decoder, idx_to_word, word_to_idx)
        samples += s
        sample_indices += si

        # Eventualy we save this generated data with a filename that refers to its origin (model and settings)
        save_samples(opt, samples, sample_indices, mode)
示例#3
0
def qualitative(opt):
    """Print some samples using various techniques."""
    # Load options, if they are stored
    try:
        opt = load_options(opt)
    except InvalidPathError as e:
        raise NoModelError(
            "Aborting testing without a valid model to load.") from e
    except Error as e:
        warn(
            "{}\n Make sure all preset arguments coincide with the model you are loading."
            .format(e))

    # Set device so script works on both GPU and CPU
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    # Here we construct all parts of the training ensemble; the model, dataloaders and optimizer
    word_to_idx, idx_to_word = load_word_index_maps(opt)
    word_to_idx = defaultdict(lambda: word_to_idx[opt.unk_token], word_to_idx)
    data_train, _ = initialize_dataloader(opt, word_to_idx, sort_pad_collate)

    opt.N = 0
    decoder = initialize_model(opt, word_to_idx)

    # Model loading is mandatory when testing, otherwise the tests will not be executed
    decoder = load_model(opt, decoder)

    with torch.no_grad():
        if opt.qual_file:
            with open(opt.qual_file, 'r') as f:
                sentences = f.read().split('\n')

            # Convert sentences to index tensors
            sentences = [[word_to_idx[word] for word in s.strip().split(' ')]
                         for s in sentences]
            x_len = [len(s) for s in sentences]
            max_len = max(x_len)
            pad_idx = word_to_idx[opt.pad_token]
            for s in sentences:
                s.extend([pad_idx] * (max_len - len(s)))
            sentences = torch.LongTensor(sentences).to(opt.device)
            x_len = torch.LongTensor(x_len)
        else:
            for data in data_train:
                # Select short sentences only
                sentences = data[0][(data[1] < 16) & (data[1] > 4)]
                x_len = data[1][(data[1] < 16) & (data[1] > 4)]
                sentences = sentences.to(opt.device)
                break

        if opt.model != 'deterministic':
            # Print some homotopies
            for i in range(sentences.shape[0] - 1):
                if x_len[i].item() < 2 or x_len[i + 1].item() < 2:
                    continue
                homotopy_idx = decoder.homotopies(
                    sentences[i][:x_len[i].item()].unsqueeze(0),
                    sentences[i + 1][:x_len[i + 1].item()].unsqueeze(0),
                    9,
                    torch.empty([10, 1], device=opt.device,
                                dtype=torch.long).fill_(
                                    word_to_idx[opt.sos_token]),
                    word_to_idx[opt.eos_token],
                    word_to_idx[opt.pad_token],
                    sample_softmax=opt.sample_softmax)
                homotopy = "\n".join([
                    " ".join([
                        idx_to_word[s] for s in hom
                        if s != word_to_idx[opt.pad_token]
                    ]) for hom in homotopy_idx
                ])
                homotopy_idx = "\n".join([
                    " ".join([
                        str(s) for s in hom if s != word_to_idx[opt.pad_token]
                    ]) for hom in homotopy_idx
                ])
                print("Original:\n {} - {}\n\n".format(
                    " ".join([
                        idx_to_word[s]
                        for s in sentences[i][:x_len[i].item()].tolist()
                    ]), " ".join([
                        idx_to_word[s]
                        for s in sentences[i + 1][:x_len[i +
                                                         1].item()].tolist()
                    ])))
                print("Homotopies:\n {}\n\n".format(homotopy))
                if opt.ter:
                    print("Novelties:\n")
                    _ = compute_novelty([homotopy, homotopy_idx],
                                        osp.join(opt.data_folder,
                                                 opt.train_file), opt,
                                        idx_to_word)

            # Print some posterior samples
            for i in range(sentences.shape[0]):
                if x_len[i].item() < 2:
                    continue
                pos_samples_idx = decoder.sample_posterior(
                    sentences[i][:x_len[i].item()].unsqueeze(0),
                    3,
                    word_to_idx[opt.eos_token],
                    word_to_idx[opt.pad_token],
                    sample_softmax=opt.sample_softmax)
                pos_samples = "\n".join([
                    " ".join([
                        idx_to_word[s] for s in sample
                        if s != word_to_idx[opt.pad_token]
                    ]) for sample in pos_samples_idx
                ])
                pos_samples_idx = "\n".join([
                    " ".join([
                        str(s) for s in sample
                        if s != word_to_idx[opt.pad_token]
                    ]) for sample in pos_samples_idx
                ])
                print("Original:\n {} \n\n".format(" ".join([
                    idx_to_word[s]
                    for s in sentences[i][:x_len[i].item()].tolist()
                ])))
                print("Samples:\n {} \n\n".format(pos_samples))
                if opt.ter:
                    print("Novelties:\n")
                    _ = compute_novelty([pos_samples, pos_samples_idx],
                                        osp.join(opt.data_folder,
                                                 opt.train_file), opt,
                                        idx_to_word)

            # Print some free posterior samples
            print("Free posterior samples.\n\n")
            for i in range(sentences.shape[0]):
                if x_len[i].item() < 2:
                    continue
                pos_samples_idx = decoder.sample_posterior(
                    sentences[i][:x_len[i].item()].unsqueeze(0),
                    3,
                    word_to_idx[opt.eos_token],
                    word_to_idx[opt.pad_token],
                    torch.empty([3, 1], device=opt.device,
                                dtype=torch.long).fill_(
                                    word_to_idx[opt.sos_token]),
                    "free",
                    sample_softmax=opt.sample_softmax)
                pos_samples = "\n".join([
                    " ".join([
                        idx_to_word[s] for s in sample
                        if s != word_to_idx[opt.pad_token]
                    ]) for sample in pos_samples_idx
                ])
                pos_samples_idx = "\n".join([
                    " ".join([
                        str(s) for s in sample
                        if s != word_to_idx[opt.pad_token]
                    ]) for sample in pos_samples_idx
                ])
                print("Original:\n {} \n\n".format(" ".join([
                    idx_to_word[s]
                    for s in sentences[i][:x_len[i].item()].tolist()
                ])))
                print("Samples:\n {} \n\n".format(pos_samples))
                if opt.ter:
                    print("Novelties:\n")
                    _ = compute_novelty([pos_samples, pos_samples_idx],
                                        osp.join(opt.data_folder,
                                                 opt.train_file), opt,
                                        idx_to_word)

        # Print some free prior samples
        print("Free samples:\n")
        if opt.ter:
            _ = compute_novelty(
                get_samples(opt, decoder, idx_to_word, word_to_idx),
                osp.join(opt.data_folder, opt.train_file), opt, idx_to_word)
        else:
            print(get_samples(opt, decoder, idx_to_word, word_to_idx)[0])
示例#4
0
def test(opt):
    """Script that tests a generative model of language given various user settings."""
    # Load options, if they are stored
    try:
        opt = load_options(opt)
    except InvalidPathError as e:
        raise NoModelError(
            "Aborting testing without a valid model to load.") from e
    except Error as e:
        warn(
            "{}\n Make sure all preset arguments coincide with the model you are loading."
            .format(e))

    # We test with a batch size of 1 to get exact results
    batch_size = opt.batch_size
    opt.batch_size = 1

    # Set device so script works on both GPU and CPU
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    word_to_idx, idx_to_word = load_word_index_maps(opt)

    data_test = initialize_dataloader(opt, word_to_idx, sort_pad_collate)
    opt.N = (len(data_test) - 1) * opt.batch_size
    decoder = initialize_model(opt, word_to_idx)

    # Model loading is mandatory when testing, otherwise the tests will not be executed
    decoder = load_model(opt, decoder)

    # The summary writer will log certain values for automatic visualization
    writer = SummaryWriter(
        osp.join(opt.out_folder, opt.model, opt.save_suffix, 'test'))

    # The StatsHandler object will store important stastics during testing and provides printing and logging utilities
    stats = StatsHandler(opt)

    with torch.no_grad():
        zs = []
        z_priors = []
        vars = []
        mus = []
        preds = []
        datas = []
        log_q_z_xs = []
        log_p_zs = []
        for data in data_test:
            # Catch small batches
            if data[0].shape[0] != opt.batch_size:
                opt.batch_size = data[0].shape[0]

            # Save data in a list
            datas.extend(data[0][:, 1:].tolist())

            # Prepare
            decoder.eval()
            data = [d.to(opt.device) for d in data]

            # Sample a number of log-likehoods to obtain a low-variance estimate of the model perplexity
            decoder.use_prior = True
            # Forward pass the evaluation dataset to obtain statistics
            losses, pred = decoder(data)
            stats.val_loss.append(losses["NLL"].item())
            stats.val_l2_loss.append(losses["L2"].item())
            stats.val_acc.append(compute_accuracy(pred, data).item())
            stats.val_log_loss[0].append(losses["NLL"].item())

            # We need the average length of the sequences to get a fair estimate of the perplexity per word
            if len(data) > 1:
                stats.avg_len.extend([torch.mean(data[1].float()).item() - 1])
            else:
                stats.avg_len.extend([data[0].shape[1] - 1])

            # Also sample the perplexity for the reconstruction case
            decoder.use_prior = False
            losses, pred, var, mu, z, z_prior, log_q_z_x, log_p_z = decoder(
                data, extensive=True)
            preds.extend(pred.tolist())
            zs.append(z), z_priors.append(z_prior), vars.append(
                var), mus.append(mu), log_q_z_xs.append(
                    log_q_z_x), log_p_zs.append(log_p_z)
            stats.val_rec_loss.append(losses["NLL"].item())
            stats.val_rec_kl.append(losses["KL"].item())
            stats.val_rec_elbo.append(losses["NLL"].item() +
                                      losses["KL"].item())
            stats.val_rec_min_rate.append(losses["Hinge"].item())
            stats.val_rec_l2_loss.append(losses["L2"].item())
            stats.val_rec_mmd.append(losses["MMD"].item())
            stats.val_rec_acc.append(compute_accuracy(pred, data).item())
            stats.val_rec_log_loss[0].append(losses["NLL"].item() +
                                             losses["KL"].item())
            for i in range(len(opt.constraint)):
                stats.constraints[i].append(
                    losses["Constraint_{}".format(i)].item())
                stats.lambs[i].append(losses["Lag_Weight_{}".format(i)].item())

        # Stack the collected samples and parameters
        mu = torch.cat(mus, 0)
        var = torch.cat(vars, 0)
        z_prior = torch.cat(z_priors, 0)
        z = torch.cat(zs, 0)
        log_q_z_x = torch.cat(log_q_z_xs, 0)
        log_p_z = torch.cat(log_p_zs, 0)
        avg_kl = torch.tensor(stats.val_rec_kl,
                              dtype=torch.float,
                              device=opt.device).mean()
        avg_h = log_q_z_x.mean()

        # We compute the MMD over the full validation set
        if opt.mmd:
            stats.val_rec_mmd = [decoder._mmd(z, z_prior).item()]

        if opt.mi:
            log_q_z = decoder.q_z_estimate(z, mu, var)
            stats.val_mi, stats.val_mkl = compute_mutual_information(
                z, log_p_z, avg_h, avg_kl, opt.mi_method, opt.mi_kde_method,
                log_q_z)

        # Active units are the number of dimensions in the latent space that do something
        stats.val_au = compute_active_units(mu, opt.delta)

        # BLEU is a measure of the corpus level reconstruction ability of the model
        stats.val_bleu = compute_bleu(preds, datas, word_to_idx[opt.pad_token])

        if opt.ter:
            print("Computing TER....")
            # TER is another measure of the corpus level reconstruction ability of the model
            stats.val_ter = compute_ter(preds, datas,
                                        word_to_idx[opt.pad_token])

            print("Computing Novelty....")
            # Novelty is minimum TER of a generated sentence compared with the training corpus
            stats.val_novelty, _ = compute_novelty(
                get_samples(opt, decoder, idx_to_word, word_to_idx)[1],
                osp.join(opt.data_folder, opt.train_file), opt, idx_to_word)

        if opt.log_likelihood and opt.model != 'deterministic':
            repeat = max(int(opt.ll_samples / opt.ll_batch), 1)
            opt.ll_batch = opt.ll_samples if opt.ll_samples < opt.ll_batch else opt.ll_batch
            normalizer = torch.log(
                torch.tensor(int(opt.ll_samples / opt.ll_batch) * opt.ll_batch,
                             device=opt.device,
                             dtype=torch.float))
            stats.val_nll = []
            for i, data in enumerate(data_test):
                if i % 100 == 0:
                    vprint(
                        "\r At {:.3f} percent of log likelihood estimation\r".
                        format(float(i) / len(data_test) * 100),
                        opt.verbosity,
                        1,
                        end="")
                nelbo = []
                for r in range(repeat):
                    data = [
                        d.expand(opt.ll_batch,
                                 *d.size()[1:]).to(opt.device) for d in data
                    ]
                    losses, _ = decoder(data, True)
                    nelbo.append(losses["NLL"] + losses["KL"])
                nelbo = torch.cat(nelbo, 0)
                stats.val_nll.append(-(torch.logsumexp(-nelbo -
                                                       normalizer, 0)).item())
            stats.val_nll = np.array(stats.val_nll).mean()
            stats.val_est_ppl = compute_perplexity([stats.val_nll],
                                                   stats.avg_len)[0]

    perplexity, variance = compute_perplexity(stats.val_log_loss,
                                              stats.avg_len)
    rec_ppl, rec_var = compute_perplexity(stats.val_rec_log_loss,
                                          stats.avg_len)

    stats.val_ppl.append(perplexity)
    stats.val_rec_ppl.append(rec_ppl)
    stats.val_ppl_std.append(variance)
    stats.val_rec_ppl_std.append(rec_var)

    vprint(stats, opt.verbosity, 0)
    stats.log_stats(writer)
    writer.close()

    opt.batch_size = batch_size

    if opt.log_likelihood:
        return stats.val_nll + compute_free_bits_from_stats(stats, opt), stats
    else:
        return stats.val_rec_elbo + compute_free_bits_from_stats(stats,
                                                                 opt), stats
示例#5
0
def train(opt):
    """Script that trains a generative model of language given various user settings."""
    # Try to load options when we resume
    if opt.resume:
        try:
            opt = load_options(opt)
        except InvalidPathError as e:
            warn("{}\n Starting from scratch...".format(e))
            opt.resume = 0
            epoch = 0
        except Error as e:
            warn(
                "{}\n Make sure all preset arguments coincide with the model you are loading."
                .format(e))
    else:
        epoch = 0

    # Set device so script works on both GPU and CPU
    seed(opt)
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    word_to_idx, idx_to_word = load_word_index_maps(opt)

    # Here we construct all parts of the training ensemble; the model, dataloaders and optimizer
    data_train, data_eval = initialize_dataloader(opt, word_to_idx,
                                                  sort_pad_collate)
    opt.N = (len(data_train) - 1) * opt.batch_size
    decoder = initialize_model(opt, word_to_idx)
    optimizers = []
    if opt.sparse:
        sparse_parameters = [
            p[1] for p in filter(lambda p: p[0] == "emb.weight",
                                 decoder.named_parameters())
        ]
        parameters = [
            p[1] for p in filter(
                lambda p: p[1].requires_grad and p[0] != "emb.weight",
                decoder.named_parameters())
        ]
        optimizers.append(Adam(parameters, opt.lr))
        optimizers.append(SparseAdam(sparse_parameters, opt.lr))
    elif opt.lagrangian:
        lag_parameters = [
            p[1] for p in filter(lambda p: p[0] == "lag_weight",
                                 decoder.named_parameters())
        ]
        parameters = [
            p[1] for p in filter(
                lambda p: p[1].requires_grad and p[0] != "lag_weight.weight",
                decoder.named_parameters())
        ]
        optimizers.append(Adam(parameters, opt.lr))
        optimizers.append(RMSprop(lag_parameters, opt.lr))
    else:
        parameters = filter(lambda p: p.requires_grad, decoder.parameters())
        optimizers.append(Adam(parameters, opt.lr))

    # Load from checkpoint
    if opt.resume:
        decoder, optimizers, epoch = load_checkpoint(opt, decoder, optimizers)

    # The SummaryWriter will log certain values for automatic visualization
    writer = SummaryWriter(
        osp.join(opt.out_folder, opt.model, opt.save_suffix, 'train'))

    # The StatsHandler object will store important stastics during training and provides printing and logging utilities
    stats = StatsHandler(opt)

    # We will early stop the network based on user specified criteria
    early_stopping = False
    stop_ticks = 0
    prev_crit = [np.inf] * len(opt.criteria)

    while not early_stopping:
        # We reset the stats object to collect fresh stats for every epoch
        stats.reset()
        epoch += 1
        stats.epoch = epoch

        start = time.time()
        for data in data_train:
            # We zero the gradients BEFORE the forward pass, instead of before the backward, to save some memory
            [optimizer.zero_grad() for optimizer in optimizers]

            # We skip the remainder batch
            if data[0].shape[0] != opt.batch_size:
                continue

            # Prepare
            decoder.train()
            decoder.use_prior = False
            data = [d.to(opt.device) for d in data]

            # Forward
            losses, pred = decoder(data)
            loss = sum([
                v for k, v in losses.items()
                if "Lag_Weight" not in k and "Constraint_" not in k
            ])

            # Log the various losses the models can return, and accuracy
            stats.train_loss.append(losses["NLL"].item())
            stats.train_kl.append(losses["KL"].item())
            stats.train_elbo.append(losses["NLL"].item() + losses["KL"].item())
            stats.train_min_rate.append(losses["Hinge"].item())
            stats.train_l2_loss.append(losses["L2"].item())
            stats.train_mmd.append(losses["MMD"].item())
            stats.train_acc.append(compute_accuracy(pred, data).item())
            for i in range(len(opt.constraint)):
                stats.constraints[i].append(
                    losses["Constraint_{}".format(i)].item())
                stats.lambs[i].append(losses["Lag_Weight_{}".format(i)].item())
            del data

            loss.backward()

            # Check for bad gradients
            nan = False
            if opt.grad_check:
                for n, p in decoder.named_parameters():
                    if torch.isnan(p.grad).any():
                        nan = True
                        print("{} Contains nan gradients!".format(n))
            if nan:
                break

            if opt.clip > 0.:
                clip_grad_norm_(decoder.parameters(),
                                opt.clip)  # Might prevent exploding gradients

            if opt.lagrangian:
                # This is equivalent to flipping the sign on the loss and computing its backward
                # So it prevents computation of the backward twice, once for max and once for min
                for group in optimizers[1].param_groups:
                    for p in group['params']:
                        p.grad = -1 * p.grad

            [optimizer.step() for optimizer in optimizers]
        end = time.time()
        print("Train time: {}s".format(end - start))

        start = time.time()
        # We wrap the entire evaluation in no_grad to save memory
        with torch.no_grad():
            zs = []
            log_q_z_xs = []
            log_p_zs = []
            mus = []
            vars = []
            for data in data_eval:
                # Catch small batches
                if data[0].shape[0] != opt.batch_size:
                    continue

                # Prepare
                decoder.eval()
                data = [d.to(opt.device) for d in data]

                # Sample a number of log-likehoods to obtain a low-variance estimate of the model perplexity
                # We do this with a single sample when training for speed. On test we will use more samples
                decoder.use_prior = True
                losses, pred = decoder(data)
                stats.val_loss.append(losses["NLL"].item())
                stats.val_l2_loss.append(losses["L2"].item())
                stats.val_acc.append(compute_accuracy(pred, data).item())
                stats.val_log_loss[0].append(losses["NLL"].item())

                if len(data) > 1:
                    stats.avg_len.append(
                        torch.mean(data[1].float()).item() - 1)
                else:
                    stats.avg_len.append(data[0].shape[1] - 1)

                # Also sample the perplexity for the reconstruction case (using the posterior)
                decoder.use_prior = False
                if opt.mi:
                    losses, pred, var, mu, z, _, log_q_z_x, log_p_z = decoder(
                        data, extensive=True)
                    zs.append(z), log_q_z_xs.append(
                        log_q_z_x), log_p_zs.append(log_p_z), mus.append(
                            mu), vars.append(var)
                else:
                    losses, pred = decoder(data)
                stats.val_rec_loss.append(losses["NLL"].item())
                stats.val_rec_kl.append(losses["KL"].item())
                stats.val_rec_elbo.append(losses["NLL"].item() +
                                          losses["KL"].item())
                stats.val_rec_min_rate.append(losses["Hinge"].item())
                stats.val_rec_l2_loss.append(losses["L2"].item())
                stats.val_rec_mmd.append(losses["MMD"].item())
                stats.val_rec_acc.append(compute_accuracy(pred, data).item())
                stats.val_rec_log_loss[0].append(losses["NLL"].item() +
                                                 losses["KL"].item())

            if opt.mi:
                # Stack the collected samples and parameters
                z = torch.cat(zs, 0)
                log_q_z_x = torch.cat(log_q_z_xs, 0)
                log_p_z = torch.cat(log_p_zs, 0)
                mu = torch.cat(mus, 0)
                var = torch.cat(vars, 0)
                avg_kl = torch.tensor(stats.val_rec_kl,
                                      dtype=torch.float,
                                      device=opt.device).mean()
                avg_h = log_q_z_x.mean()
                log_q_z = decoder.q_z_estimate(z, mu, var)
                stats.val_mi, stats.val_mkl = compute_mutual_information(
                    z, log_p_z, avg_h, avg_kl, opt.mi_method,
                    opt.mi_kde_method, log_q_z)
        end = time.time()
        print("Eval time: {}s".format(end - start))

        # Compute the perplexity and its variance for this batch, sampled N times
        perplexity, variance = compute_perplexity(stats.val_log_loss,
                                                  stats.avg_len)
        rec_ppl, rec_var = compute_perplexity(stats.val_rec_log_loss,
                                              stats.avg_len)

        stats.val_ppl.append(perplexity)
        stats.val_rec_ppl.append(rec_ppl)
        stats.val_ppl_std.append(variance)
        stats.val_rec_ppl_std.append(rec_var)
        stats.kl_scale = decoder._scale

        # Print and log the statistics after every epoch
        # Note that the StatsHandler object automatically prepares the stats, so no more stats can be added
        vprint(stats, opt.verbosity, 0)
        stats.log_stats(writer)

        # We early stop when the model has not improved certain criteria for a given number of epochs.
        stop = [0] * len(opt.criteria)
        i = 0
        # This is the default criteria; we will stop when the ELBO/LL no longer improves
        if 'posterior' in opt.criteria:
            if stats.val_rec_elbo > (prev_crit[i] - opt.min_imp) and epoch > 4:
                stop[i] = 1
            else:
                stop_ticks = 0
            if stats.val_rec_elbo < prev_crit[i]:
                try:
                    save_checkpoint(opt, decoder, optimizers, epoch)
                except InvalidPathError as e:
                    vprint(e, opt.verbosity, 0)
                    vprint("Cannot save model, continuing without saving...",
                           opt.verbosity, 0)
                prev_crit[i] = stats.val_rec_elbo
            i += 1

        # We early stop the model when an estimate of the log-likelihood based on prior samples no longer increases
        # This generally only makes sense when we have a learned prior
        if 'prior' in opt.criteria:
            if stats.val_loss > (prev_crit[i] - opt.min_imp) and epoch > 4:
                stop[i] = 1
            else:
                stop_ticks = 0
            if stats.val_loss < prev_crit[i]:
                try:
                    # For each non standard criteria we add a suffix to the model name
                    save_checkpoint(opt, decoder, optimizers, epoch, 'prior')
                except InvalidPathError as e:
                    vprint(e, opt.verbosity, 0)
                    vprint("Cannot save model, continuing without saving...",
                           opt.verbosity, 0)
                prev_crit[i] = stats.val_loss

        # So far we can choose to either/or save models based on prior loss and posterior loss
        if 'prior' not in opt.criteria and 'posterior' not in opt.criteria:
            raise UnknownArgumentError(
                "No valid early stopping criteria found, please choose either/both [posterior, prior]"
            )

        # We only increase the stop ticks if all criteria are not satisfied
        stop_ticks += int(np.all(np.array(stop)))

        # When we reach a user specified amount of stop ticks, we stop training
        if stop_ticks >= opt.stop_ticks:
            writer.close()
            vprint("Early stopping after {} epochs".format(epoch),
                   opt.verbosity, 0)
            early_stopping = True