예제 #1
0
def get_vae(model_encoder,
            model_decoder,
            tokenizer_encoder,
            tokenizer_decoder,
            beta=1):
    ArgsObj = namedtuple("Args", ["latent_size", "device", "fb_mode", "beta"])
    args = ArgsObj(latent_size=LATENT_SIZE_LARGE,
                   device=get_device(),
                   fb_mode=0,
                   beta=beta)

    checkpoint_full_dir = os.path.join(OUTPUT_DIR, "checkpoint-full-31250")
    if not torch.cuda.is_available():
        checkpoint = torch.load(os.path.join(checkpoint_full_dir,
                                             "training.bin"),
                                map_location="cpu")
    else:
        checkpoint = torch.load(
            os.path.join(checkpoint_full_dir, "training.bin"))

    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder,
                    tokenizer_decoder, args)
    model_vae.load_state_dict(checkpoint["model_state_dict"])
    # logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device)
    return model_vae
def main():
    # Load MNIST image dataset
    mnist_train_data = datasets.MNIST(
        '/home/ajays/Downloads/',download=True,transform=transforms.ToTensor()
    )
    mnist_test_data = datasets.MNIST('/home/ajays/Downloads/',train=False,download=True)

    train_loader = torch.utils.data.DataLoader(
        mnist_train_data, batch_size = batch_size, shuffle=True
    )

    # Instantiation
    vae = VAE(n_inputs=32)

    # *********************
    # IMAGE VAE TRAINING
    # *********************
    # plot before training
    # o_before, mu, logvar = vae(mnist_train_data[0][0].reshape((1,1,28,28)))
    # plt.imshow(o_before.detach().numpy().reshape((28,28)))
    # plt.show()

    # train
    vae.load_state_dict(torch.load(LOAD_PATH))
    #vae = train_image_vae(vae, train_loader)

    # After training
    # o_after, mu, logvar = vae(example[0].reshape((1,1,28,28)))
    o_after = vae.decode(torch.randn((128)))
    plt.imshow(o_after.detach().numpy().reshape((28,28)))
    plt.show()
예제 #3
0
def get_vae_recons(loader, hidden_size=256):

    model = VAE(3, hidden_size, hidden_size).to(DEVICE)

    ckpt = torch.load("./models/imagenet_hs_128_256_vae.pt")
    model.load_state_dict(ckpt)
    args = type('', (), {})()
    args.device = DEVICE
    gen_img, _ = next(iter(loader))
    # grid = make_grid(gen_img.cpu(), nrow=8)
    # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size))
    #exit()

    reconstruction = vae.generate_samples(gen_img, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8)

    return grid
