Esempio n. 1
0
 def run_wiki2vec(self):
     try:
         db_connection = get_connection()
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             self.wiki2vec = load_wiki2vec()
             self.init_entity_embeds_wiki2vec()
             self.context_encoder = ContextEncoder(
                 self.wiki2vec, self.lookups.token_idx_lookup, self.device)
             self.encoder = SimpleJointModel(self.context_encoder)
             if not self.run_params.load_model:
                 with self.experiment.train(['error', 'loss']):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             else:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
                 self.encoder = nn.DataParallel(self.context_encoder)
                 self.encoder = self.context_encoder.to(self.device).module
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.context_encoder)
                 tester.test()
     finally:
         db_connection.close()
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    load_dotenv(dotenv_path='.env')
    paths = m(lookups=os.getenv("LOOKUPS_PATH"),
              page_id_order=os.getenv("PAGE_ID_ORDER_PATH"))
    model_params = m(freeze_word_embeddings=True)
    runner = Runner(device, paths=paths, model_params=model_params)
    runner.load_caches()
    runner.init_entity_embeds()
    db_connection = get_connection()
    with db_connection.cursor() as cursor:
        dataloader = DataLoader(dataset=runner._get_dataset(cursor,
                                                            is_test=True),
                                batch_sampler=runner._get_sampler(
                                    cursor, is_test=True),
                                collate_fn=collate)
        acc = 0
        n = 0
        for batch in dataloader:
            for label, candidate_ids in zip(batch['label'],
                                            batch['candidate_ids']):
                if int(label) in candidate_ids.tolist():
                    acc += 1
            n += 1
            print(acc, n)
        print(acc, n)
Esempio n. 3
0
 def run_deep_el(self):
     try:
         db_connection = get_connection(self.paths.env)
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             pad_vector = self.lookups.embedding(
                 torch.tensor([self.lookups.token_idx_lookup['<PAD>']],
                              device=self.lookups.embedding.weight.device)
             ).squeeze()
             self.init_entity_embeds_deep_el()
             entity_ids_by_freq = self._get_entity_ids_by_freq(cursor)
             if self.model_params.use_adaptive_softmax:
                 self.lookups = self.lookups.set(
                     'entity_labels',
                     _.from_pairs(
                         zip(entity_ids_by_freq,
                             range(len(entity_ids_by_freq)))))
             self.adaptive_logits = self._get_adaptive_calc_logits()
             self.encoder = JointModel(
                 self.model_params.embed_len,
                 self.model_params.context_embed_len,
                 self.model_params.word_embed_len,
                 self.model_params.local_encoder_lstm_size,
                 self.model_params.document_encoder_lstm_size,
                 self.model_params.num_lstm_layers,
                 self.train_params.dropout_drop_prob,
                 self.entity_embeds,
                 self.lookups.embedding,
                 pad_vector,
                 self.adaptive_logits,
                 self.model_params.use_deep_network,
                 self.model_params.use_lstm_local,
                 self.model_params.num_cnn_local_filters,
                 self.model_params.use_cnn_local,
                 self.model_params.ablation,
                 use_stacker=self.model_params.use_stacker,
                 desc_is_avg=self.model_params.desc_is_avg)
             if self.run_params.load_model:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
             if self.run_params.continue_training:
                 fields = [
                     'mention_context_error', 'document_context_error',
                     'loss'
                 ]
                 with self.experiment.train(fields):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.encoder)
                 tester.test()
     finally:
         db_connection.close()
Esempio n. 4
0
def main():
    p = get_cli_args(args)
    conll_path = 'custom.tsv' if p.run.use_custom else './AIDA-YAGO2-dataset.tsv'
    num_correct = 0
    missed_idxs = []
    guessed_when_missed = []
    db_connection = get_connection(p.run.env_path)
    model = load_model(p.model, p.train)
    with open('./tokens.pkl', 'rb') as fh:
        token_idx_lookup = pickle.load(fh)
    with open('./glove_token_idx_lookup.pkl', 'rb') as fh:
        full_token_idx_lookup = pickle.load(fh)
    with open('./val_test_indices.json', 'r') as fh:
        val_indices, test_indices = json.load(fh)
    model.eval()
    with torch.no_grad():
        with db_connection.cursor() as cursor:
            dataset = SimpleCoNLLDataset(cursor, token_idx_lookup,
                                         full_token_idx_lookup, conll_path,
                                         p.run.lookups_path, p.run.idf_path,
                                         p.train.train_size,
                                         p.run.txt_dataset_path)
            conll_test_set = DataLoader(
                dataset,
                batch_sampler=BatchSampler(
                    SubsetSequentialSampler(test_indices), p.run.batch_size,
                    False),
                collate_fn=collate_simple_mention_ranker)
            ctr = count()
            for batch in progressbar(conll_test_set):
                (candidate_ids, features), target_rankings = batch
                target = [ranking[0] for ranking in target_rankings]
                candidate_scores = model(features)
                top_1 = []
                offset = 0
                for ids in candidate_ids:
                    ranking_size = len(ids)
                    top_1.append(ids[torch.argmax(
                        candidate_scores[offset:offset +
                                         ranking_size]).item()])
                    offset += ranking_size
                for guess, label, ids, idx in zip(top_1, target, candidate_ids,
                                                  ctr):
                    if guess == label:
                        num_correct += 1
                    else:
                        missed_idxs.append(idx)
                        guessed_when_missed.append(guess)
            print(num_correct / next(ctr))
            import ipdb
            ipdb.set_trace()
            with open('./missed_idxs', 'w') as fh:
                fh.write('\n'.join(
                    [str((idx, dataset[idx])) for idx in missed_idxs]))
            with open('./guessed_when_missed', 'w') as fh:
                fh.write('\n'.join([str(idx) for idx in guessed_when_missed]))
