Exemple #1
0
def interpolate_sentences(num=10):
    # NOTE: vocab based on datasets
    train_iter, test_iter, valid_iter, vocab = get_gyafc(conf)

    ckpt = torch.load(save_path)
    vae, vae_trainer = create_vae(conf, vocab)
    vae.load_state_dict(ckpt['vae_dict'])
    vae.eval()
    del ckpt

    z1 = on_cuda(torch.randn([1, conf.n_z]))
    # z2 = on_cuda(torch.randn([1, conf.n_z]))
    z2 = z1 + on_cuda(0.3 * torch.ones(z1.size()))

    int_z = torch.lerp(z1, z2,
                       on_cuda(torch.linspace(0.0, 1.0, num).unsqueeze(1)))
    # zs to strings
    for i in range(int_z.size()[0]):
        z = int_z[i, :].unsqueeze(0)
        h_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        c_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        G_hidden = (h_0, c_0)
        G_inp = torch.LongTensor(1, 1).fill_(vocab.stoi[conf.start_token])
        G_inp = on_cuda(G_inp)
        string = conf.start_token + ' '
        while G_inp[0][0].item() != vocab.stoi[conf.end_token]:
            with torch.autograd.no_grad():
                logit, G_hidden, _ = vae(None, G_inp, z, G_hidden)
            probs = F.softmax(logit[0], dim=1)
            G_inp = torch.multinomial(probs, 1)
            string += (vocab.itos[G_inp[0][0].item()] + ' ')
        print('----------------------------')
        print(string.encode('utf-8'))
Exemple #2
0
def sampling_around_existing_sentence(s1, num=10):
    # NOTE: vocab based on datasets
    train_iter, test_iter, valid_iter, vocab = get_gyafc(conf)

    ckpt = torch.load(save_path)
    vae, vae_trainer = create_vae(conf, vocab)
    vae.load_state_dict(ckpt['vae_dict'])
    vae.eval()
    del ckpt

    # string to tensor
    s1_tensor = str_to_tensor(s1, vocab, conf)
    s1_tensor = on_cuda(s1_tensor.unsqueeze(0))

    mu, logvar = vae.encode(s1_tensor)
    mvn = MultivariateNormal(mu, scale_tril=torch.diag(torch.exp(logvar[0])))

    for i in range(num):
        z = mvn.sample()
        h_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        c_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        G_hidden = (h_0, c_0)
        G_inp = torch.LongTensor(1, 1).fill_(vocab.stoi[conf.start_token])
        G_inp = on_cuda(G_inp)
        string = conf.start_token + ' '
        while G_inp[0][0].item() != vocab.stoi[conf.end_token]:
            with torch.autograd.no_grad():
                logit, G_hidden, _ = vae(None, G_inp, z, G_hidden)
            probs = F.softmax(logit[0], dim=1)
            G_inp = torch.multinomial(probs, 1)
            string += (vocab.itos[G_inp[0][0].item()] + ' ')
        print('----------------------------')
        print(string.encode('utf-8'))
Exemple #3
0
def generate_sentences(n_examples):
    # NOTE: vocab based on datasets
    train_iter, test_iter, valid_iter, vocab = get_gyafc(conf)

    ckpt = torch.load(save_path)
    vae, vae_trainer = create_vae(conf, vocab)
    vae.load_state_dict(ckpt['vae_dict'])
    vae.eval()
    del ckpt

    for i in range(n_examples):
        z = on_cuda(torch.randn([1, conf.n_z]))
        h_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        c_0 = on_cuda(torch.zeros(2 * conf.n_layers_E, 1, conf.n_hidden_G))
        G_hidden = (h_0, c_0)
        # 2 is the index of start token in vocab stoi
        G_inp = torch.LongTensor(1, 1).fill_(vocab.stoi[conf.start_token])
        G_inp = on_cuda(G_inp)
        string = conf.start_token + ' '
        # until we hit end token (index 3 in vocab stoi)
        while G_inp[0][0].item() != vocab.stoi[conf.end_token]:
            with torch.autograd.no_grad():
                logit, G_hidden, _ = vae(None, G_inp, z, G_hidden)
            probs = F.softmax(logit[0], dim=1)
            G_inp = torch.multinomial(probs, 1)
            string += (vocab.itos[G_inp[0][0].item()] + ' ')
        # print(string.encode('utf-8'))
        print(string)
Exemple #4
0
    def __init__(self, conf):
        # create vae, load weights
        _, _, _, self.vocab = get_gyafc(conf)
        self.vae, _ = create_vae(conf, self.vocab)
        ckpt = torch.load(conf.vae_model_path)
        self.vae.load_state_dict(ckpt['vae_dict'])
        self.vae.eval()
        del(ckpt)

        # create linear shift
        self.linear_shift = on_cuda(LinearShift(conf))

        # save conf
        self.conf = conf
        # init
        self.score = 0
        self.eval_done = False

        # load dataset
        self.test = get_formality_set(conf, self.vocab)

        # scoring
        self.extractor = FeatureExtractor(conf.w2v_path, conf.corpus_dict_path)
        self.pt16_ridge = pickle.load(open(conf.pt16_path, 'rb'))
    return vae, trainer_vae

