Beispiel #1
0
 def run_wiki2vec(self):
     try:
         db_connection = get_connection()
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             self.wiki2vec = load_wiki2vec()
             self.init_entity_embeds_wiki2vec()
             self.context_encoder = ContextEncoder(
                 self.wiki2vec, self.lookups.token_idx_lookup, self.device)
             self.encoder = SimpleJointModel(self.context_encoder)
             if not self.run_params.load_model:
                 with self.experiment.train(['error', 'loss']):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             else:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
                 self.encoder = nn.DataParallel(self.context_encoder)
                 self.encoder = self.context_encoder.to(self.device).module
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.context_encoder)
                 tester.test()
     finally:
         db_connection.close()
Beispiel #2
0
 def run_sum_encoder(self):
     try:
         db_connection = get_connection(self.paths.env)
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             pad_vector = self.lookups.embedding(
                 torch.tensor([self.lookups.token_idx_lookup['<PAD>']],
                              device=self.lookups.embedding.weight.device)
             ).squeeze()
             with open('./wiki_idf.json') as fh:
                 token_idf = json.load(fh)
                 self.idf = {
                     self.lookups.token_idx_lookup[token]: score
                     for token, score in token_idf.items()
                     if token in self.lookups.token_idx_lookup
                 }
             self.init_entity_embeds_sum_encoder(cursor)
             if self.model_params.only_mention_encoder:
                 self.context_encoder = ContextEncoderModel(
                     self.lookups.embedding, use_cnts=True, idf=self.idf)
             else:
                 self.context_encoder = MentionEncoderModel(
                     self.lookups.embedding,
                     (1 - self.train_params.dropout_drop_prob),
                     use_cnts=True,
                     idf=self.idf)
             self.context_encoder.to(self.device)
             self.encoder = SimpleJointModel(self.entity_embeds,
                                             self.context_encoder)
             if self.run_params.load_model:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
             if self.run_params.continue_training:
                 fields = ['error', 'loss']
                 with self.experiment.train(fields):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.encoder)
                 tester.test()
     finally:
         db_connection.close()
Beispiel #3
0
 def run_deep_el(self):
     try:
         db_connection = get_connection(self.paths.env)
         with db_connection.cursor() as cursor:
             self.load_caches(cursor)
             pad_vector = self.lookups.embedding(
                 torch.tensor([self.lookups.token_idx_lookup['<PAD>']],
                              device=self.lookups.embedding.weight.device)
             ).squeeze()
             self.init_entity_embeds_deep_el()
             entity_ids_by_freq = self._get_entity_ids_by_freq(cursor)
             if self.model_params.use_adaptive_softmax:
                 self.lookups = self.lookups.set(
                     'entity_labels',
                     _.from_pairs(
                         zip(entity_ids_by_freq,
                             range(len(entity_ids_by_freq)))))
             self.adaptive_logits = self._get_adaptive_calc_logits()
             self.encoder = JointModel(
                 self.model_params.embed_len,
                 self.model_params.context_embed_len,
                 self.model_params.word_embed_len,
                 self.model_params.local_encoder_lstm_size,
                 self.model_params.document_encoder_lstm_size,
                 self.model_params.num_lstm_layers,
                 self.train_params.dropout_drop_prob,
                 self.entity_embeds,
                 self.lookups.embedding,
                 pad_vector,
                 self.adaptive_logits,
                 self.model_params.use_deep_network,
                 self.model_params.use_lstm_local,
                 self.model_params.num_cnn_local_filters,
                 self.model_params.use_cnn_local,
                 self.model_params.ablation,
                 use_stacker=self.model_params.use_stacker,
                 desc_is_avg=self.model_params.desc_is_avg)
             if self.run_params.load_model:
                 path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                 self.encoder.load_state_dict(torch.load(path))
             if self.run_params.continue_training:
                 fields = [
                     'mention_context_error', 'document_context_error',
                     'loss'
                 ]
                 with self.experiment.train(fields):
                     self.log.status('Training')
                     trainer = self._get_trainer(cursor, self.encoder)
                     trainer.train()
                     torch.save(self.encoder.state_dict(),
                                './' + self.experiment.model_name)
             with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                 self.log.status('Testing')
                 tester = self._get_tester(cursor, self.encoder)
                 tester.test()
     finally:
         db_connection.close()
