コード例 #1
0
ファイル: main.py プロジェクト: zizai/deep-generative-lm
def main():
    opt, parser = parse_arguments()
    opt = predefined(opt)
    print_flags(opt)

    if not osp.isdir(opt.out_folder):
        os.makedirs(opt.out_folder)

    if opt.script == 'generative':
        if opt.mode == 'train':
            train(opt)
        elif opt.mode == 'test':
            test(opt)
        elif opt.mode == 'generate':
            generate_data(opt)
        elif opt.mode == 'novelty':
            novelty(opt)
        elif opt.mode == 'qualitative':
            qualitative(opt)
        else:
            raise UnknownArgumentError(
                "--mode not recognized, please choose: [train, test, generate, qualitative, novelty]."
            )
    elif opt.script == 'bayesopt':
        optimize_bayesian(opt, parser)
    elif opt.script == 'grid':
        run_grid(opt, parser)
    else:
        raise UnknownArgumentError(
            "--script not recognized, please choose: [generative, bayesopt, grid]."
        )
コード例 #2
0
def get_model_path(opt, suffix=""):
    if opt.script == "generative":
        return osp.join(
            opt.out_folder,
            "{}_checkpoint_{}{}.pt".format(opt.model, opt.save_suffix, suffix))
    else:
        raise UnknownArgumentError(
            "--script not recognized, please choose: [generative].")
コード例 #3
0
ファイル: bayesopt.py プロジェクト: zizai/deep-generative-lm
def get_parameters(opt):
    """Add desired parameter lists and initial search space to this function."""
    Y_init = None
    parameters = list()
    X_init = list()
    if "example" in opt.bayes_mode:
        # We perform Bayesian search over four parameters of the decoder (RNNLM) and specify some initial points
        # to test. Alternatively, one can specify no initial points and choose a 'grid' (eg latin) option
        # in GPyOpt to randomly select some starting points. See the GPyOpt docs for more information.
        parameters = [{
            'name': 'layers',
            'type': 'discrete',
            'domain': (1, 2)
        }, {
            'name': 'h_dim',
            'type': 'discrete',
            'domain': list(range(128, 513, 32))
        }, {
            'name': 'p',
            'type': 'continuous',
            'domain': (0., 0.6)
        }, {
            'name': 'cut_off',
            'type': 'continuous',
            'domain': (1, 5)
        }]
        X_init.append([1., 2., 1., 2.])  # Number of layers
        X_init.append([256., 128., 512., 256.])  # Number of hidden units
        X_init.append([0.2, 0.1, 0.3, 0.4])  # Dropout
        X_init.append(
            [1., 2., 4., 3.]
        )  # Number of standard deviations above mean sentence length to truncate

        # for param in X_init:
        #     shuffle(param)
        X_init = np.array(X_init).T
        print(X_init)
    else:
        raise UnknownArgumentError(
            "Uknown bayes mode: {}. Please choose another or specify this one yourself."
        )

    if opt.bayes_load:
        # Load previously stored results as initialization for Bayesian search
        X_init = pickle.load(
            open(
                osp.join(opt.out_folder, "bayesian",
                         "bayesian_X_{}.pickle".format(opt.bayes_mode)), 'rb'))
        Y_init = pickle.load(
            open(
                osp.join(opt.out_folder, "bayesian",
                         "bayesian_Y_{}.pickle".format(opt.bayes_mode)), 'rb'))

    tuning_list = [d['name'] for d in parameters]
    return parameters, tuning_list, X_init, Y_init
コード例 #4
0
    def constraint(self, vals):
        if type(vals) != list:
            if isinstance(vals, str):
                vals = [vals]
            else:
                raise InvalidArgumentError('constraint should be a list or str')
        for val in vals:
            if val not in ['mdr', 'mmd']:
                raise UnknownArgumentError(
                    'constraint {} unknown. Please choose [mdr, mmd].'.format(val))

        self._constraint = vals