예제 #4
0
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    if args.enc_type == 'lstm':
        encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device
    vae = VAE(encoder, decoder, args).to(device)

    if args.eval:
        print('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            print("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)
            calc_iwnll(vae, test_data_batch, args)

        return

    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    opt_dict['lr'] = 1.0

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)
    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                id_ = np.random.random_integers(0, len(train_data_batch) - 1)

                batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

                # if sub_iter >= 30:
                #     break

            # print(sub_iter)

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    with torch.no_grad():
                        mi = calc_mi(vae, val_data_batch)
                        au, _ = calc_au(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'au %d, time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, au, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi))
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

        print('kl weight %.4f' % kl_weight)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
            au, au_var = calc_au(vae, val_data_batch)
            print("%d active units" % au)
            # print(au_var)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch and epoch >= 15:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                             args)

        vae.train()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        print("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
예제 #5
0
class VAESampler:
    def __init__(self, decode_from, params, cuda=False):
        self.decode_from = decode_from
        self.params = params
        params.enc_nh = params.dec_nh  # not sure why this is necessary...

        self.train_data = MonoTextData(params.train_data, label=False)
        self.vocab = self.train_data.vocab
        self.vocab_size = len(self.vocab)

        # do I need these?
        model_init = uniform_initializer(0.01)
        emb_init = uniform_initializer(0.1)

        params.device = self.device = torch.device("cuda" if cuda else "cpu")

        self.encoder = LSTMEncoder(params, self.vocab_size, model_init,
                                   emb_init)
        self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init)

        self.vae = VAE(self.encoder, self.decoder, params).to(params.device)

        # assuming models were trained on a gpu...
        if cuda:
            self.vae.load_state_dict(torch.load(self.decode_from))
        else:
            self.vae.load_state_dict(
                torch.load(self.decode_from, map_location='cpu'))

    def to_s(self, decoded):
        return [' '.join(item) for item in decoded]

    def beam(self, z, K=5):
        decoded_batch = self.vae.decoder.beam_search_decode(z, K)
        return self.to_s(decoded_batch)

    def sample(self, z, temperature=1.0):
        decoded_batch = self.vae.decoder.sample_decode(z, temperature)
        return self.to_s(decoded_batch)

    def greedy(self, z):
        decoded_batch = self.vae.decoder.greedy_decode(z)
        return self.to_s(decoded_batch)

    def str2ids(self, s):
        "encode string s as list of word ids"
        raise NotImplemented

    def encode(self, t):
        """
        Returns (z, mu, log_var) from encoder given list of strings.

        z is a sample from gaussian specified with (mu, log_var)
        """
        str_ids = []
        for s in t:
            ids = self.str2ids(s)
            str_ids.append(ids)
        tensor = self.train_data._to_tensor(str_ids, True, self.device)[0]
        z, (mu, log_var) = self.vae.encoder.sample(tensor, 1)
        return z, mu, log_var

    def z(self, t):
        "return sampled latent zs for list of strings t"
        z, mu, logvar = self.encode(t)
        return z.squeeze(1)

    def mu(self, t):
        "return mean of latent gaussian for list of strings t"
        z, mu, logvar = self.encode(t)
        return mu.squeeze(1)
예제 #6
0
def main(args):
    if args.save_path == '':
        make_savepath(args)
        seed(args)

    if args.cuda:
        print('using cuda')

    print(args)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    all_data = torch.load(args.data_file)
    x_train, x_val, x_test = all_data

    x_train = x_train.to(device)
    x_val = x_val.to(device)
    x_test = x_test.to(device)
    y_size = 1
    y_train = x_train.new_zeros(x_train.size(0), y_size)
    y_val = x_train.new_zeros(x_val.size(0), y_size)
    y_test = x_train.new_zeros(x_test.size(0), y_size)
    print(torch.__version__)
    train_data = torch.utils.data.TensorDataset(x_train, y_train)
    val_data = torch.utils.data.TensorDataset(x_val, y_val)
    test_data = torch.utils.data.TensorDataset(x_test, y_test)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    print('Train data: %d batches' % len(train_loader))
    print('Val data: %d batches' % len(val_loader))
    print('Test data: %d batches' % len(test_loader))
    sys.stdout.flush()

    log_niter = len(train_loader) // 5

    encoder = ResNetEncoderV2(args)
    decoder = PixelCNNDecoderV2(args)

    vae = VAE(encoder, decoder, args).to(device)

    if args.sample_from != '':
        save_dir = "samples/%s" % args.dataset
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        vae.load_state_dict(torch.load(args.sample_from))
        vae.eval()
        with torch.no_grad():
            sample_z = vae.sample_from_prior(400).to(device)
            sample_x, sample_probs = vae.decode(sample_z, False)
        image_file = 'sample_binary_from_%s.png' % (
            args.sample_from.split('/')[-1][:-3])
        save_image(sample_x.data.cpu(),
                   os.path.join(save_dir, image_file),
                   nrow=20)
        image_file = 'sample_cont_from_%s.png' % (
            args.sample_from.split('/')[-1][:-3])
        save_image(sample_probs.data.cpu(),
                   os.path.join(save_dir, image_file),
                   nrow=20)

        return

    if args.eval:
        print('begin evaluation')
        test_loader = torch.utils.data.DataLoader(test_data,
                                                  batch_size=50,
                                                  shuffle=True)
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test(vae, test_loader, "TEST", args)
            au, au_var = calc_au(vae, test_loader)
            print("%d active units" % au)
            # print(au_var)

            calc_iwnll(vae, test_loader, args)

        return

    enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
    dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
    opt_dict['lr'] = 0.001

    iter_ = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    decay_cnt = pre_mi = best_mi = mi_not_improved = 0
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader))

    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_examples = 0
        for datum in train_loader:
            batch_data, _ = datum
            batch_data = torch.bernoulli(batch_data)
            batch_size = batch_data.size(0)

            report_num_examples += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_examples = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_num_examples += batch_data_enc.size(0)
                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                id_ = np.random.choice(x_train.size(0),
                                       args.batch_size,
                                       replace=False)

                batch_data_enc = torch.bernoulli(x_train[id_])

                if sub_iter % 10 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_examples
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_examples = 0

                sub_iter += 1

            # print(sub_iter)

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_examples
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    with torch.no_grad():
                        mi = calc_mi(vae, val_loader)
                        au, _ = calc_au(vae, val_loader)

                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'au %d, time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi,
                           report_rec_loss / report_num_examples, au, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                          'time elapsed %.2fs' %
                          (epoch, iter_, train_loss, report_kl_loss / report_num_examples,
                          report_rec_loss / report_num_examples, time.time() - start))
                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_examples = 0

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_loader)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_loader)
                vae.train()
                if cur_mi - best_mi < 0:
                    mi_not_improved += 1
                    if mi_not_improved == 5:
                        aggressive_flag = False
                        print("STOP BURNING")

                else:
                    best_mi = cur_mi

                pre_mi = cur_mi

        print('kl weight %.4f' % kl_weight)
        print('epoch: %d, VAL' % epoch)

        vae.eval()

        with torch.no_grad():
            loss, nll, kl = test(vae, val_loader, "VAL", args)
            au, au_var = calc_au(vae, val_loader)
            print("%d active units" % au)
            # print(au_var)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            torch.save(vae.state_dict(), args.save_path)

        if loss > best_loss:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                decay_cnt += 1
                print('new lr: %f' % opt_dict["lr"])
                enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                           lr=opt_dict["lr"])
                dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                           lr=opt_dict["lr"])
        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl = test(vae, test_loader, "TEST", args)

        vae.train()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))
    vae.eval()
    with torch.no_grad():
        loss, nll, kl = test(vae, test_loader, "TEST", args)
        au, au_var = calc_au(vae, test_loader)
        print("%d active units" % au)
        # print(au_var)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=50,
                                              shuffle=True)

    with torch.no_grad():
        calc_iwnll(vae, test_loader, args)