Beispiel #4
0
class Runner(object):
    def __init__(self,
                 device,
                 paths=default_paths,
                 train_params=default_train_params,
                 model_params=default_model_params,
                 run_params=default_run_params):
        self.train_params = m().update(default_train_params).update(
            train_params)
        self.model_params = m().update(default_model_params).update(
            model_params)
        self.run_params = m().update(default_run_params).update(run_params)
        self.paths = m().update(default_paths).update(paths)
        self.experiment = Experiment(
            self.train_params.update(self.run_params).update(
                self.model_params))
        self.log = self.experiment.log
        self.lookups = m()
        self.device = device
        self.model_params = self.model_params.set('context_embed_len',
                                                  self.model_params.embed_len)
        if not hasattr(self.model_params, 'adaptive_softmax_cutoffs'):
            self.model_params = self.model_params.set(
                'adaptive_softmax_cutoffs',
                [1000, 10000, 100000, 300000, 500000])
        self.paths = self.paths.set('word_embedding',
                                    self._get_word_embedding_path())
        self.page_id_order: Optional[list] = None
        self.num_train_pages: Optional[int] = None
        self.page_id_order_train: Optional[list] = None
        self.page_id_order_test: Optional[list] = None
        self.entity_embeds: Optional[nn.Embedding] = None
        self.adaptive_logits = {'desc': None, 'mention': None}
        self.encoder = None
        self.use_conll = self.run_params.use_conll
        self.context_encoder = None
        self.wiki2vec = None

    def _get_word_embedding_path(self):
        dim = self.model_params.word_embed_len
        if self.model_params.word_embedding_set.lower() == 'glove':
            return f'./glove.6B.{dim}d.txt'
        else:
            raise NotImplementedError(
                'Only loading from glove is currently supported')

    def _get_entity_ids_by_freq(self, cursor):
        query = 'select entity_id, count(*) from entity_mentions group by `entity_id` order by count(*) desc'
        cursor.execute(query)
        sorted_rows = cursor.fetchall()
        return [row['entity_id'] for row in sorted_rows]

    def load_caches(self, cursor):
        self.log.status('Loading entity candidate_ids lookup')
        lookups = load_entity_candidate_ids_and_label_lookup(
            self.paths.lookups, self.train_params.train_size)
        if not hasattr(self.model_params, 'num_entities'):
            self.log.status('Getting number of entities')
            if self.train_params.min_mentions > 1:
                query = 'select id from entities where num_mentions >= ' + str(
                    self.train_params.min_mentions)
                cursor.execute(query)
                valid_old_entity_labels_to_ids = {
                    lookups['entity_labels'][row['id']]: row['id']
                    for row in cursor.fetchall()
                }
                new_labels = {}
                new_prior = defaultdict(dict)
                for entity, counts in lookups['entity_candidates_prior'].items(
                ):
                    for old_label in counts.keys():
                        if old_label in valid_old_entity_labels_to_ids:
                            entity_id = valid_old_entity_labels_to_ids[
                                old_label]
                            new_label = new_labels.get(entity_id,
                                                       len(new_labels))
                            new_labels[entity_id] = new_label
                            new_prior[entity][new_label] = counts[old_label]
                lookups = {
                    'entity_candidates_prior': new_prior,
                    'entity_labels': new_labels
                }
            self.model_params = self.model_params.set(
                'num_entities', len(lookups['entity_labels']))
        self.log.status('Loading word embedding lookup')
        embedding_dict = get_embedding_dict(
            self.paths.word_embedding,
            embedding_dim=self.model_params.word_embed_len)
        token_idx_lookup = dict(
            zip(embedding_dict.keys(), range(len(embedding_dict))))
        embedding = nn.Embedding.from_pretrained(
            torch.stack([embedding_dict[token]
                         for token in token_idx_lookup]).to(self.device),
            freeze=self.model_params.freeze_word_embeddings)
        self.lookups = self.lookups.update({
            'entity_candidates_prior':
            lookups['entity_candidates_prior'],
            'entity_labels':
            lookups['entity_labels'],
            'embedding':
            embedding,
            'token_idx_lookup':
            token_idx_lookup
        })
        self.log.status('Getting page id order')
        self.page_id_order = load_page_id_order(self.paths.page_id_order)
        self.num_train_pages = int(
            len(self.page_id_order) * self.train_params.train_size)
        self.page_id_order_train = self.page_id_order[:self.num_train_pages]
        self.page_id_order_test = self.page_id_order[self.num_train_pages:]

    def _get_entity_tokens(self, num_entities):
        mapper = lambda token: self.lookups.token_idx_lookup[
            token
        ] if token in self.lookups.token_idx_lookup else self.lookups.token_idx_lookup[
            '<UNK>']
        entity_indexed_tokens = {
            self.lookups.entity_labels[entity_id]:
            _.map_(parse_for_tokens(text), mapper)
            for entity_id, text in get_entity_text().items()
            if entity_id in self.lookups.entity_labels
        }
        entity_indexed_tokens_list = [
            entity_indexed_tokens[i] if i in entity_indexed_tokens else [1]
            for i in range(num_entities)
        ]
        return torch.tensor(pad_batch_list(0, entity_indexed_tokens_list),
                            device=self.device)

    def _get_entity_wikivecs(self, num_entities):
        vecs_by_label = {
            self.lookups.entity_labels[entity_id]:
            self.wiki2vec.get_entity_vector(text)
            for entity_id, text in get_entity_text().items()
            if entity_id in self.lookups.entity_labels
        }
        vecs_in_order = [
            vecs_by_label[i]
            if i in vecs_by_label else torch.randn(self.model_params.embed_len)
            for i in range(num_entities)
        ]
        return torch.tensor(vecs_in_order, device=self.device)

    def _sum_in_batches(self, by_token):
        results = []
        for chunk in torch.chunk(by_token, 100):
            entity_word_vecs = self.lookups.embedding(chunk)
            results.append(entity_word_vecs.sum(1))
        return torch.cat(results, 0)

    def init_entity_embeds_deep_el(self):
        if self.model_params.word_embed_len == self.model_params.embed_len:
            entities_by_token = self._get_entity_tokens(
                self.model_params.num_entities)
            entity_embed_weights = nn.Parameter(
                self._sum_in_batches(entities_by_token))
        else:
            print(
                f'word embed len: {self.model_params.word_embed_len} != entity embed len {self.model_params.embed_len}. Not initializing entity embeds with word embeddings'
            )
            entity_embed_weights = nn.Parameter(
                torch.Tensor(self.model_params.num_entities,
                             self.model_params.embed_len))
            entity_embed_weights.data.normal_(
                0, 1.0 / math.sqrt(self.model_params.embed_len))
        self.entity_embeds = nn.Embedding(self.model_params.num_entities,
                                          self.model_params.embed_len,
                                          _weight=entity_embed_weights).to(
                                              self.device)

    def init_entity_embeds_wiki2vec(self):
        entity_wikivecs = self._get_entity_wikivecs(
            self.model_params.num_entities)
        entity_embed_weights = nn.Parameter(entity_wikivecs)
        self.entity_embeds = nn.Embedding(self.model_params.num_entities,
                                          self.model_params.embed_len,
                                          _weight=entity_embed_weights).to(
                                              self.device)

    def _get_dataset(self, cursor, is_test, use_fast_sampler=False):
        page_ids = self.page_id_order_test if is_test else self.page_id_order_train
        if self.use_conll:
            return CoNLLDataset(cursor, self.lookups.entity_candidates_prior,
                                self.lookups.embedding,
                                self.lookups.token_idx_lookup,
                                self.model_params.num_entities,
                                self.model_params.num_candidates,
                                self.lookups.entity_labels)
        else:
            return MentionContextDataset(
                cursor,
                page_ids,
                self.lookups.entity_candidates_prior,
                self.lookups.entity_labels,
                self.lookups.embedding,
                self.lookups.token_idx_lookup,
                self.train_params.batch_size,
                self.model_params.num_entities,
                self.model_params.num_candidates,
                cheat=self.run_params.cheat,
                buffer_scale=self.run_params.buffer_scale,
                min_mentions=self.train_params.min_mentions,
                use_fast_sampler=use_fast_sampler,
                use_wiki2vec=self.model_params.use_wiki2vec,
                start_from_page_num=self.train_params.start_from_page_num)

    def _get_sampler(self,
                     cursor,
                     is_test,
                     limit=None,
                     use_fast_sampler=False):
        if self.use_conll:
            return BatchSampler(RandomSampler(self._dataset),
                                self.train_params.batch_size, False)
        else:
            page_ids = self.page_id_order_test if is_test else self.page_id_order_train
            return MentionContextBatchSampler(
                cursor,
                page_ids,
                self.train_params.batch_size,
                self.train_params.min_mentions,
                limit=limit,
                use_fast_sampler=use_fast_sampler)

    def _calc_logits(self, encoded, candidate_entity_ids):
        if self.model_params.use_wiki2vec:
            return Logits()(encoded, self.entity_embeds(candidate_entity_ids))
        else:
            desc_embeds, mention_context_embeds = encoded
            if self.model_params.use_adaptive_softmax:
                raise NotImplementedError('No longer supported')
            elif self.model_params.use_ranking_loss:
                raise NotImplementedError('No longer supported')
            else:
                logits = Logits()
                desc_logits = logits(desc_embeds,
                                     self.entity_embeds(candidate_entity_ids))
                mention_logits = logits(
                    mention_context_embeds,
                    self.entity_embeds(candidate_entity_ids))
            return desc_logits, mention_logits

    def _calc_loss(self, scores, labels_for_batch):
        desc_score, mention_score = scores
        if self.model_params.use_adaptive_softmax:
            raise NotImplementedError('No longer supported')
        elif self.model_params.use_ranking_loss:
            raise NotImplementedError('No longer supported')
        else:
            criterion = nn.CrossEntropyLoss()
            desc_loss = criterion(desc_score, labels_for_batch)
            mention_loss = criterion(mention_score, labels_for_batch)
        return desc_loss + mention_loss

    def _get_trainer(self, cursor, model):
        self._trainer = Trainer(
            device=self.device,
            embedding=self.lookups.embedding,
            token_idx_lookup=self.lookups.token_idx_lookup,
            model=model,
            get_dataset=lambda: self._get_dataset(
                cursor,
                is_test=False,
                use_fast_sampler=self.train_params.use_fast_sampler),
            get_batch_sampler=lambda: self._get_sampler(
                cursor,
                is_test=False,
                limit=self.train_params.dataset_limit,
                use_fast_sampler=self.train_params.use_fast_sampler),
            num_epochs=self.train_params.num_epochs,
            experiment=self.experiment,
            calc_loss=self._calc_loss,
            calc_logits=self._calc_logits,
            logits_and_softmax=self._get_logits_and_softmax(),
            adaptive_logits=self.adaptive_logits,
            use_adaptive_softmax=self.model_params.use_adaptive_softmax,
            clip_grad=self.train_params.clip_grad,
            use_wiki2vec=self.model_params.use_wiki2vec)
        return self._trainer

    def _get_logits_and_softmax(self):
        def get_calc(context):
            if self.model_params.use_adaptive_softmax:
                softmax = self.adaptive_logits[context].log_prob
                calc = lambda hidden, _: softmax(hidden)
            else:
                calc_logits = Logits()
                softmax = Softmax()
                calc = lambda hidden, candidate_entity_ids: softmax(
                    calc_logits(hidden, self.entity_embeds(candidate_entity_ids
                                                           )))
            return calc

        return {context: get_calc(context) for context in ['desc', 'mention']}

    def _get_tester(self, cursor, model):
        logits_and_softmax = self._get_logits_and_softmax()
        test_dataset = self._get_dataset(cursor, is_test=True)
        self._dataset = test_dataset
        batch_sampler = self._get_sampler(cursor, is_test=True)
        return Tester(
            dataset=test_dataset,
            batch_sampler=batch_sampler,
            model=model,
            logits_and_softmax=logits_and_softmax,
            embedding=self.lookups.embedding,
            token_idx_lookup=self.lookups.token_idx_lookup,
            device=self.device,
            experiment=self.experiment,
            ablation=self.model_params.ablation,
            use_adaptive_softmax=self.model_params.use_adaptive_softmax,
            use_wiki2vec=self.model_params.use_wiki2vec)

    def _get_adaptive_calc_logits(self):
        def get_calc(context):
            if self.model_params.use_hardcoded_cutoffs:
                vocab_size = self.entity_embeds.weight.shape[0]
                cutoffs = self.model_params.adaptive_softmax_cutoffs
            else:
                raise NotImplementedError
            in_features = self.entity_embeds.weight.shape[1]
            n_classes = self.entity_embeds.weight.shape[0]
            return AdaptiveLogSoftmaxWithLoss(in_features,
                                              n_classes,
                                              cutoffs,
                                              div_value=1.0).to(self.device)

        calc = get_calc('desc_and_mention')
        return {context: calc for context in ['desc', 'mention']}

    def run_deep_el(self):
        try:
            db_connection = get_connection()
            with db_connection.cursor() as cursor:
                self.load_caches(cursor)
                pad_vector = self.lookups.embedding(
                    torch.tensor([self.lookups.token_idx_lookup['<PAD>']],
                                 device=self.lookups.embedding.weight.device)
                ).squeeze()
                self.init_entity_embeds_deep_el()
                entity_ids_by_freq = self._get_entity_ids_by_freq(cursor)
                if self.model_params.use_adaptive_softmax:
                    self.lookups = self.lookups.set(
                        'entity_labels',
                        _.from_pairs(
                            zip(entity_ids_by_freq,
                                range(len(entity_ids_by_freq)))))
                self.adaptive_logits = self._get_adaptive_calc_logits()
                self.encoder = JointModel(
                    self.model_params.embed_len,
                    self.model_params.context_embed_len,
                    self.model_params.word_embed_len,
                    self.model_params.local_encoder_lstm_size,
                    self.model_params.document_encoder_lstm_size,
                    self.model_params.num_lstm_layers,
                    self.train_params.dropout_drop_prob, self.entity_embeds,
                    self.lookups.embedding, pad_vector, self.adaptive_logits,
                    self.model_params.use_deep_network,
                    self.model_params.use_lstm_local,
                    self.model_params.num_cnn_local_filters,
                    self.model_params.use_cnn_local)
                if self.run_params.load_model:
                    path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                    self.encoder.load_state_dict(torch.load(path))
                    self.encoder = nn.DataParallel(self.encoder)
                    self.encoder = self.encoder.to(self.device).module
                if self.run_params.continue_training:
                    if self.model_params.use_wiki2vec:
                        fields = ['context_error', 'loss']
                    else:
                        fields = [
                            'mention_context_error', 'document_context_error',
                            'loss'
                        ]
                    with self.experiment.train(fields):
                        self.log.status('Training')
                        trainer = self._get_trainer(cursor, self.encoder)
                        trainer.train()
                        torch.save(self.encoder.state_dict(),
                                   './' + self.experiment.model_name)
                with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                    self.log.status('Testing')
                    tester = self._get_tester(cursor, self.encoder)
                    tester.test()
        finally:
            db_connection.close()

    def run_wiki2vec(self):
        try:
            db_connection = get_connection()
            with db_connection.cursor() as cursor:
                self.load_caches(cursor)
                self.wiki2vec = load_wiki2vec()
                self.init_entity_embeds_wiki2vec()
                self.context_encoder = ContextEncoder(
                    self.wiki2vec, self.lookups.token_idx_lookup, self.device)
                self.encoder = SimpleJointModel(self.context_encoder)
                if not self.run_params.load_model:
                    with self.experiment.train(['error', 'loss']):
                        self.log.status('Training')
                        trainer = self._get_trainer(cursor, self.encoder)
                        trainer.train()
                        torch.save(self.encoder.state_dict(),
                                   './' + self.experiment.model_name)
                else:
                    path = self.experiment.model_name if self.run_params.load_path is None else self.run_params.load_path
                    self.encoder.load_state_dict(torch.load(path))
                    self.encoder = nn.DataParallel(self.context_encoder)
                    self.encoder = self.context_encoder.to(self.device).module
                with self.experiment.test(['accuracy', 'TP', 'num_samples']):
                    self.log.status('Testing')
                    tester = self._get_tester(cursor, self.context_encoder)
                    tester.test()
        finally:
            db_connection.close()

    def run(self):
        if self.model_params.use_wiki2vec: self.run_wiki2vec()
        else: self.run_deep_el()