コード例 #1
0
ファイル: transfer.py プロジェクト: ghazi-f/CP-VAE
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name
    train_data_pth = os.path.join(data_pth, "train_data.txt")
    train_feat_pth = os.path.join(data_pth, "train_%s.npy" % args.feat)
    train_data = MonoTextData(train_data_pth, True)
    train_feat = np.load(train_feat_pth)
    vocab = train_data.vocab
    dev_data_pth = os.path.join(data_pth, "dev_data.txt")
    dev_feat_pth = os.path.join(data_pth, "dev_%s.npy" % args.feat)
    dev_data = MonoTextData(dev_data_pth, True, vocab=vocab)
    dev_feat = np.load(dev_feat_pth)
    test_data_pth = os.path.join(data_pth, "test_data.txt")
    test_feat_pth = os.path.join(data_pth, "test_%s.npy" % args.feat)
    test_data = MonoTextData(test_data_pth, True, vocab=vocab)
    test_feat = np.load(test_feat_pth)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    kwargs = {
        "train": ([1], None),
        "valid": (None, None),
        "test": (None, None),
        "feat": None,
        "bsz": 32,
        "save_path": args.load_path,
        "logging": None,
        "text_only": args.text_only,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = vocab
    params["vae_params"]["device"] = device
    params["vae_params"]["text_only"] = args.text_only
    params["vae_params"]["mlp_ni"] = dev_feat.shape[1]
    kwargs = dict(kwargs, **params)

    model = DecomposedVAE(**kwargs)
    model.load(args.load_path)
    model.vae.eval()

    train_data, train_feat = train_data.create_data_batch_feats(
        32, train_feat, device)
    print("Collecting training distributions...")
    mus, logvars = [], []
    step = 0
    for batch_data, batch_feat in zip(train_data, train_feat):
        mu1, logvar1 = model.vae.lstm_encoder(batch_data)
        mu2, logvar2 = model.vae.mlp_encoder(batch_feat)
        r, _ = model.vae.mlp_encoder(batch_feat, True)
        p = model.vae.get_var_prob(r)
        mu = torch.cat([mu1, mu2], -1)
        logvar = torch.cat([logvar1, logvar2], -1)
        mus.append(mu.detach().cpu())
        logvars.append(logvar.detach().cpu())
        step += 1
        if step % 100 == 0:
            torch.cuda.empty_cache()
    mus = torch.cat(mus, 0)
    logvars = torch.cat(logvars, 0)

    if args.text_only:
        neg_sample = dev_data.data[:10]
        neg_inputs, _ = dev_data._to_tensor(neg_sample,
                                            batch_first=False,
                                            device=device)
    else:
        neg_sample = dev_feat[:10]
        neg_inputs = torch.tensor(neg_sample,
                                  dtype=torch.float,
                                  requires_grad=False,
                                  device=device)
    r, _ = model.vae.mlp_encoder(neg_inputs, True)
    p = model.vae.get_var_prob(r).mean(0, keepdim=True)
    neg_idx = torch.max(p, 1)[1].item()

    if args.text_only:
        pos_sample = dev_data.data[-10:]
        pos_inputs, _ = dev_data._to_tensor(pos_sample,
                                            batch_first=False,
                                            device=device)
    else:
        pos_sample = dev_feat[-10:]
        pos_inputs = torch.tensor(pos_sample,
                                  dtype=torch.float,
                                  requires_grad=False,
                                  device=device)
    r, _ = model.vae.mlp_encoder(pos_inputs, True)
    p = model.vae.get_var_prob(r).mean(0, keepdim=True)
    top2 = torch.topk(p, 2, 1)[1].squeeze()
    if top2[0].item() == neg_idx:
        print("Collision!!! Use second most as postive.")
        pos_idx = top2[1].item()
    else:
        pos_idx = top2[0].item()
    other_idx = -1
    for i in range(3):
        if i not in [pos_idx, neg_idx]:
            other_idx = i
            break

    print("Negative: %d" % neg_idx)
    print("Positive: %d" % pos_idx)

    sep_id = -1
    for idx, x in enumerate(test_data.labels):
        if x == 1:
            sep_id = idx
            break

    bsz = 64
    ori_logps = []
    tra_logps = []
    pos_z2 = model.vae.mlp_encoder.var_embedding[pos_idx:pos_idx + 1]
    neg_z2 = model.vae.mlp_encoder.var_embedding[neg_idx:neg_idx + 1]
    other_z2 = model.vae.mlp_encoder.var_embedding[other_idx:other_idx + 1]
    _, d0 = get_coordinates(pos_z2[0], neg_z2[0], other_z2[0])
    ori_obs = []
    tra_obs = []
    with open(os.path.join(args.load_path, 'generated_results.txt'), "w") as f:
        idx = 0
        step = 0
        n_samples = len(test_data.labels)
        while idx < n_samples:
            label = test_data.labels[idx]
            _idx = idx + bsz if label else min(idx + bsz, sep_id)
            _idx = min(_idx, n_samples)
            var_id = neg_idx if label else pos_idx
            text, _ = test_data._to_tensor(test_data.data[idx:_idx],
                                           batch_first=False,
                                           device=device)
            feat = torch.tensor(test_feat[idx:_idx],
                                dtype=torch.float,
                                requires_grad=False,
                                device=device)
            z1, _ = model.vae.lstm_encoder(text[:min(text.shape[0], 10)])
            ori_z2, _ = model.vae.mlp_encoder(feat)
            tra_z2 = model.vae.mlp_encoder.var_embedding[var_id:var_id +
                                                         1, :].expand(
                                                             _idx - idx, -1)
            texts = model.vae.decoder.beam_search_decode(z1, tra_z2)
            for text in texts:
                f.write("%d\t%s\n" % (1 - label, " ".join(text[1:-1])))

            ori_z = torch.cat([z1, ori_z2], -1)
            tra_z = torch.cat([z1, tra_z2], -1)
            for i in range(_idx - idx):
                ori_logps.append(
                    cal_log_density(mus, logvars, ori_z[i:i + 1].cpu()))
                tra_logps.append(
                    cal_log_density(mus, logvars, tra_z[i:i + 1].cpu()))

            idx = _idx
            step += 1
            if step % 100 == 0:
                print(step, idx)

    with open(os.path.join(args.load_path, 'nll.txt'), "w") as f:
        for x, y in zip(ori_logps, tra_logps):
            f.write("%f\t%f\n" % (x, y))
コード例 #2
0
ファイル: run.py プロジェクト: AkshayaRaju/CP-VAE
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name

    train_sentiment_data_pth = os.path.join(data_pth,
                                            "train_sentiment_data.txt")
    train_sentiment_feat_pth = os.path.join(
        data_pth, "train_sentiment_%s.npy" % args.feat)
    train_sentiment_data = MonoTextData(train_sentiment_data_pth, True)
    train_sentiment_feat = np.load(train_sentiment_feat_pth)

    train_tense_data_pth = os.path.join(data_pth, "train_tense_data.txt")
    train_tense_feat_pth = os.path.join(data_pth,
                                        "train_tense_%s.npy" % args.feat)
    train_tense_data = MonoTextData(train_tense_data_pth, True)
    train_tense_feat = np.load(train_tense_feat_pth)

    sentiment_vocab = train_sentiment_data.vocab
    print('Sentiment Vocabulary size: %d' % len(sentiment_vocab))

    tense_vocab = train_tense_data.vocab
    print('Tense Vocabulary size: %d' % len(tense_vocab))

    dev_sentiment_data_pth = os.path.join(data_pth, "dev_sentiment_data.txt")
    dev_sentiment_feat_pth = os.path.join(data_pth,
                                          "dev_sentiment_%s.npy" % args.feat)
    dev_sentiment_data = MonoTextData(dev_sentiment_data_pth,
                                      True,
                                      vocab=sentiment_vocab)
    dev_sentiment_feat = np.load(dev_sentiment_feat_pth)

    dev_tense_data_pth = os.path.join(data_pth, "dev_tense_data.txt")
    dev_tense_feat_pth = os.path.join(data_pth, "dev_tense_%s.npy" % args.feat)
    dev_tense_data = MonoTextData(dev_tense_data_pth, True, vocab=tense_vocab)
    dev_tense_feat = np.load(dev_tense_feat_pth)

    test_sentiment_data_pth = os.path.join(data_pth, "test_sentiment_data.txt")
    test_sentiment_feat_pth = os.path.join(data_pth,
                                           "test_sentiment_%s.npy" % args.feat)
    test_sentiment_data = MonoTextData(test_sentiment_data_pth,
                                       True,
                                       vocab=sentiment_vocab)
    test_sentiment_feat = np.load(test_sentiment_feat_pth)

    test_tense_data_pth = os.path.join(data_pth, "test_tense_data.txt")
    test_tense_feat_pth = os.path.join(data_pth,
                                       "test_tense_%s.npy" % args.feat)
    test_tense_data = MonoTextData(test_tense_data_pth,
                                   True,
                                   vocab=tense_vocab)
    test_tense_feat = np.load(test_tense_feat_pth)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_path0 = 'sentiment-{}-{}-{}'.format(args.save, args.data_name,
                                             args.feat)
    save_path0 = os.path.join(save_path0, time.strftime("%Y%m%d-%H%M%S"))
    save_path1 = 'tense-{}-{}-{}'.format(args.save, args.data_name, args.feat)
    save_path1 = os.path.join(save_path1, time.strftime("%Y%m%d-%H%M%S"))

    scripts_to_save = [
        'run.py', 'models/decomposed_vae.py', 'models/vae.py',
        'models/base_network.py', 'config.py'
    ]
    logging0 = create_exp_dir(save_path0,
                              scripts_to_save=scripts_to_save,
                              debug=args.debug)
    logging1 = create_exp_dir(save_path1,
                              scripts_to_save=scripts_to_save,
                              debug=args.debug)

    if args.text_only:
        train_sentiment = train_sentiment_data.create_data_batch(
            args.bsz, device)
        dev_sentiment = dev_sentiment_data.create_data_batch(args.bsz, device)
        test_sentiment = test_sentiment_data.create_data_batch(
            args.bsz, device)
        feat_sentiment = train_sentiment

        train_tense = train_tense_data.create_data_batch(args.bsz, device)
        test_tense = test_tense_data.create_data_batch(args.bsz, device)
        feat_tense = train_tense
    else:
        train_sentiment = train_sentiment_data.create_data_batch_feats(
            args.bsz, train_sentiment_feat, device)
        dev_sentiment = dev_sentiment_data.create_data_batch_feats(
            args.bsz, dev_sentiment_feat, device)
        test_sentiment = test_sentiment_data.create_data_batch_feats(
            args.bsz, test_sentiment_feat, device)
        feat_sentiment = train_sentiment_feat
        train_tense = train_tense_data.create_data_batch_feats(
            args.bsz, train_tense_feat, device)
        test_tense = test_tense_data.create_data_batch_feats(
            args.bsz, test_tense_feat, device)
        feat_tense = train_tense_feat

    #VAE training on sentiment data
    # kwargs0 = {
    #     "train": train_sentiment,
    #     "valid": dev_sentiment,
    #     "test": test_sentiment,
    #     "feat": feat_sentiment,
    #     "bsz": args.bsz,
    #     "save_path": save_path0,
    #     "logging": logging0,
    #     "text_only": args.text_only,
    # }
    # params = conf["params"]
    # params["vae_params"]["vocab"] = sentiment_vocab
    # params["vae_params"]["device"] = device
    # params["vae_params"]["text_only"] = args.text_only
    # params["vae_params"]["mlp_ni"] = train_sentiment_feat.shape[1]
    # kwargs0 = dict(kwargs0, **params)

    # sentiment_model = DecomposedVAE(**kwargs0)
    # try:
    #     valid_loss = sentiment_model.fit()
    #     logging("sentiment val loss : {}".format(valid_loss))
    # except KeyboardInterrupt:
    #     logging("Exiting from training early")

    # sentiment_model.load(save_path0)
    # test_loss = model.evaluate(sentiment_model.test_data, sentiment_model.test_feat)
    # logging("sentiment test loss: {}".format(test_loss[0]))
    # logging("sentiment test recon: {}".format(test_loss[1]))
    # logging("sentiment test kl1: {}".format(test_loss[2]))
    # logging("sentiment test kl2: {}".format(test_loss[3]))
    # logging("sentiment test mi1: {}".format(test_loss[4]))
    # logging("sentiment test mi2: {}".format(test_loss[5]))

    #VAE training on tense data
    kwargs1 = {
        "train": train_tense,
        "valid": test_tense,
        "test": test_tense,
        "feat": feat_tense,
        "bsz": args.bsz,
        "save_path": save_path1,
        "logging": logging1,
        "text_only": args.text_only,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = tense_vocab
    params["vae_params"]["device"] = device
    params["vae_params"]["text_only"] = args.text_only
    params["vae_params"]["mlp_ni"] = train_tense_feat.shape[1]
    kwargs1 = dict(kwargs1, **params)

    tense_model = DecomposedVAE(**kwargs1)
    try:
        valid_loss = tense_model.fit()
        logging("tense val loss : {}".format(valid_loss))
    except KeyboardInterrupt:
        logging("Exiting from training early")

    tense_model.load(save_path1)
    test_loss = model.evaluate(tense_model.test_data, tense_model.test_feat)
    logging("tense test loss: {}".format(test_loss[0]))
    logging("tense test recon: {}".format(test_loss[1]))
    logging("tense test kl1: {}".format(test_loss[2]))
    logging("tense test kl2: {}".format(test_loss[3]))
    logging("tense test mi1: {}".format(test_loss[4]))
    logging("tense test mi2: {}".format(test_loss[5]))
コード例 #3
0
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name
    train_data_pth = os.path.join(data_pth, "train_input_data.csv")
    train_feat_pth = os.path.join(data_pth, "train_%s.npy" % args.feat)
    train_data = MonoTextData(train_data_pth, True)
    train_feat = np.load(train_feat_pth)

    vocab = train_data.vocab
    print('Vocabulary size: %d' % len(vocab))

    dev_data_pth = os.path.join(data_pth, "dev_input_data.csv")
    dev_feat_pth = os.path.join(data_pth, "dev_%s.npy" % args.feat)
    dev_data = MonoTextData(dev_data_pth, True, vocab=vocab)
    dev_feat = np.load(dev_feat_pth)
    test_data_pth = os.path.join(data_pth, "test_input_data.csv")
    test_feat_pth = os.path.join(data_pth, "test_%s.npy" % args.feat)
    test_data = MonoTextData(test_data_pth, True, vocab=vocab)
    test_feat = np.load(test_feat_pth)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_path = '{}-{}-{}'.format(args.save, args.data_name, args.feat)
    save_path = os.path.join(save_path, time.strftime("%Y%m%d-%H%M%S"))
    scripts_to_save = [
        'run.py', 'models/decomposed_vae.py', 'models/vae.py',
        'models/base_network.py', 'config.py'
    ]
    logging = create_exp_dir(save_path,
                             scripts_to_save=scripts_to_save,
                             debug=args.debug)

    if args.text_only:
        train, train_sentiments, train_tenses = train_data.create_data_batch_labels(
            args.bsz, device)
        dev, dev_sentiments, dev_tenses = dev_data.create_data_batch_labels(
            args.bsz, device)
        test, test_sentiments, test_tenses = test_data.create_data_batch_labels(
            args.bsz, device)
        feat = train
    else:
        train = train_data.create_data_batch_feats(args.bsz, train_feat,
                                                   device)
        dev = dev_data.create_data_batch_feats(args.bsz, dev_feat, device)
        test = test_data.create_data_batch_feats(args.bsz, test_feat, device)
        feat = train_feat

    print("data done.")

    kwargs = {
        "train": train,
        "valid": dev,
        "test": test,
        "train_sentiments": train_sentiments,
        "train_tenses": train_tenses,
        "dev_sentiments": dev_sentiments,
        "dev_tenses": dev_tenses,
        "test_sentiments": test_sentiments,
        "test_tenses": test_tenses,
        "feat": feat,
        "bsz": args.bsz,
        "save_path": save_path,
        "logging": logging,
        "text_only": args.text_only,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = vocab
    params["vae_params"]["device"] = device
    params["vae_params"]["text_only"] = args.text_only
    params["vae_params"]["mlp_ni"] = train_feat.shape[1]
    kwargs = dict(kwargs, **params)

    model = DecomposedVAE(**kwargs)
    try:
        valid_loss = model.fit()
        logging("val loss : {}".format(valid_loss))
    except KeyboardInterrupt:
        logging("Exiting from training early")

    model.load(save_path)
    test_loss = model.evaluate(model.test_data, model.test_feat)
    logging("test loss: {}".format(test_loss[0]))
    logging("test recon: {}".format(test_loss[1]))
    logging("test kl1: {}".format(test_loss[2]))
    logging("test kl2: {}".format(test_loss[3]))
    logging("test mi1: {}".format(test_loss[4]))
    logging("test mi2: {}".format(test_loss[5]))