コード例 #1
0
ファイル: vae.py プロジェクト: eracah/supervise-thyself
    def __init__(self, embed_len=32, **kwargs):
        super(VAE, self).__init__()
        self.embed_len = embed_len
        self.encoder = Encoder(embed_len=embed_len, **kwargs)

        self.logvar_fc = nn.Linear(in_features=self.encoder.enc_out_shape,
                                   out_features=self.embed_len)

        self.decoder = Decoder(self.encoder)
コード例 #2
0
ファイル: tdc.py プロジェクト: eracah/supervise-thyself
 def __init__(self, embed_len=32, **kwargs):
     super(TDC,self).__init__()
     self.args = kwargs["args"]
     self.embed_len = embed_len
     self.encoder = Encoder(embed_len=embed_len, **kwargs)
     self.interval_choices = [[0],[1],[2],[3,4],list(range(5,10))]
     self.num_buckets = len(self.interval_choices)
     self.temp_dist_predictor = LinearModel(in_feat=2*self.embed_len,
                                         out_feat=self.num_buckets)
コード例 #3
0
 def __init__(self, num_frames=3, embed_len=32, **kwargs):
     super(ShuffleNLearn, self).__init__()
     self.embed_len = embed_len
     self.encoder = Encoder(embed_len=embed_len, **kwargs)
     self.bin_clsf = LinearModel(in_feat=num_frames * self.embed_len,
                                 out_feat=2)
     self.args = kwargs["args"]
     self.stride = self.args.stride
     self.num_frames = self.args.frames_per_example
コード例 #4
0
def get_encodings(model: Encoder, data_loader: BaseDataLoader,
                  load_encoding: bool, save_file, is_src, is_mid,
                  encoding_num):
    if not load_encoding:
        batches = data_loader.create_batches("test",
                                             is_src=is_src,
                                             is_mid=is_mid)
        # encodings = np.empty((0, encoder.hidden_size*2))
        encodings = [[] for _ in range(encoding_num)]
        start_time = time.time()
        for idx, batch in enumerate(batches):
            if (idx + 1) % 10000 == 0:
                print("[INFO] process {} batches, using {:.2f} seconds".format(
                    idx + 1,
                    time.time() - start_time))
            cur_encodings = np.array(
                model.calc_encode(batch, is_src=is_src, is_mid=is_mid).cpu())
            append_multiple_encodings(encodings, cur_encodings, encoding_num)
        encodings = list2nparr(encodings, model.hidden_size, merge=True)
        print("[INFO] encoding shape: {}".format(str(encodings.shape)))
        print("[INFO] done all {} batches, using {:.2f} seconds".format(
            len(batches),
            time.time() - start_time))
        # np.save(save_file, encodings)
        # print("[INFO] save test encodings!")
    else:
        encodings = np.load(save_file)
        print("[INFO] load test encodings!")
        print(encodings.shape)

    if is_mid:
        kb_ids, data_plain = get_kb_id(data_loader.test_file.mid_file_name,
                                       data_loader.test_file.mid_str_idx,
                                       data_loader.test_file.mid_id_idx)
    else:
        if is_src:
            kb_ids, data_plain = get_kb_id(data_loader.test_file.src_file_name,
                                           data_loader.test_file.src_str_idx,
                                           data_loader.test_file.src_id_idx)
        else:
            kb_ids, data_plain = get_kb_id(data_loader.test_file.trg_file_name,
                                           data_loader.test_file.trg_str_idx,
                                           data_loader.test_file.trg_id_idx)

    assert kb_ids.shape[0] == int(encodings.shape[0] / encoding_num) \
           and len(data_plain) == int(encodings.shape[0] / encoding_num), \
        (kb_ids.shape[0], int(encodings.shape[0] / encoding_num), len(data_plain))

    return encodings, kb_ids, data_plain