if __name__ == '__main__':
    with open('configs/default.yaml') as file:
        conf_dict = yaml.load(file, Loader=yaml.FullLoader)
    conf = Namespace(**conf_dict)
    print(conf)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    best_linear_shift = on_cuda(LinearShift(conf))
    linear_ckpt = torch.load(conf.linear_model_save_path)
    best_linear_shift.load_state_dict(linear_ckpt)
    best_linear_shift.eval()

    _, _, _, vocab = get_gyafc(conf)
    ckpt = torch.load(conf.vae_model_path)
    vae, _ = create_vae(conf, vocab)
    vae.load_state_dict(ckpt['vae_dict'])
    vae.eval()
    del ckpt, linear_ckpt

    test = get_informal_test_set(conf, vocab)

    all_strings = []
    for batch in test:
        print('New Batch')
        batch = on_cuda(batch.T)
        mu, logvar = vae.encode(batch)
        new_mu, new_logvar = best_linear_shift(mu, logvar)
Exemple #6
0
def train():
    # data loading
    # train_iter, test_iter, valid_iter, vocab = get_wiki2(conf)
    train_iter, test_iter, valid_iter, vocab = get_gyafc(conf)

    # create model, load weights if necessary
    if args.resume_training:
        step, start_epoch, vae, trainer_vae = load_ckpt(conf, save_path, vocab)
    else:
        start_epoch = 0
        step = 0
        vae, trainer_vae = create_vae(conf, vocab)

    all_t_rec_loss = []
    all_t_kl_loss = []
    all_t_loss = []
    all_v_rec_loss = []
    all_v_kl_loss = []
    all_v_loss = []

    # training epochs
    for epoch in tqdm.tqdm(range(start_epoch, conf.epochs), desc='Epochs'):
        vae.train()
        # logging
        train_rec_loss = []
        train_kl_loss = []
        train_loss = []

        for batch in train_iter:
            # batch is encoder input and target ouput for generator
            batch = on_cuda(batch.T)
            G_inp = create_g_input(batch, True, vocab, conf)
            rec_loss, kl_loss, elbo, kld_coef = train_batch(vae,
                                                            trainer_vae,
                                                            batch,
                                                            G_inp,
                                                            step,
                                                            conf,
                                                            train=True)
            train_rec_loss.append(rec_loss)
            train_kl_loss.append(kl_loss)
            train_loss.append(elbo)

            # log
            if args.to_train:
                writer.add_scalar('ELBO', elbo, step)
                writer.add_scalar('Cross Entropy', rec_loss, step)
                writer.add_scalar('KL Divergence Raw', kl_loss, step)
                writer.add_scalar('KL Annealed Weight', kld_coef, step)
                writer.add_scalar('KL Divergence Weighted', kl_loss * kld_coef,
                                  step)

            # increment step
            step += 1

        # valid
        vae.eval()
        valid_rec_loss = []
        valid_kl_loss = []
        valid_loss = []

        for valid_batch in valid_iter:
            valid_batch = on_cuda(valid_batch.T)
            G_inp = create_g_input(valid_batch, True, vocab, conf)
            with torch.autograd.no_grad():
                rec_loss, kl_loss, elbo, kld_coef = train_batch(vae,
                                                                trainer_vae,
                                                                valid_batch,
                                                                G_inp,
                                                                step,
                                                                conf,
                                                                train=False)
            valid_rec_loss.append(rec_loss)
            valid_kl_loss.append(kl_loss)
            valid_loss.append(elbo)

        all_t_rec_loss.append(train_rec_loss)
        all_t_kl_loss.append(train_kl_loss)
        all_t_loss.append(train_loss)
        all_v_rec_loss.append(valid_rec_loss)
        all_v_kl_loss.append(valid_kl_loss)
        all_v_loss.append(valid_loss)
        mean_t_rec_loss = np.mean(train_rec_loss)
        mean_t_kl_loss = np.mean(train_kl_loss)
        mean_t_loss = np.mean(train_loss)
        mean_v_rec_loss = np.mean(valid_rec_loss)
        mean_v_kl_loss = np.mean(valid_kl_loss)
        mean_v_loss = np.mean(valid_loss)

        # loss_log.set_description_str(f'T_rec: ' + '%.2f'%mean_t_rec_loss +
        #     ' T_kld: ' + '%.2f'%mean_t_kl_loss + ' V_rec: ' +
        #     '%.2f'%mean_v_rec_loss + ' V_kld: ' + '%.2f'%mean_v_kl_loss)
        tqdm.tqdm.write(f'T_rec: ' + '%.2f' % mean_t_rec_loss + ' T_kld: ' +
                        '%.2f' % mean_t_kl_loss + ' T_ELBO: ' +
                        '%.2f' % mean_t_loss + ' V_rec: ' +
                        '%.2f' % mean_v_rec_loss + ' V_kld: ' +
                        '%.2f' % mean_v_kl_loss + ' V_ELBO: ' +
                        '%.2f' % mean_v_loss + ' kld_coef: ' +
                        '%.2f' % kld_coef)

        if epoch % 5 == 0:
            torch.save(
                {
                    'epoch': epoch + 1,
                    'vae_dict': vae.state_dict(),
                    'vae_trainer': trainer_vae.state_dict(),
                    'step': step
                }, save_path)

            # NOTE: npz path, still messed up, overwrites with the latest 5 when resume training
            # np.savez_compressed('data/losses_log/losses_wiki2_fixed.npz',
            #                     t_rec=np.array(all_t_rec_loss),
            #                     t_kl=np.array(all_t_kl_loss),
            #                     v_rec=np.array(all_v_rec_loss),
            #                     v_kl=np.array(all_v_kl_loss))

            np.savez_compressed(
                'data/losses_log/losses_gyafc_weightfix3_nodropout_25000crossover_long_0.0005k.npz',
                t_rec=np.array(all_t_rec_loss),
                t_kl=np.array(all_t_kl_loss),
                t_elbo=np.array(all_t_loss),
                v_rec=np.array(all_v_rec_loss),
                v_kl=np.array(all_v_kl_loss),
                v_elbo=np.array(all_v_loss))