예제 #7
0
def main(args):
    train_data = MonoTextData(args.train_data, label=args.label)
    vocab = train_data.vocab
    vocab_size = len(vocab)
    
    vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt")
    with open(vocab_path, "w") as fout:
        for i in range(vocab_size):
            fout.write("{}\n".format(vocab.id2word(i)))
        #return

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data)//args.batch_size)//10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    print('begin evaluation')
    vae.load_state_dict(torch.load(args.load_path))
    vae.eval()
    with torch.no_grad():
        test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size,
                                                      device=device,
                                                      batch_first=True)

        # test(vae, test_data_batch, "TEST", args)
        # au, au_var = calc_au(vae, test_data_batch)
        # print("%d active units" % au)

        train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size,
                                                        device=device,
                                                        batch_first=True)

        val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

        print("getting  vectors for training")
        print(args.save_dir)
        save_latents(args, vae, train_data_batch, train_batch_labels, "train")
        print("getting  vectors for validating")
        save_latents(args, vae, val_data_batch, val_batch_labels, "val")
        print("getting  vectors for testing")
        save_latents(args, vae, test_data_batch, test_batch_labels, "test")
예제 #8
0
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, vocab=vocab)
    test_data = MonoTextData(args.test_data, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device

    encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
    args.enc_nh = args.dec_nh

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args).to(device)

    if args.optim == 'sgd':
        enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0)
        dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0)
        opt_dict['lr'] = 1.0
    else:
        enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        opt_dict['lr'] = 0.001

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = -1
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    plot_data = train_data.data_sample(nsample=args.num_plot,
                                       device=device,
                                       batch_first=True)

    if args.plot_mode == 'multiple':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_multiple

    elif args.plot_mode == 'single':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_single
        posterior_mean = []
        infer_mean = []

        posterior_mean.append(
            vae.calc_model_posterior_mean(plot_data[0], grid_z))
        infer_mean.append(vae.calc_infer_mean(plot_data[0]))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            if args.plot_mode == "single":
                batch_data, _ = plot_data

            else:
                batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                if args.plot_mode == "single":
                    batch_data_enc, _ = plot_data

                else:
                    id_ = np.random.random_integers(0,
                                                    len(train_data_batch) - 1)

                    batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

            if args.plot_mode == 'single' and epoch == 0 and aggressive_flag:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(posterior_mean[-1])
                    infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()
            if args.plot_mode == 'single' and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(
                        vae.calc_model_posterior_mean(plot_data[0], grid_z))

                    if aggressive_flag:
                        infer_mean.append(infer_mean[-1])
                    else:
                        infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    mi = calc_mi(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            if iter_ % args.plot_niter == 0 and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    if args.plot_mode == 'single' and iter_ != 0:
                        plot_fn(infer_mean, posterior_mean, args)
                        return
                    elif args.plot_mode == "multiple":
                        plot_fn(vae, plot_data, grid_z, iter_, args)
                vae.train()

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

                # return

        print('kl weight %.4f' % kl_weight)
        print('epoch: %d, VAL' % epoch)

        with torch.no_grad():
            plot_fn(vae, plot_data, grid_z, iter_, args)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.optim == 'sgd':
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"])
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"])
                else:
                    enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
                    dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args)

        vae.train()

    print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \
          % (best_loss, best_kl, best_nll, best_ppl))

    sys.stdout.flush()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))
    vae.eval()

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
예제 #9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help=
        "An input evaluation data file to evaluate the perplexity on (a text file)."
    )
    parser.add_argument("--checkpoint_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The directory where checkpoints are saved.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--dataset",
                        default='Snli',
                        type=str,
                        help="The dataset.")

    ## Variational auto-encoder
    parser.add_argument("--latent_size",
                        default=32,
                        type=int,
                        help="Latent space dimension.")
    parser.add_argument("--total_sents",
                        default=10,
                        type=int,
                        help="Total sentences to test recontruction.")
    parser.add_argument("--num_interpolation_steps",
                        default=10,
                        type=int,
                        help="Total sentences to test recontruction.")
    parser.add_argument("--play_mode",
                        default="interpolation",
                        type=str,
                        help="interpolation or reconstruction.")

    ## Encoder options
    parser.add_argument(
        "--encoder_model_type",
        default="bert",
        type=str,
        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--encoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--encoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--encoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    ## Decoder options
    parser.add_argument(
        "--decoder_model_type",
        default="gpt2",
        type=str,
        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--decoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--decoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--decoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    parser.add_argument("--per_gpu_train_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gloabl_step_eval',
                        type=int,
                        default=661,
                        help="Evaluate the results at the given global step")

    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length"
    )

    # Interact with users
    parser.add_argument("--interact_with_user_input",
                        action='store_true',
                        help="Use user input to interact_with.")
    parser.add_argument("--sent_source", type=str, default="")
    parser.add_argument("--sent_target", type=str, default="")
    parser.add_argument("--sent_input", type=str, default="")
    parser.add_argument("--degree_to_target", type=float, default="1.0")

    ## Variational auto-encoder
    parser.add_argument("--nz",
                        default=32,
                        type=int,
                        help="Latent space dimension.")

    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--padding_text", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--use_philly",
                        action='store_true',
                        help="Use Philly for computing.")

    args = parser.parse_args()

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    set_seed(args)

    args.encoder_model_type = args.encoder_model_type.lower()
    args.decoder_model_type = args.decoder_model_type.lower()

    global_step = args.gloabl_step_eval

    output_encoder_dir = os.path.join(
        args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
    output_decoder_dir = os.path.join(
        args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
    checkpoints = [[output_encoder_dir, output_decoder_dir]]
    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    # Load a trained Encoder model and vocabulary that you have fine-tuned
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[
        args.encoder_model_type]
    model_encoder = encoder_model_class.from_pretrained(
        output_encoder_dir, latent_size=args.latent_size)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(
        args.encoder_tokenizer_name
        if args.encoder_tokenizer_name else args.encoder_model_name_or_path,
        do_lower_case=args.do_lower_case)

    model_encoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_encoder.max_len_single_sentence)

    # Load a trained Decoder model and vocabulary that you have fine-tuned
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[
        args.decoder_model_type]
    model_decoder = decoder_model_class.from_pretrained(
        output_decoder_dir, latent_size=args.latent_size)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(
        args.decoder_tokenizer_name
        if args.decoder_tokenizer_name else args.decoder_model_name_or_path,
        do_lower_case=args.do_lower_case)
    model_decoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_decoder.max_len_single_sentence)

    # Load full model
    output_full_dir = os.path.join(args.checkpoint_dir,
                                   'checkpoint-full-{}'.format(global_step))
    checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'),
                            map_location=torch.device('cpu'))

    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {
        'pad_token': '<PAD>',
        'bos_token': '<BOS>',
        'eos_token': '<EOS>'
    }
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(
        len(tokenizer_decoder)
    )  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    # Evaluation
    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder,
                    tokenizer_decoder, args)
    model_vae.load_state_dict(checkpoint['model_state_dict'])
    logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device)

    if args.interact_with_user_input:

        if args.play_mode == 'interpolation':
            if len(args.sent_source) > 0 and len(args.sent_source) > 0:
                result = interpolate(model_vae, tokenizer_encoder,
                                     tokenizer_decoder, args)
            else:
                print('Please check: specify the source and target sentences!')

        if args.play_mode == 'analogy':
            if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len(
                    args.sent_input) > 0:
                result = analogy(model_vae, tokenizer_encoder,
                                 tokenizer_decoder, args)
            else:
                print(
                    'Please check: specify the source, target and input analogy sentences!'
                )

    else:
        result = evaluate_latent_space(args,
                                       model_vae,
                                       tokenizer_encoder,
                                       tokenizer_decoder,
                                       prefix=global_step)