コード例 #5
0
def get_true_data_path(opt, mode):
    """Return the paths containing true data text and indices given user settings and a mode."""
    if mode == 'train':
        indices = osp.join(opt.data_folder, opt.train_file)
    elif mode == 'valid':
        indices = osp.join(opt.data_folder, opt.val_file)
    elif mode == 'test':
        indices = osp.join(opt.data_folder, opt.test_file)
    else:
        raise UnknownArgumentError(
            "mode: {} not recognized. Please choose [train, valid, test]".
            format(mode))

    text = indices.split(".")[0] + ".txt"
    return text, indices
コード例 #6
0
def set_ptb_folders(opt):
    """This function sets default paths given a ptb_type."""
    if opt.ptb_type == "dyer":
        opt.data_folder = osp.join(toplevel_path, "dataset/penn_treebank_dyer")
        opt.out_folder = osp.join(toplevel_path, "out/penn_treebank_dyer")
        opt.v_dim = 25643
    elif opt.ptb_type == "mik":
        opt.data_folder = osp.join(toplevel_path, "dataset/penn_treebank")
        opt.out_folder = osp.join(toplevel_path, "out/penn_treebank")
        opt.v_dim = 10002
    else:
        raise UnknownArgumentError(
            "Unknown ptb_type {}. Please choose [mik, dyer]".format(
                opt.ptb_type))
    return opt
コード例 #7
0
def get_parameters(opt):
    """Add desired parameter lists and initial search space to this function."""
    parameters = list()
    X_init = list()
    if "mdr_example" in opt.bayes_mode:
        # Run grid search over a series of target rates
        parameters.append("min_rate")
        X_init.append([5., 10., 15., 20., 25., 30., 35., 40., 45., 50.])
    else:
        raise UnknownArgumentError(
            "Uknown bayes mode: {}. Please choose another or specify this one yourself.")

    X_init = np.array(X_init).T
    print(X_init)

    return parameters, X_init