コード例 #5
0
def eval_data(model: Encoder, train_batches:List[BaseBatch], dev_batches: List[BaseBatch], similarity_measure: Similarity, args_dict: dict):
    use_mid = args_dict["use_mid"]
    topk = args_dict["topk"]
    trg_encoding_num = args_dict["trg_encoding_num"]
    mid_encoding_num = args_dict["mid_encoding_num"]
    # treat train target strings as the KB
    recall = 0
    tot = 0
    KB_encodings = [[] for _ in range(trg_encoding_num)]
    KB_ids = []
    for batch in train_batches:
        cur_encodings = np.array(model.calc_encode(batch, is_src=False).cpu())
        append_multiple_encodings(KB_encodings, cur_encodings, trg_encoding_num)
        KB_ids += batch.trg_kb_ids
    assert len(KB_encodings[0]) == len(train_batches)
    KB_encodings = list2nparr(KB_encodings, model.hidden_size)

    src_encodings = []
    trg_encodings = [[] for _ in range(trg_encoding_num)]
    trg_kb_ids = []
    for batch in dev_batches:
        src_encodings.append(np.array(model.calc_encode(batch, is_src=True).cpu()))
        cur_encodings = np.array(model.calc_encode(batch, is_src=False).cpu())
        append_multiple_encodings(trg_encodings, cur_encodings, trg_encoding_num)
        trg_kb_ids += batch.trg_kb_ids
    assert len(src_encodings) == len(dev_batches)
    assert len(trg_encodings[0]) == len(dev_batches)

    src_encodings = list2nparr([src_encodings], model.hidden_size, merge=True)
    trg_encodings = list2nparr(trg_encodings, model.hidden_size)


    # TODO might need it in the future
    # prune KB_encodings so that all entities are unique
    unique_kb_idx = get_unique_kb_idx(KB_ids)
    KB_encodings = [x[unique_kb_idx] for x in KB_encodings]

    all_trg_encodings = merge_encodings(trg_encodings, KB_encodings)
    n = max(all_trg_encodings.shape[0], 160000)
    all_trg_encodings = all_trg_encodings[:n]
    # calculate similarity`
    # [dev_size, dev_size + kb_size]
    scores = similarity_measure(src_encodings, all_trg_encodings, is_src_trg=True, split=True,
                                pieces=10, negative_sample=None, encoding_num=trg_encoding_num)
    encoding_scores = np.copy(scores)
    if use_mid:
        mid_KB_encodings = [[] for _ in range(mid_encoding_num)]
        for batch in train_batches:
            cur_encodings = np.array(model.calc_encode(batch, is_src=False, is_mid=True).cpu())
            append_multiple_encodings(mid_KB_encodings, cur_encodings, mid_encoding_num)
            KB_ids += batch.trg_kb_ids
        assert len(mid_KB_encodings[0]) == len(train_batches)
        mid_KB_encodings = list2nparr(mid_KB_encodings, model.hidden_size)

        mid_encodings = [[] for _ in range(mid_encoding_num)]
        for batch in dev_batches:
            cur_encodings = np.array(model.calc_encode(batch, is_src=False, is_mid=True).cpu())
            append_multiple_encodings(mid_encodings, cur_encodings, mid_encoding_num)
        assert len(mid_encodings[0]) == len(dev_batches)
        mid_encodings = list2nparr(mid_encodings, model.hidden_size)

        # TODO might need it in the future
        # mid_KB_encodings = mid_KB_encodings[unique_kb_idx]
        all_mid_encodings = merge_encodings(mid_encodings, mid_KB_encodings)
        all_mid_encodings = all_mid_encodings[:n]
        all_mid_encodings = all_mid_encodings[:n]

        mid_scores = similarity_measure(src_encodings, all_mid_encodings,
                                        is_src_trg=False, split=True, pieces=10, negative_sample=None, encoding_num=mid_encoding_num)
        scores = np.maximum(scores, mid_scores)
    for entry_idx, entry_scores in enumerate(scores):
        ranked_idxes = entry_scores.argsort()[::-1]
        # the correct index is entry_idx
        if entry_idx in ranked_idxes[:topk]:
            recall += 1
        tot += 1

    recall_2 = 0
    for entry_idx, entry_scores in enumerate(encoding_scores):
        ranked_idxes = entry_scores.argsort()[::-1]
        # the correct index is entry_idx
        if entry_idx in ranked_idxes[:topk]:
            recall_2 += 1

    return [recall, recall_2], tot
