示例#1
0
    def __init__(self, decode_from, params, cuda=False):
        self.decode_from = decode_from
        self.params = params
        params.enc_nh = params.dec_nh  # not sure why this is necessary...

        self.train_data = MonoTextData(params.train_data, label=False)
        self.vocab = self.train_data.vocab
        self.vocab_size = len(self.vocab)

        # do I need these?
        model_init = uniform_initializer(0.01)
        emb_init = uniform_initializer(0.1)

        params.device = self.device = torch.device("cuda" if cuda else "cpu")

        self.encoder = LSTMEncoder(params, self.vocab_size, model_init,
                                   emb_init)
        self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init)

        self.vae = VAE(self.encoder, self.decoder, params).to(params.device)

        # assuming models were trained on a gpu...
        if cuda:
            self.vae.load_state_dict(torch.load(self.decode_from))
        else:
            self.vae.load_state_dict(
                torch.load(self.decode_from, map_location='cpu'))
示例#2
0
def create_model(args, vocab):
    # build initializers
    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    # build encoder
    if args.enc_type == 'lstm':
        encoder = LSTMEncoder(args, args.vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")
    
    # build decoder
    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args).to(args.device)
    
    return vae
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    if args.enc_type == 'lstm':
        encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device
    vae = VAE(encoder, decoder, args).to(device)

    if args.eval:
        print('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            print("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)
            calc_iwnll(vae, test_data_batch, args)

        return

    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    opt_dict['lr'] = 1.0

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)
    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                id_ = np.random.random_integers(0, len(train_data_batch) - 1)

                batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

                # if sub_iter >= 30:
                #     break

            # print(sub_iter)

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    with torch.no_grad():
                        mi = calc_mi(vae, val_data_batch)
                        au, _ = calc_au(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'au %d, time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, au, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi))
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

        print('kl weight %.4f' % kl_weight)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
            au, au_var = calc_au(vae, val_data_batch)
            print("%d active units" % au)
            # print(au_var)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch and epoch >= 15:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                             args)

        vae.train()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        print("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
示例#4
0
def main(args):
    train_data = MonoTextData(args.train_data, label=args.label)
    vocab = train_data.vocab
    vocab_size = len(vocab)
    
    vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt")
    with open(vocab_path, "w") as fout:
        for i in range(vocab_size):
            fout.write("{}\n".format(vocab.id2word(i)))
        #return

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data)//args.batch_size)//10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    print('begin evaluation')
    vae.load_state_dict(torch.load(args.load_path))
    vae.eval()
    with torch.no_grad():
        test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size,
                                                      device=device,
                                                      batch_first=True)

        # test(vae, test_data_batch, "TEST", args)
        # au, au_var = calc_au(vae, test_data_batch)
        # print("%d active units" % au)

        train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size,
                                                        device=device,
                                                        batch_first=True)

        val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

        print("getting  vectors for training")
        print(args.save_dir)
        save_latents(args, vae, train_data_batch, train_batch_labels, "train")
        print("getting  vectors for validating")
        save_latents(args, vae, val_data_batch, val_batch_labels, "val")
        print("getting  vectors for testing")
        save_latents(args, vae, test_data_batch, test_batch_labels, "test")