コード例 #8
0
def initialize_dataloader(opt, word_to_idx, collate_fn):
    """Initializes the dataloader with the given user settings and collate function."""
    if opt.mode in ['train', 'qualitative']:
        data_train = DataLoader(TextDataUnPadded(
            get_true_data_path(opt, "train")[1], opt.seq_len,
            word_to_idx[opt.pad_token]),
                                collate_fn=collate_fn,
                                batch_size=opt.batch_size,
                                shuffle=True,
                                num_workers=4,
                                pin_memory=True)
        data_eval = DataLoader(TextDataUnPadded(
            get_true_data_path(opt, "valid")[1], 0,
            word_to_idx[opt.pad_token]),
                               collate_fn=collate_fn,
                               batch_size=opt.batch_size,
                               shuffle=True,
                               num_workers=4,
                               pin_memory=True)
        return data_train, data_eval
    elif opt.mode == 'test':
        # We use the PTB test set only sparsly
        if opt.use_test_set:
            return DataLoader(TextDataUnPadded(
                get_true_data_path(opt, "test")[1], 0,
                word_to_idx[opt.pad_token]),
                              collate_fn=collate_fn,
                              batch_size=opt.batch_size,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(TextDataUnPadded(
                get_true_data_path(opt, "valid")[1], 0,
                word_to_idx[opt.pad_token]),
                              collate_fn=collate_fn,
                              batch_size=opt.batch_size,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True)

    else:
        raise UnknownArgumentError(
            "Cannot load data for --mode={}. Please choose [train, test, qualitative]"
            .format(opt.mode))
コード例 #9
0
def kde_gauss(samples, method):
    """KDE estimation with a Gaussian kernel for a batch of samples. Returns the log probability of each sample."""
    if method == 'scipy':
        # [num, z_dim] -> [z_dim, num]
        samples_cpu = samples.cpu().numpy().transpose()
        kde = gaussian_kde(samples_cpu)
        # [num] with p(samples) under kernels
        return samples.new_tensor(kde.logpdf(samples_cpu))
    elif method == 'pytorch':
        # [num]
        return torch.log(
            samples.new_tensor(
                compute_kernel(samples.cpu(), samples.cpu()).mean(dim=1) +
                1e-80))
    else:
        raise UnknownArgumentError(
            'KDE method {} is unknown. Please choose [scipy, pytorch]'.format(
                method))
コード例 #10
0
def compute_mutual_information(samples,
                               log_p_z,
                               avg_H,
                               avg_KL,
                               method,
                               kde_method,
                               log_q_z=None):
    """Computes the mutual information given a batch of samples and their LL under the sampling distribution.

    If log_q_z is not provided, this method uses kernel density estimation on the provided samples to compute this
    quantity. Otherwise, we will use the provided log_q_z to estimate the MI, either through Hoffman's or Zhao's method.

    Args:
        samples(torch.FloatTensor): [N, z_dim] dimensional tensor of samples from q(z|x).
        log_p_z(torch.FloatTensor): [N] dimensional tensor of log-probabilities of the samples under p(z).
        avg_H(torch.FloatTensor): [] dimensional tensor containing the average entropy of q(z|x).
        avg_KL(torch.FloatTensor): [] dimensional tensor containing the average KL[q(z|x)||p(z)].
        method: method for estimating the mutual information.
        kde_method: method for obtaining the kde likelihood estimates of the samples under q(z).
        log_q_z: [N] dimensional tensor of log-probabilities of the samples under q(z), or None.
    Returns:
        mi_estimate(float): estimated mutual information between X and Z given N samples from q(z|x).
        marg_KL(float): estimated KL[q(z)||p(z)] given N samples from q(z|x).
    """
    if log_q_z is None:
        log_q_z = kde_gauss(samples, kde_method)

    marg_KL = (log_q_z - log_p_z).mean()

    if method == 'zhao':  # https://arxiv.org/pdf/1806.06514.pdf (Lagrangian VAE)
        mi_estimate = (avg_H - log_q_z.mean()).item()
    elif method == 'hoffman':  # http://approximateinference.org/accepted/HoffmanJohnson2016.pdf (ELBO surgery)
        mi_estimate = (avg_KL - marg_KL).item()
    else:
        raise UnknownArgumentError(
            'MI method {} is unknown. Please choose [zhao, hoffman]'.format(
                method))

    return mi_estimate, marg_KL.item()
コード例 #11
0
def initialize_model(opt, word_to_idx):
    """Initializes model with the given user settings."""
    if opt.model == "bowman":
        decoder = BowmanDecoder(
            opt.device, opt.seq_len, opt.kl_step, opt.word_p, opt.enc_word_p,
            opt.p, opt.enc_p, opt.drop_type, opt.min_rate,
            word_to_idx[opt.unk_token], opt.css, opt.sparse, opt.N,
            opt.rnn_type, opt.tie_in_out, opt.beta, opt.lamb, opt.mmd,
            opt.ann_mode, opt.rate_mode, opt.posterior, opt.hinge_weight,
            opt.k, opt.ann_word, opt.word_step, opt.v_dim, opt.x_dim,
            opt.h_dim, opt.z_dim, opt.s_dim, opt.layers, opt.enc_h_dim,
            opt.enc_layers, opt.lagrangian, opt.constraint, opt.max_mmd,
            opt.max_elbo, opt.alpha).to(opt.device)
    elif opt.model == "flowbowman":
        decoder = FlowBowmanDecoder(
            opt.device, opt.seq_len, opt.kl_step, opt.word_p, opt.enc_word_p,
            opt.p, opt.enc_p, opt.drop_type, opt.min_rate,
            word_to_idx[opt.unk_token], opt.css, opt.sparse, opt.N,
            opt.rnn_type, opt.tie_in_out, opt.beta, opt.lamb, opt.mmd,
            opt.ann_mode, opt.rate_mode, opt.posterior, opt.hinge_weight,
            opt.k, opt.ann_word, opt.word_step, opt.flow, opt.flow_depth,
            opt.h_depth, opt.prior, opt.num_weights, opt.mean_len, opt.std_len,
            opt.v_dim, opt.x_dim, opt.h_dim, opt.z_dim, opt.s_dim, opt.layers,
            opt.enc_h_dim, opt.enc_layers, opt.c_dim, opt.lagrangian,
            opt.constraint, opt.max_mmd, opt.max_elbo, opt.alpha).to(
                opt.device)
    elif opt.model == "deterministic":
        decoder = DeterministicDecoder(
            opt.device, opt.seq_len, opt.word_p, opt.p, opt.drop_type,
            word_to_idx[opt.unk_token], opt.css, opt.sparse, opt.N,
            opt.rnn_type, opt.tie_in_out, opt.v_dim, opt.x_dim, opt.h_dim,
            opt.s_dim, opt.layers).to(opt.device)
    else:
        raise UnknownArgumentError(
            "--model not recognized, please choose: [deterministic, bowman, flowbowman]."
        )

    return decoder
コード例 #12
0
        save_samples(opt, samples, sample_indices, mode)


def file_len(fname):
    with open(fname, 'r') as f:
        for i, l in enumerate(f):
            pass
    return i + 1


if __name__ == "__main__":
    opt = parse_arguments()
    opt = predefined(opt)
    print_flags(opt)

    # Set script info so this can be used without having to think about this setting
    opt.script = "generative"

    if not osp.isdir(opt.out_folder):
        os.makedirs(opt.out_folder)

    if opt.mode == 'train':
        train(opt)
    elif opt.mode == 'test':
        test(opt)
    elif opt.mode == 'generate':
        generate_data(opt)
    else:
        raise UnknownArgumentError(
            "--mode not recognized, please choose: [train, test, generate].")
コード例 #13
0
 def posterior(self, val):
     if val not in ["gaussian", "vmf"]:
         return UnknownArgumentError(
             "Unknown posterior: {}. Please choose [gaussian, vmf].".format(
                 val))
     self._posterior = val
コード例 #14
0
 def ann_mode(self, value):
     if value not in ["linear", "sfb"]:
         raise UnknownArgumentError(
             "Unknown ann_mode {}. Please choose [linear, sfb]".format(
                 value))
     self._ann_mode = value
コード例 #15
0
 def rnn_type(self, value):
     if value not in ["LSTM", "GRU"]:
         raise UnknownArgumentError(
             "Unknown rnn_type {}. Please choose [GRU, LSTM]".format(value))
     self._rnn_type = value
コード例 #16
0
 def drop_type(self, value):
     if value not in ["varied", "shared", "recurrent"]:
         raise UnknownArgumentError(
             "Unknown drop_type: {}. Please choose [varied, shared, recurrent]"
             .format(value))
     self._drop_type = value
コード例 #17
0
ファイル: generative.py プロジェクト: yyht/deep-generative-lm
def train(opt):
    """Script that trains a generative model of language given various user settings."""
    # Try to load options when we resume
    if opt.resume:
        try:
            opt = load_options(opt)
        except InvalidPathError as e:
            warn("{}\n Starting from scratch...".format(e))
            opt.resume = 0
            epoch = 0
        except Error as e:
            warn(
                "{}\n Make sure all preset arguments coincide with the model you are loading."
                .format(e))
    else:
        epoch = 0

    # Set device so script works on both GPU and CPU
    seed(opt)
    opt.device = torch.device(
        "cuda:{}".format(opt.local_rank) if opt.local_rank >= 0 else "cpu")
    vprint("Using device: {}".format(opt.device), opt.verbosity, 1)

    word_to_idx, idx_to_word = load_word_index_maps(opt)

    # Here we construct all parts of the training ensemble; the model, dataloaders and optimizer
    data_train, data_eval = initialize_dataloader(opt, word_to_idx,
                                                  sort_pad_collate)
    opt.N = (len(data_train) - 1) * opt.batch_size
    decoder = initialize_model(opt, word_to_idx)
    optimizers = []
    if opt.sparse:
        sparse_parameters = [
            p[1] for p in filter(lambda p: p[0] == "emb.weight",
                                 decoder.named_parameters())
        ]
        parameters = [
            p[1] for p in filter(
                lambda p: p[1].requires_grad and p[0] != "emb.weight",
                decoder.named_parameters())
        ]
        optimizers.append(Adam(parameters, opt.lr))
        optimizers.append(SparseAdam(sparse_parameters, opt.lr))
    elif opt.lagrangian:
        lag_parameters = [
            p[1] for p in filter(lambda p: p[0] == "lag_weight",
                                 decoder.named_parameters())
        ]
        parameters = [
            p[1] for p in filter(
                lambda p: p[1].requires_grad and p[0] != "lag_weight.weight",
                decoder.named_parameters())
        ]
        optimizers.append(Adam(parameters, opt.lr))
        optimizers.append(RMSprop(lag_parameters, opt.lr))
    else:
        parameters = filter(lambda p: p.requires_grad, decoder.parameters())
        optimizers.append(Adam(parameters, opt.lr))

    # Load from checkpoint
    if opt.resume:
        decoder, optimizers, epoch = load_checkpoint(opt, decoder, optimizers)

    # The SummaryWriter will log certain values for automatic visualization
    writer = SummaryWriter(
        osp.join(opt.out_folder, opt.model, opt.save_suffix, 'train'))

    # The StatsHandler object will store important stastics during training and provides printing and logging utilities
    stats = StatsHandler(opt)

    # We will early stop the network based on user specified criteria
    early_stopping = False
    stop_ticks = 0
    prev_crit = [np.inf] * len(opt.criteria)

    while not early_stopping:
        # We reset the stats object to collect fresh stats for every epoch
        stats.reset()
        epoch += 1
        stats.epoch = epoch

        start = time.time()
        for data in data_train:
            # We zero the gradients BEFORE the forward pass, instead of before the backward, to save some memory
            [optimizer.zero_grad() for optimizer in optimizers]

            # We skip the remainder batch
            if data[0].shape[0] != opt.batch_size:
                continue

            # Prepare
            decoder.train()
            decoder.use_prior = False
            data = [d.to(opt.device) for d in data]

            # Forward
            losses, pred = decoder(data)
            loss = sum([
                v for k, v in losses.items()
                if "Lag_Weight" not in k and "Constraint_" not in k
            ])

            # Log the various losses the models can return, and accuracy
            stats.train_loss.append(losses["NLL"].item())
            stats.train_kl.append(losses["KL"].item())
            stats.train_elbo.append(losses["NLL"].item() + losses["KL"].item())
            stats.train_min_rate.append(losses["Hinge"].item())
            stats.train_l2_loss.append(losses["L2"].item())
            stats.train_mmd.append(losses["MMD"].item())
            stats.train_acc.append(compute_accuracy(pred, data).item())
            for i in range(len(opt.constraint)):
                stats.constraints[i].append(
                    losses["Constraint_{}".format(i)].item())
                stats.lambs[i].append(losses["Lag_Weight_{}".format(i)].item())
            del data

            loss.backward()

            # Check for bad gradients
            nan = False
            if opt.grad_check:
                for n, p in decoder.named_parameters():
                    if torch.isnan(p.grad).any():
                        nan = True
                        print("{} Contains nan gradients!".format(n))
            if nan:
                break

            if opt.clip > 0.:
                clip_grad_norm_(decoder.parameters(),
                                opt.clip)  # Might prevent exploding gradients

            if opt.lagrangian:
                # This is equivalent to flipping the sign on the loss and computing its backward
                # So it prevents computation of the backward twice, once for max and once for min
                for group in optimizers[1].param_groups:
                    for p in group['params']:
                        p.grad = -1 * p.grad

            [optimizer.step() for optimizer in optimizers]
        end = time.time()
        print("Train time: {}s".format(end - start))

        start = time.time()
        # We wrap the entire evaluation in no_grad to save memory
        with torch.no_grad():
            zs = []
            log_q_z_xs = []
            log_p_zs = []
            mus = []
            vars = []
            for data in data_eval:
                # Catch small batches
                if data[0].shape[0] != opt.batch_size:
                    continue

                # Prepare
                decoder.eval()
                data = [d.to(opt.device) for d in data]

                # Sample a number of log-likehoods to obtain a low-variance estimate of the model perplexity
                # We do this with a single sample when training for speed. On test we will use more samples
                decoder.use_prior = True
                losses, pred = decoder(data)
                stats.val_loss.append(losses["NLL"].item())
                stats.val_l2_loss.append(losses["L2"].item())
                stats.val_acc.append(compute_accuracy(pred, data).item())
                stats.val_log_loss[0].append(losses["NLL"].item())

                if len(data) > 1:
                    stats.avg_len.append(
                        torch.mean(data[1].float()).item() - 1)
                else:
                    stats.avg_len.append(data[0].shape[1] - 1)

                # Also sample the perplexity for the reconstruction case (using the posterior)
                decoder.use_prior = False
                if opt.mi:
                    losses, pred, var, mu, z, _, log_q_z_x, log_p_z = decoder(
                        data, extensive=True)
                    zs.append(z), log_q_z_xs.append(
                        log_q_z_x), log_p_zs.append(log_p_z), mus.append(
                            mu), vars.append(var)
                else:
                    losses, pred = decoder(data)
                stats.val_rec_loss.append(losses["NLL"].item())
                stats.val_rec_kl.append(losses["KL"].item())
                stats.val_rec_elbo.append(losses["NLL"].item() +
                                          losses["KL"].item())
                stats.val_rec_min_rate.append(losses["Hinge"].item())
                stats.val_rec_l2_loss.append(losses["L2"].item())
                stats.val_rec_mmd.append(losses["MMD"].item())
                stats.val_rec_acc.append(compute_accuracy(pred, data).item())
                stats.val_rec_log_loss[0].append(losses["NLL"].item() +
                                                 losses["KL"].item())

            if opt.mi:
                # Stack the collected samples and parameters
                z = torch.cat(zs, 0)
                log_q_z_x = torch.cat(log_q_z_xs, 0)
                log_p_z = torch.cat(log_p_zs, 0)
                mu = torch.cat(mus, 0)
                var = torch.cat(vars, 0)
                avg_kl = torch.tensor(stats.val_rec_kl,
                                      dtype=torch.float,
                                      device=opt.device).mean()
                avg_h = log_q_z_x.mean()
                log_q_z = decoder.q_z_estimate(z, mu, var)
                stats.val_mi, stats.val_mkl = compute_mutual_information(
                    z, log_p_z, avg_h, avg_kl, opt.mi_method,
                    opt.mi_kde_method, log_q_z)
        end = time.time()
        print("Eval time: {}s".format(end - start))

        # Compute the perplexity and its variance for this batch, sampled N times
        perplexity, variance = compute_perplexity(stats.val_log_loss,
                                                  stats.avg_len)
        rec_ppl, rec_var = compute_perplexity(stats.val_rec_log_loss,
                                              stats.avg_len)

        stats.val_ppl.append(perplexity)
        stats.val_rec_ppl.append(rec_ppl)
        stats.val_ppl_std.append(variance)
        stats.val_rec_ppl_std.append(rec_var)
        stats.kl_scale = decoder._scale

        # Print and log the statistics after every epoch
        # Note that the StatsHandler object automatically prepares the stats, so no more stats can be added
        vprint(stats, opt.verbosity, 0)
        stats.log_stats(writer)

        # We early stop when the model has not improved certain criteria for a given number of epochs.
        stop = [0] * len(opt.criteria)
        i = 0
        # This is the default criteria; we will stop when the ELBO/LL no longer improves
        if 'posterior' in opt.criteria:
            if stats.val_rec_elbo > (prev_crit[i] - opt.min_imp) and epoch > 4:
                stop[i] = 1
            else:
                stop_ticks = 0
            if stats.val_rec_elbo < prev_crit[i]:
                try:
                    save_checkpoint(opt, decoder, optimizers, epoch)
                except InvalidPathError as e:
                    vprint(e, opt.verbosity, 0)
                    vprint("Cannot save model, continuing without saving...",
                           opt.verbosity, 0)
                prev_crit[i] = stats.val_rec_elbo
            i += 1

        # We early stop the model when an estimate of the log-likelihood based on prior samples no longer increases
        # This generally only makes sense when we have a learned prior
        if 'prior' in opt.criteria:
            if stats.val_loss > (prev_crit[i] - opt.min_imp) and epoch > 4:
                stop[i] = 1
            else:
                stop_ticks = 0
            if stats.val_loss < prev_crit[i]:
                try:
                    # For each non standard criteria we add a suffix to the model name
                    save_checkpoint(opt, decoder, optimizers, epoch, 'prior')
                except InvalidPathError as e:
                    vprint(e, opt.verbosity, 0)
                    vprint("Cannot save model, continuing without saving...",
                           opt.verbosity, 0)
                prev_crit[i] = stats.val_loss

        # So far we can choose to either/or save models based on prior loss and posterior loss
        if 'prior' not in opt.criteria and 'posterior' not in opt.criteria:
            raise UnknownArgumentError(
                "No valid early stopping criteria found, please choose either/both [posterior, prior]"
            )

        # We only increase the stop ticks if all criteria are not satisfied
        stop_ticks += int(np.all(np.array(stop)))

        # When we reach a user specified amount of stop ticks, we stop training
        if stop_ticks >= opt.stop_ticks:
            writer.close()
            vprint("Early stopping after {} epochs".format(epoch),
                   opt.verbosity, 0)
            early_stopping = True
コード例 #18
0
    def __init__(self, device, seq_len, kl_step, word_p, word_p_enc,
                 parameter_p, encoder_p, drop_type, min_rate, unk_index, css,
                 sparse, N, rnn_type, tie_in_out, beta, lamb, mmd, ann_mode,
                 rate_mode, posterior, hinge_weight, k, ann_word, word_step,
                 flow, flow_depth, hidden_depth, prior, num_weights, mean_len,
                 std_len, v_dim, x_dim, h_dim, z_dim, s_dim, l_dim, h_dim_enc,
                 l_dim_enc, c_dim, lagrangian, constraint, max_mmd, max_elbo,
                 alpha):
        super(FlowBowmanDecoder,
              self).__init__(device, seq_len, kl_step, word_p, word_p_enc,
                             parameter_p, encoder_p, drop_type, min_rate,
                             unk_index, css, sparse, N, rnn_type, tie_in_out,
                             beta, lamb, mmd, ann_mode, rate_mode, posterior,
                             hinge_weight, k, ann_word, word_step, v_dim,
                             x_dim, h_dim, z_dim, s_dim, l_dim, h_dim_enc,
                             l_dim_enc, lagrangian, constraint, max_mmd,
                             max_elbo, alpha)

        if flow == "diag":
            self.flow = Diag()
        elif flow == "iaf":
            self.flow = IAF(c_dim, z_dim, flow_depth, hidden_depth)
        elif flow == "vpiaf":
            self.flow = IAF(c_dim,
                            z_dim,
                            flow_depth,
                            hidden_depth,
                            scale=False)
        elif flow == 'planar':
            self.flow = Planar(h_dim_enc * 2, z_dim, hidden_depth)
        else:
            raise UnknownArgumentError(
                "Flow type not recognized: {}. Please choose [diag, iaf, vpiaf, planar]"
                .format(flow))

        if prior in ["mog", "vamp", "weak"]:
            self.prior = prior
        else:
            raise NotImplementedError(
                "No implementation for prior: {}".format(prior))

        if self.prior == "mog":
            # Create learnable mixture parameters with fixed weights and initialize
            mixture_weights = torch.Tensor(num_weights).fill_(1. / num_weights)
            self.register_buffer("mixture_weights", mixture_weights)
            self.mixture_mu = Parameter(torch.Tensor(num_weights, self.z_dim))
            if self.posterior == "vmf":
                self.mixture_var = Parameter(torch.Tensor(num_weights, 1))
            else:
                self.mixture_var = Parameter(
                    torch.Tensor(num_weights, self.z_dim))
            xavier_normal_(self.mixture_var)
            xavier_normal_(self.mixture_mu)
        elif self.prior == "vamp":
            # Create learnable pseudoinputs with fixed weights and initialize
            pseudo_weights = torch.Tensor(num_weights).fill_(1. / num_weights)
            self.register_buffer("pseudo_weights", pseudo_weights)
            self.pseudo_inputs = Parameter(
                torch.Tensor(num_weights, seq_len, x_dim))
            xavier_normal_(self.pseudo_inputs)

            # Sample lengths for the pseudo inputs based on database statistics
            pre_lengths = Normal(mean_len,
                                 std_len).sample(torch.Size([num_weights]))
            lengths = torch.clamp(torch.round(pre_lengths), 1, seq_len).long()
            lengths = torch.sort(lengths, descending=True)[0]
            self.register_buffer("pseudo_lengths", lengths)

        # layer to context
        if 'iaf' in flow:
            self.h_to_context = nn.Linear(h_dim_enc * l_dim_enc * 2, c_dim)
        else:
            self.h_to_context = nn.Sequential()