コード例 #6
0
def run(data_loader: BaseDataLoader, encoder: Encoder, criterion, optimizer: optim, scheduler: optim.lr_scheduler,
          similarity_measure: Similarity, save_model,
          args:argparse.Namespace):
    encoder.to(device)
    best_accs = {"encode_acc": float('-inf'), "pivot_acc": float('-inf')}
    last_update = 0
    dev_arg_dict = {
        "use_mid": args.use_mid,
        "topk": args.val_topk,
        "trg_encoding_num": args.trg_encoding_num,
        "mid_encoding_num": args.mid_encoding_num
    }
    # lr_decay = scheduler is not None
    # if lr_decay:
    #     print("[INFO] using learning rate decay")
    for ep in range(args.max_epoch):
        encoder.train()
        train_loss = 0.0
        start_time = time.time()
        # if not args.mega:
        train_batches = data_loader.create_batches("train")
        # else:
        #     if ep <= 30:
        #         train_batches = data_loader.create_batches("train")
        #     else:
        #         train_batches = data_loader.create_megabatch(encoder)
        batch_num = 0
        t = 0
        for idx, batch in enumerate(train_batches):
            optimizer.zero_grad()
            cur_loss = calc_batch_loss(encoder, criterion, batch, args.mid_proportion, args.trg_encoding_num, args.mid_encoding_num)
            train_loss += cur_loss.item()
            cur_loss.backward()
            # optimizer.step()

            for p in list(filter(lambda p: p.grad is not None, encoder.parameters())):
                t += p.grad.data.norm(2).item()

            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5)
            optimizer.step()

            if encoder.name == "bilstm":
                # set all but forget gate bias to 0
                reset_bias(encoder.src_lstm)
                reset_bias(encoder.trg_lstm)
                # pass
            batch_num += 1
        print("[INFO] epoch {:d}: train loss={:.8f}, time={:.2f}".format(ep, train_loss / batch_num,
                                                                         time.time()-start_time))
        # print(t)

        if (ep + 1) % EPOCH_CHECK == 0:
            with torch.no_grad():
                encoder.eval()
                # eval
                train_batches = data_loader.create_batches("train")
                dev_batches = data_loader.create_batches("dev")
                start_time = time.time()

                recall, tot = eval_data(encoder, train_batches, dev_batches, similarity_measure, dev_arg_dict)
                dev_pivot_acc = recall[0] / float(tot)
                dev_encode_acc = recall[1] / float(tot)
                if dev_encode_acc > best_accs["encode_acc"]:
                    best_accs["encode_acc"] = dev_encode_acc
                    best_accs["pivot_acc"] = dev_pivot_acc
                    last_update = ep + 1
                    save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "best" + ".tar")
                save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "last" + ".tar")
                print("[INFO] epoch {:d}: encoding/pivoting dev acc={:.4f}/{:.4f}, time={:.2f}".format(
                                                                                            ep, dev_encode_acc, dev_pivot_acc,
                                                                                            time.time()-start_time))
                if args.lr_decay and ep + 1 - last_update > UPDATE_PATIENT:
                    new_lr = optimizer.param_groups[0]['lr'] * args.lr_scaler
                    best_info  = torch.load(args.model_path + "_" + "best" + ".tar")
                    encoder.load_state_dict(best_info["model_state_dict"])
                    optimizer.load_state_dict(best_info["optimizer_state_dict"])
                    optimizer.param_groups[0]['lr'] = new_lr
                    print("[INFO] reload best model ..")

                if ep + 1 - last_update > PATIENT:
                    print("[FINAL] in epoch {}, the best develop encoding/pivoting accuracy = {:.4f}/{:.4f}".format(ep + 1,
                                                                                                                    best_accs["encode_acc"],
                                                                                                                    best_accs["pivot_acc"]))
                    break
コード例 #7
0
 def __init__(self, embed_len=32, num_actions=3, **kwargs):
     super(InverseModel, self).__init__()
     self.embed_len = embed_len
     self.encoder = Encoder(embed_len=embed_len, **kwargs)
     self.action_predictor = LinearModel(in_feat=2 * self.embed_len,
                                         out_feat=num_actions)
コード例 #8
0
def eval_dataset(model: Encoder, similarity_calculator: Similarity,
                 base_data_loader: BaseDataLoader, encoded_test_file,
                 load_encoded_test, encoded_kb_file, load_encoded_kb,
                 intermediate_stuff, method, trg_encoding_num,
                 mid_encoding_num, result_files: dict, record_recall: bool):
    with torch.no_grad():
        model.eval()
        model.to(device)
        encoded_test, test_gold_kb_id, test_data_plain = get_encodings(
            model,
            base_data_loader,
            load_encoded_test,
            encoded_test_file,
            is_src=True,
            is_mid=False,
            encoding_num=1)
        encoded_kb, kb_ids, kb_entity_string = get_encodings(
            model,
            base_data_loader,
            load_encoded_kb,
            encoded_kb_file,
            is_src=False,
            is_mid=False,
            encoding_num=trg_encoding_num)
        intermediate_info = {}
        if method != "base":
            intermediate_encodings = {}
            intermediate_kb_id = {}
            intermediate_plain_text = {}
            for stuff in intermediate_stuff:
                # name is used to present the contain of this intermediate stuff
                name, data_loader, encoded_file, load_encoded, is_src, is_mid = stuff
                encoded_stuff, gold_kb_id, plain_text = get_encodings(
                    model,
                    data_loader,
                    load_encoded,
                    encoded_file,
                    is_src=is_src,
                    is_mid=is_mid,
                    encoding_num=mid_encoding_num)
                intermediate_encodings[name] = encoded_stuff
                intermediate_kb_id[name] = gold_kb_id
                intermediate_plain_text[name] = plain_text
            intermediate_info["encodings"] = intermediate_encodings
            intermediate_info["kb_id"] = intermediate_kb_id
            intermediate_info["plain_text"] = intermediate_plain_text
        start_time = time.time()
        calc_result(encoded_test,
                    test_gold_kb_id,
                    test_data_plain,
                    encoded_kb,
                    kb_ids,
                    kb_entity_string,
                    intermediate_info,
                    method,
                    similarity_calculator,
                    result_files,
                    trg_encoding_num,
                    mid_encoding_num,
                    record_recall=record_recall)

        print(
            "[INFO] take {:.4f}s to calculate similarity".format(time.time() -
                                                                 start_time))