Esempio n. 5
0
 def run_sum_encoder(self):
     try:
         db_connection = get_connection(self.paths.env)
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             pad_vector = self.lookups.embedding(
                 torch.tensor([self.lookups.token_idx_lookup['<PAD>']],
                              device=self.lookups.embedding.weight.device)
             ).squeeze()
             with open('./wiki_idf.json') as fh:
                 token_idf = json.load(fh)
                 self.idf = {
                     self.lookups.token_idx_lookup[token]: score
                     for token, score in token_idf.items()
                     if token in self.lookups.token_idx_lookup
                 }
             self.init_entity_embeds_sum_encoder(cursor)
             if self.model_params.only_mention_encoder:
                 self.context_encoder = ContextEncoderModel(
                     self.lookups.embedding, use_cnts=True, idf=self.idf)
             else:
                 self.context_encoder = MentionEncoderModel(
                     self.lookups.embedding,
                     (1 - self.train_params.dropout_drop_prob),
                     use_cnts=True,
                     idf=self.idf)
             self.context_encoder.to(self.device)
             self.encoder = SimpleJointModel(self.entity_embeds,
                                             self.context_encoder)
             if self.run_params.load_model:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
             if self.run_params.continue_training:
                 fields = ['error', 'loss']
                 with self.experiment.train(fields):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.encoder)
                 tester.test()
     finally:
         db_connection.close()
Esempio n. 6
0
def main():
    p = get_cli_args(args)
    conll_path = 'custom.tsv' if p.run.use_custom else './AIDA-YAGO2-dataset.tsv'
    db_connection = get_connection(p.run.env_path)
    model = load_model(p.model, p.train)
    with open('./tokens.pkl', 'rb') as fh:
        token_idx_lookup = pickle.load(fh)
    with open('./glove_token_idx_lookup.pkl', 'rb') as fh:
        full_token_idx_lookup = pickle.load(fh)
    with open('./val_test_indices.json', 'r') as fh:
        val_indices, test_indices = json.load(fh)
    model.eval()
    with torch.no_grad():
        with db_connection.cursor() as cursor:
            doc_id_dataset = MentionCoNLLDataset(cursor,
                                                 './AIDA-YAGO2-dataset.tsv',
                                                 p.run.lookups_path,
                                                 p.train.train_size)
            dataset = SimpleCoNLLDataset(cursor, token_idx_lookup,
                                         full_token_idx_lookup, conll_path,
                                         p.run.lookups_path, p.run.idf_path,
                                         p.train.train_size,
                                         p.run.txt_dataset_path)
            # compats = load_npz('compats_wiki+conll_100000.npz')
            with open('./entity_to_row_id.pkl', 'rb') as fh:
                entity_id_to_row = pickle.load(fh)
            idf = get_idf(token_idx_lookup, p.run.idf_path)
            desc_fs_sparse = csr_matrix(load_npz('./desc_fs.npz'))
            desc_vs = csr_matrix(sparse_to_tfidf_vs(idf, desc_fs_sparse))
            norm = np.sqrt((desc_vs.multiply(desc_vs)).sum(1))
            ctr = count()
            num_correct = 0
            num_in_val = 0
            num_correct_small = 0
            num_in_val_small = 0
            grouped = groupby(
                ((dataset[idx], doc_id_dataset.mention_doc_id[idx])
                 for idx in range(len(val_indices) + len(test_indices))),
                key=itemgetter(1))
            batches = [
                collate_simple_mention_ranker([data for data, doc_id in g])
                for doc_id, g in grouped
            ]
            val_indices = set(val_indices)
            for document_batch in progressbar(batches):
                (candidate_ids, features), target_rankings = document_batch
                target = [ranking[0] for ranking in target_rankings]
                candidate_scores = model(features)
                emissions = emissions_from_flat_scores(
                    [len(ids) for ids in candidate_ids], candidate_scores)
                keep_top_n = 5
                top_emissions = []
                top_cands = []
                idxs_to_check = []
                for i, (emission, cand_ids, top_1, idx) in enumerate(
                        zip(emissions, candidate_ids, target, ctr)):
                    if len(cand_ids) > 1:
                        if cand_ids[np.argmax(emission)] != top_1:
                            if idx in val_indices:
                                idxs_to_check.append(i)

                    em, cand = zip(*nlargest(keep_top_n,
                                             zip(emission, cand_ids),
                                             key=itemgetter(0)))
                    top_emissions.append(np.array(em))
                    top_cands.append(cand)
                compatibilities = compatibilities_from_ids(
                    entity_id_to_row, desc_vs, norm, top_cands)
                top_1_idx = mp_shallow_tree_doc(top_emissions, compatibilities)
                # top_1_idx = [np.argmax(em) for em in top_emissions]
                top_1 = [
                    cand_ids[idx]
                    for cand_ids, idx in zip(top_cands, top_1_idx)
                ]
                for guess, label in zip(top_1, target):
                    num_in_val += 1
                    if guess == label: num_correct += 1
                for idx in idxs_to_check:
                    guess = top_1[idx]
                    label = target[idx]
                    num_in_val_small += 1
                    if guess == label: num_correct_small += 1
            print(num_correct / num_in_val)
            print(num_correct_small / num_in_val_small)