示例#5
0
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, vocab=vocab)
    test_data = MonoTextData(args.test_data, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device

    encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
    args.enc_nh = args.dec_nh

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args).to(device)

    if args.optim == 'sgd':
        enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0)
        dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0)
        opt_dict['lr'] = 1.0
    else:
        enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        opt_dict['lr'] = 0.001

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = -1
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    plot_data = train_data.data_sample(nsample=args.num_plot,
                                       device=device,
                                       batch_first=True)

    if args.plot_mode == 'multiple':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_multiple

    elif args.plot_mode == 'single':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_single
        posterior_mean = []
        infer_mean = []

        posterior_mean.append(
            vae.calc_model_posterior_mean(plot_data[0], grid_z))
        infer_mean.append(vae.calc_infer_mean(plot_data[0]))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            if args.plot_mode == "single":
                batch_data, _ = plot_data

            else:
                batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                if args.plot_mode == "single":
                    batch_data_enc, _ = plot_data

                else:
                    id_ = np.random.random_integers(0,
                                                    len(train_data_batch) - 1)

                    batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

            if args.plot_mode == 'single' and epoch == 0 and aggressive_flag:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(posterior_mean[-1])
                    infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()
            if args.plot_mode == 'single' and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(
                        vae.calc_model_posterior_mean(plot_data[0], grid_z))

                    if aggressive_flag:
                        infer_mean.append(infer_mean[-1])
                    else:
                        infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    mi = calc_mi(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            if iter_ % args.plot_niter == 0 and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    if args.plot_mode == 'single' and iter_ != 0:
                        plot_fn(infer_mean, posterior_mean, args)
                        return
                    elif args.plot_mode == "multiple":
                        plot_fn(vae, plot_data, grid_z, iter_, args)
                vae.train()

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

                # return

        print('kl weight %.4f' % kl_weight)
        print('epoch: %d, VAL' % epoch)

        with torch.no_grad():
            plot_fn(vae, plot_data, grid_z, iter_, args)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.optim == 'sgd':
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"])
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"])
                else:
                    enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
                    dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args)

        vae.train()

    print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \
          % (best_loss, best_kl, best_nll, best_ppl))

    sys.stdout.flush()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))
    vae.eval()

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
示例#6
0
def main(args):
    global logging
    debug = (args.reconstruct_from != ""
             or args.eval == True)  # don't make exp dir for reconstruction
    logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug)

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

        if args.reset_dec:
            vae.decoder.reset_parameters(model_init, emb_init)

    if args.eval:
        logging('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            logging("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)

            nll, ppl = calc_iwnll(vae, test_data_batch, args)
            logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))

        return

    if args.reconstruct_from != "":
        print("begin decoding")
        sys.stdout.flush()

        vae.load_state_dict(torch.load(args.reconstruct_from))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)
            # test(vae, test_data_batch, "TEST", args)
            reconstruct(vae, test_data_batch, vocab, args.decoding_strategy,
                        args.reconstruct_to)

        return

    if args.opt == "sgd":
        enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
        dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    vae.train()
    start = time.time()

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in range(args.epochs):
            report_kl_loss = report_rec_loss = report_loss = 0
            report_num_words = report_num_sents = 0

            for i in np.random.permutation(len(train_data_batch)):

                batch_data = train_data_batch[i]
                batch_size, sent_len = batch_data.size()

                # not predict start symbol
                report_num_words += (sent_len - 1) * batch_size
                report_num_sents += batch_size

                kl_weight = args.beta

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                if args.iw_train_nsamples < 0:
                    loss, loss_rc, loss_kl = vae.loss(batch_data,
                                                      kl_weight,
                                                      nsamples=args.nsamples)
                else:
                    loss, loss_rc, loss_kl = vae.loss_iw(
                        batch_data,
                        kl_weight,
                        nsamples=args.iw_train_nsamples,
                        ns=ns)
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                loss_rc = loss_rc.sum()
                loss_kl = loss_kl.sum()

                enc_optimizer.step()
                dec_optimizer.step()

                report_rec_loss += loss_rc.item()
                report_kl_loss += loss_kl.item()
                report_loss += loss.item() * batch_size

                if iter_ % log_niter == 0:
                    #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                    train_loss = report_loss / report_num_sents
                    logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs, kl_weight %.4f' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start, kl_weight))

                    #sys.stdout.flush()

                    report_rec_loss = report_kl_loss = report_loss = 0
                    report_num_words = report_num_sents = 0

                iter_ += 1

            logging('kl weight %.4f' % kl_weight)

            vae.eval()
            with torch.no_grad():
                loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
                au, au_var = calc_au(vae, val_data_batch)
                logging("%d active units" % au)
                # print(au_var)

            if args.save_ckpt > 0 and epoch <= args.save_ckpt:
                logging('save checkpoint')
                torch.save(
                    vae.state_dict(),
                    os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt'))

            if loss < best_loss:
                logging('update best loss')
                best_loss = loss
                best_nll = nll
                best_kl = kl
                best_ppl = ppl
                torch.save(vae.state_dict(), args.save_path)

            if loss > opt_dict["best_loss"]:
                opt_dict["not_improved"] += 1
                if opt_dict[
                        "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                    opt_dict["best_loss"] = loss
                    opt_dict["not_improved"] = 0
                    opt_dict["lr"] = opt_dict["lr"] * lr_decay
                    vae.load_state_dict(torch.load(args.save_path))
                    logging('new lr: %f' % opt_dict["lr"])
                    decay_cnt += 1
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)

            else:
                opt_dict["not_improved"] = 0
                opt_dict["best_loss"] = loss

            if decay_cnt == max_decay:
                break

            if epoch % args.test_nepoch == 0:
                with torch.no_grad():
                    loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                                 args)

            if args.save_latent > 0 and epoch <= args.save_latent:
                visualize_latent(args, epoch, vae, "cuda", test_data)

            vae.train()

    except KeyboardInterrupt:
        logging('-' * 100)
        logging('Exiting from training early')

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        logging("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        nll, ppl = calc_iwnll(vae, test_data_batch, args)
        logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
示例#7
0
def main(args):
    global logging
    logging = create_exp_dir(args.exp_dir, scripts_to_save=[])

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    with open(args.vocab_file) as fvocab:
        for i, line in enumerate(fvocab):
            vocab[line.strip()] = i

    vocab = VocabEntry(vocab)

    train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)

    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = max(1, (len(train_data) //
                        (args.batch_size * args.update_every)) // 10)

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.fb == 3:
        encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init,
                                           emb_init)
        args.enc_nh = args.dec_nh
    elif args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    # if args.eval:
    #     logging('begin evaluation')
    #     vae.load_state_dict(torch.load(args.load_path))
    #     vae.eval()
    #     with torch.no_grad():
    #         test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
    #                                                       device=device,
    #                                                       batch_first=True)

    #         test(vae, test_data_batch, test_labels_batch, "TEST", args)
    #         au, au_var = calc_au(vae, test_data_batch)
    #         logging("%d active units" % au)
    #         # print(au_var)

    #         test_data_batch = test_data.create_data_batch(batch_size=1,
    #                                                       device=device,
    #                                                       batch_first=True)
    #         calc_iwnll(vae, test_data_batch, args)

    #     return

    if args.discriminator == "linear":
        discriminator = LinearDiscriminator(args, vae.encoder).to(device)
    elif args.discriminator == "mlp":
        discriminator = MLPDiscriminator(args, vae.encoder).to(device)

    if args.opt == "sgd":
        optimizer = optim.SGD(discriminator.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    # best_kl = best_nll = best_ppl = 0
    # pre_mi = 0
    discriminator.train()
    start = time.time()

    # kl_weight = args.kl_start
    # if args.warm_up > 0:
    #     anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size))
    # else:
    #     anneal_rate = 0

    # dim_target_kl = args.target_kl / float(args.nz)

    train_data_batch, train_labels_batch = train_data.create_data_batch_labels(
        batch_size=args.batch_size, device=device, batch_first=True)

    val_data_batch, val_labels_batch = val_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    test_data_batch, test_labels_batch = test_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    acc_cnt = 1
    acc_loss = 0.
    for epoch in range(args.epochs):
        report_loss = 0
        report_correct = report_num_words = report_num_sents = 0
        acc_batch_size = 0
        optimizer.zero_grad()
        for i in np.random.permutation(len(train_data_batch)):

            batch_data = train_data_batch[i]
            if batch_data.size(0) < 2:
                continue
            batch_labels = train_labels_batch[i]
            batch_labels = [int(x) for x in batch_labels]

            batch_labels = torch.tensor(batch_labels,
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device=device)

            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size
            report_num_sents += batch_size
            acc_batch_size += batch_size

            # (batch_size)
            loss, correct = discriminator.get_performance(
                batch_data, batch_labels)

            acc_loss = acc_loss + loss.sum()

            if acc_cnt % args.update_every == 0:
                acc_loss = acc_loss / acc_batch_size
                acc_loss.backward()

                torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                               clip_grad)

                optimizer.step()
                optimizer.zero_grad()

                acc_cnt = 0
                acc_loss = 0
                acc_batch_size = 0

            acc_cnt += 1
            report_loss += loss.sum().item()
            report_correct += correct

            if iter_ % log_niter == 0:
                #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                train_loss = report_loss / report_num_sents


                logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \
                       'time %.2fs' %
                       (epoch, iter_, train_loss, report_correct / report_num_sents,
                        time.time() - start))

                #sys.stdout.flush()

            iter_ += 1

        logging('lr {}'.format(opt_dict["lr"]))
        print(report_num_sents)
        discriminator.eval()

        with torch.no_grad():
            loss, acc = test(discriminator, val_data_batch, val_labels_batch,
                             "VAL", args)
            # print(au_var)

        if loss < best_loss:
            logging('update best loss')
            best_loss = loss
            best_acc = acc
            print(args.save_path)
            torch.save(discriminator.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict[
                    "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                discriminator.load_state_dict(torch.load(args.save_path))
                logging('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.opt == "sgd":
                    optimizer = optim.SGD(discriminator.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                    opt_dict['lr'] = opt_dict["lr"]
                elif args.opt == "adam":
                    optimizer = optim.Adam(discriminator.parameters(),
                                           lr=opt_dict["lr"])
                    opt_dict['lr'] = opt_dict["lr"]
                else:
                    raise ValueError("optimizer not supported")

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, acc = test(discriminator, test_data_batch,
                                 test_labels_batch, "TEST", args)

        discriminator.train()

    # compute importance weighted estimate of log p(x)
    discriminator.load_state_dict(torch.load(args.save_path))
    discriminator.eval()

    with torch.no_grad():
        loss, acc = test(discriminator, test_data_batch, test_labels_batch,
                         "TEST", args)
示例#8
0
def main(args, args_model):
    global logging
    logging = get_logger_existing_dir(os.path.dirname(args.load_path),
                                      'log_classifier.txt')

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    if getattr(args, 'vocab_file', None) is not None:
        with open(args.vocab_file) as fvocab:
            for i, line in enumerate(fvocab):
                vocab[line.strip()] = i

        vocab = VocabEntry(vocab)

    filename_glob = args.train_data + '.seed_*.n_' + str(
        args.num_label_per_class)
    train_sets = glob.glob(filename_glob)
    print("Train sets:", train_sets)

    main_train_data = MonoTextData(args.train_data,
                                   label=args.label,
                                   vocab=vocab)
    vocab = main_train_data.vocab
    vocab_size = len(vocab)

    logging('finish reading datasets, vocab size is %d' % len(vocab))
    #sys.stdout.flush()

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args_model.device = device

    if args_model.enc_type == 'lstm':
        args_model.pooling = getattr(args_model, 'pooling', None)
        encoder = GaussianLSTMEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            pooling=args_model.pooling,
        )

    elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']:
        args_model.skip_first_word = getattr(args_model, 'skip_first_word',
                                             None)
        encoder = GaussianPoolEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            enc_type=args_model.enc_type,
            skip_first_word=args_model.skip_first_word)
        #args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    args_model.encode_length = getattr(args_model, 'encode_length', None)
    if args_model.dec_type == 'lstm':
        decoder = LSTMDecoder(args_model, vocab, model_init, emb_init,
                              args_model.encode_length)
    elif args_model.dec_type == 'unigram':
        decoder = UnigramDecoder(args_model, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args_model,
              args_model.encode_length).to(device)

    if args.load_path:
        print("load args!")
        print(vae)
        loaded_state_dict = torch.load(args.load_path)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    vae.eval()

    def preprocess(data_fn):
        codes, labels = read_dataset(data_fn, vocab, device, vae,
                                     args.classify_using_samples)
        if args.classify_using_samples:
            is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2)
            codes = augment_dataset(codes, 1, is_gaussian_enc,
                                    vae)  # use only 1 sample for test
        codes = codes.cpu().numpy()
        labels = labels.cpu().numpy()
        return codes, labels

    test_codes, test_labels = preprocess(args.test_data)

    test_f1_scores = []
    average_f1 = 'macro'
    f1_scorer = make_scorer(f1_score,
                            average=average_f1,
                            labels=np.unique(test_labels),
                            greater_is_better=True)
    # log loss: negative log likelihood. We should minimize that, so greater_is_better=False
    log_loss_scorer = make_scorer(log_loss,
                                  needs_proba=True,
                                  greater_is_better=False)
    warnings.filterwarnings('ignore')
    results = {
        'n_samples_per_class': args.num_label_per_class,
    }
    n_repeats = args.n_repeats

    n_splits = min(args.num_label_per_class, 5)
    for i, fn in enumerate(train_sets):
        codes, labels = preprocess(fn)
        if args.resample > 1:
            # going to augment the training set by sampling
            # then create a new cross validation function to get the correct indices
            cross_val = augment_cross_val(labels, args.resample, n_splits,
                                          n_repeats)
            labels = np.repeat(labels, args.resample)
        else:
            cross_val = RepeatedStratifiedKFold(n_splits=n_splits,
                                                n_repeats=n_repeats)

        scaler = StandardScaler()
        codes = scaler.fit_transform(codes)
        scaled_test_codes = scaler.transform(test_codes)
        gridsearch = GridSearchCV(
            LogisticRegression(solver='sag', multi_class='auto'),
            {
                "penalty": ['l2'],
                "C": [0.01, 0.1, 1, 10, 100],
            },
            cv=cross_val,
            scoring={
                "f1": f1_scorer,
                "log": log_loss_scorer,
            },
            refit=False,
        )
        clf = gridsearch
        clf.fit(codes, labels)
        crossval_f1, test_f1 = refit_and_eval(
            'f1',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            f1_scorer,
        )
        crossval_log, test_log_loss = refit_and_eval(
            'log',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            log_loss_scorer,
        )
        results[i] = {
            "F1": {
                'crossval': crossval_f1,
                'test': test_f1
            },
            "log": {
                'crossval': crossval_log,
                'test': test_log_loss
            },
        }
        print(results[i])

    if args.classify_using_samples:
        n_per_class = str(args.num_label_per_class)
        resample = 1 if args.resample == -1 else args.resample
        output_fn = os.path.join(
            args.exp_dir,
            'results_sample_' + str(resample) + '_' + n_per_class + '.json')
    else:
        output_fn = os.path.join(args.exp_dir,
                                 'results_' + n_per_class + '.json')
    with open(output_fn, 'w') as f:
        json.dump(results, f)