예제 #10
0
def main(args):
    global logging
    debug = (args.reconstruct_from != ""
             or args.eval == True)  # don't make exp dir for reconstruction
    logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug)

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

        if args.reset_dec:
            vae.decoder.reset_parameters(model_init, emb_init)

    if args.eval:
        logging('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            logging("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)

            nll, ppl = calc_iwnll(vae, test_data_batch, args)
            logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))

        return

    if args.reconstruct_from != "":
        print("begin decoding")
        sys.stdout.flush()

        vae.load_state_dict(torch.load(args.reconstruct_from))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)
            # test(vae, test_data_batch, "TEST", args)
            reconstruct(vae, test_data_batch, vocab, args.decoding_strategy,
                        args.reconstruct_to)

        return

    if args.opt == "sgd":
        enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
        dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    vae.train()
    start = time.time()

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in range(args.epochs):
            report_kl_loss = report_rec_loss = report_loss = 0
            report_num_words = report_num_sents = 0

            for i in np.random.permutation(len(train_data_batch)):

                batch_data = train_data_batch[i]
                batch_size, sent_len = batch_data.size()

                # not predict start symbol
                report_num_words += (sent_len - 1) * batch_size
                report_num_sents += batch_size

                kl_weight = args.beta

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                if args.iw_train_nsamples < 0:
                    loss, loss_rc, loss_kl = vae.loss(batch_data,
                                                      kl_weight,
                                                      nsamples=args.nsamples)
                else:
                    loss, loss_rc, loss_kl = vae.loss_iw(
                        batch_data,
                        kl_weight,
                        nsamples=args.iw_train_nsamples,
                        ns=ns)
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                loss_rc = loss_rc.sum()
                loss_kl = loss_kl.sum()

                enc_optimizer.step()
                dec_optimizer.step()

                report_rec_loss += loss_rc.item()
                report_kl_loss += loss_kl.item()
                report_loss += loss.item() * batch_size

                if iter_ % log_niter == 0:
                    #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                    train_loss = report_loss / report_num_sents
                    logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs, kl_weight %.4f' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start, kl_weight))

                    #sys.stdout.flush()

                    report_rec_loss = report_kl_loss = report_loss = 0
                    report_num_words = report_num_sents = 0

                iter_ += 1

            logging('kl weight %.4f' % kl_weight)

            vae.eval()
            with torch.no_grad():
                loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
                au, au_var = calc_au(vae, val_data_batch)
                logging("%d active units" % au)
                # print(au_var)

            if args.save_ckpt > 0 and epoch <= args.save_ckpt:
                logging('save checkpoint')
                torch.save(
                    vae.state_dict(),
                    os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt'))

            if loss < best_loss:
                logging('update best loss')
                best_loss = loss
                best_nll = nll
                best_kl = kl
                best_ppl = ppl
                torch.save(vae.state_dict(), args.save_path)

            if loss > opt_dict["best_loss"]:
                opt_dict["not_improved"] += 1
                if opt_dict[
                        "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                    opt_dict["best_loss"] = loss
                    opt_dict["not_improved"] = 0
                    opt_dict["lr"] = opt_dict["lr"] * lr_decay
                    vae.load_state_dict(torch.load(args.save_path))
                    logging('new lr: %f' % opt_dict["lr"])
                    decay_cnt += 1
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)

            else:
                opt_dict["not_improved"] = 0
                opt_dict["best_loss"] = loss

            if decay_cnt == max_decay:
                break

            if epoch % args.test_nepoch == 0:
                with torch.no_grad():
                    loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                                 args)

            if args.save_latent > 0 and epoch <= args.save_latent:
                visualize_latent(args, epoch, vae, "cuda", test_data)

            vae.train()

    except KeyboardInterrupt:
        logging('-' * 100)
        logging('Exiting from training early')

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        logging("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        nll, ppl = calc_iwnll(vae, test_data_batch, args)
        logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
예제 #11
0
def main(args):
    global logging
    logging = create_exp_dir(args.exp_dir, scripts_to_save=[])

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    with open(args.vocab_file) as fvocab:
        for i, line in enumerate(fvocab):
            vocab[line.strip()] = i

    vocab = VocabEntry(vocab)

    train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)

    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = max(1, (len(train_data) //
                        (args.batch_size * args.update_every)) // 10)

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.fb == 3:
        encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init,
                                           emb_init)
        args.enc_nh = args.dec_nh
    elif args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    # if args.eval:
    #     logging('begin evaluation')
    #     vae.load_state_dict(torch.load(args.load_path))
    #     vae.eval()
    #     with torch.no_grad():
    #         test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
    #                                                       device=device,
    #                                                       batch_first=True)

    #         test(vae, test_data_batch, test_labels_batch, "TEST", args)
    #         au, au_var = calc_au(vae, test_data_batch)
    #         logging("%d active units" % au)
    #         # print(au_var)

    #         test_data_batch = test_data.create_data_batch(batch_size=1,
    #                                                       device=device,
    #                                                       batch_first=True)
    #         calc_iwnll(vae, test_data_batch, args)

    #     return

    if args.discriminator == "linear":
        discriminator = LinearDiscriminator(args, vae.encoder).to(device)
    elif args.discriminator == "mlp":
        discriminator = MLPDiscriminator(args, vae.encoder).to(device)

    if args.opt == "sgd":
        optimizer = optim.SGD(discriminator.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    # best_kl = best_nll = best_ppl = 0
    # pre_mi = 0
    discriminator.train()
    start = time.time()

    # kl_weight = args.kl_start
    # if args.warm_up > 0:
    #     anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size))
    # else:
    #     anneal_rate = 0

    # dim_target_kl = args.target_kl / float(args.nz)

    train_data_batch, train_labels_batch = train_data.create_data_batch_labels(
        batch_size=args.batch_size, device=device, batch_first=True)

    val_data_batch, val_labels_batch = val_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    test_data_batch, test_labels_batch = test_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    acc_cnt = 1
    acc_loss = 0.
    for epoch in range(args.epochs):
        report_loss = 0
        report_correct = report_num_words = report_num_sents = 0
        acc_batch_size = 0
        optimizer.zero_grad()
        for i in np.random.permutation(len(train_data_batch)):

            batch_data = train_data_batch[i]
            if batch_data.size(0) < 2:
                continue
            batch_labels = train_labels_batch[i]
            batch_labels = [int(x) for x in batch_labels]

            batch_labels = torch.tensor(batch_labels,
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device=device)

            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size
            report_num_sents += batch_size
            acc_batch_size += batch_size

            # (batch_size)
            loss, correct = discriminator.get_performance(
                batch_data, batch_labels)

            acc_loss = acc_loss + loss.sum()

            if acc_cnt % args.update_every == 0:
                acc_loss = acc_loss / acc_batch_size
                acc_loss.backward()

                torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                               clip_grad)

                optimizer.step()
                optimizer.zero_grad()

                acc_cnt = 0
                acc_loss = 0
                acc_batch_size = 0

            acc_cnt += 1
            report_loss += loss.sum().item()
            report_correct += correct

            if iter_ % log_niter == 0:
                #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                train_loss = report_loss / report_num_sents


                logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \
                       'time %.2fs' %
                       (epoch, iter_, train_loss, report_correct / report_num_sents,
                        time.time() - start))

                #sys.stdout.flush()

            iter_ += 1

        logging('lr {}'.format(opt_dict["lr"]))
        print(report_num_sents)
        discriminator.eval()

        with torch.no_grad():
            loss, acc = test(discriminator, val_data_batch, val_labels_batch,
                             "VAL", args)
            # print(au_var)

        if loss < best_loss:
            logging('update best loss')
            best_loss = loss
            best_acc = acc
            print(args.save_path)
            torch.save(discriminator.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict[
                    "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                discriminator.load_state_dict(torch.load(args.save_path))
                logging('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.opt == "sgd":
                    optimizer = optim.SGD(discriminator.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                    opt_dict['lr'] = opt_dict["lr"]
                elif args.opt == "adam":
                    optimizer = optim.Adam(discriminator.parameters(),
                                           lr=opt_dict["lr"])
                    opt_dict['lr'] = opt_dict["lr"]
                else:
                    raise ValueError("optimizer not supported")

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, acc = test(discriminator, test_data_batch,
                                 test_labels_batch, "TEST", args)

        discriminator.train()

    # compute importance weighted estimate of log p(x)
    discriminator.load_state_dict(torch.load(args.save_path))
    discriminator.eval()

    with torch.no_grad():
        loss, acc = test(discriminator, test_data_batch, test_labels_batch,
                         "TEST", args)
예제 #12
0
def main(args, args_model):
    global logging
    logging = get_logger_existing_dir(os.path.dirname(args.load_path),
                                      'log_classifier.txt')

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    if getattr(args, 'vocab_file', None) is not None:
        with open(args.vocab_file) as fvocab:
            for i, line in enumerate(fvocab):
                vocab[line.strip()] = i

        vocab = VocabEntry(vocab)

    filename_glob = args.train_data + '.seed_*.n_' + str(
        args.num_label_per_class)
    train_sets = glob.glob(filename_glob)
    print("Train sets:", train_sets)

    main_train_data = MonoTextData(args.train_data,
                                   label=args.label,
                                   vocab=vocab)
    vocab = main_train_data.vocab
    vocab_size = len(vocab)

    logging('finish reading datasets, vocab size is %d' % len(vocab))
    #sys.stdout.flush()

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args_model.device = device

    if args_model.enc_type == 'lstm':
        args_model.pooling = getattr(args_model, 'pooling', None)
        encoder = GaussianLSTMEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            pooling=args_model.pooling,
        )

    elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']:
        args_model.skip_first_word = getattr(args_model, 'skip_first_word',
                                             None)
        encoder = GaussianPoolEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            enc_type=args_model.enc_type,
            skip_first_word=args_model.skip_first_word)
        #args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    args_model.encode_length = getattr(args_model, 'encode_length', None)
    if args_model.dec_type == 'lstm':
        decoder = LSTMDecoder(args_model, vocab, model_init, emb_init,
                              args_model.encode_length)
    elif args_model.dec_type == 'unigram':
        decoder = UnigramDecoder(args_model, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args_model,
              args_model.encode_length).to(device)

    if args.load_path:
        print("load args!")
        print(vae)
        loaded_state_dict = torch.load(args.load_path)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    vae.eval()

    def preprocess(data_fn):
        codes, labels = read_dataset(data_fn, vocab, device, vae,
                                     args.classify_using_samples)
        if args.classify_using_samples:
            is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2)
            codes = augment_dataset(codes, 1, is_gaussian_enc,
                                    vae)  # use only 1 sample for test
        codes = codes.cpu().numpy()
        labels = labels.cpu().numpy()
        return codes, labels

    test_codes, test_labels = preprocess(args.test_data)

    test_f1_scores = []
    average_f1 = 'macro'
    f1_scorer = make_scorer(f1_score,
                            average=average_f1,
                            labels=np.unique(test_labels),
                            greater_is_better=True)
    # log loss: negative log likelihood. We should minimize that, so greater_is_better=False
    log_loss_scorer = make_scorer(log_loss,
                                  needs_proba=True,
                                  greater_is_better=False)
    warnings.filterwarnings('ignore')
    results = {
        'n_samples_per_class': args.num_label_per_class,
    }
    n_repeats = args.n_repeats

    n_splits = min(args.num_label_per_class, 5)
    for i, fn in enumerate(train_sets):
        codes, labels = preprocess(fn)
        if args.resample > 1:
            # going to augment the training set by sampling
            # then create a new cross validation function to get the correct indices
            cross_val = augment_cross_val(labels, args.resample, n_splits,
                                          n_repeats)
            labels = np.repeat(labels, args.resample)
        else:
            cross_val = RepeatedStratifiedKFold(n_splits=n_splits,
                                                n_repeats=n_repeats)

        scaler = StandardScaler()
        codes = scaler.fit_transform(codes)
        scaled_test_codes = scaler.transform(test_codes)
        gridsearch = GridSearchCV(
            LogisticRegression(solver='sag', multi_class='auto'),
            {
                "penalty": ['l2'],
                "C": [0.01, 0.1, 1, 10, 100],
            },
            cv=cross_val,
            scoring={
                "f1": f1_scorer,
                "log": log_loss_scorer,
            },
            refit=False,
        )
        clf = gridsearch
        clf.fit(codes, labels)
        crossval_f1, test_f1 = refit_and_eval(
            'f1',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            f1_scorer,
        )
        crossval_log, test_log_loss = refit_and_eval(
            'log',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            log_loss_scorer,
        )
        results[i] = {
            "F1": {
                'crossval': crossval_f1,
                'test': test_f1
            },
            "log": {
                'crossval': crossval_log,
                'test': test_log_loss
            },
        }
        print(results[i])

    if args.classify_using_samples:
        n_per_class = str(args.num_label_per_class)
        resample = 1 if args.resample == -1 else args.resample
        output_fn = os.path.join(
            args.exp_dir,
            'results_sample_' + str(resample) + '_' + n_per_class + '.json')
    else:
        output_fn = os.path.join(args.exp_dir,
                                 'results_' + n_per_class + '.json')
    with open(output_fn, 'w') as f:
        json.dump(results, f)