Пример #1
0
    def test_masked_class_logsoftmax_basic(self):
        # shape batch x M x K
        # model outputs
        preds = torch.tensor([[[2.0, 2.0, 1.0], [3.0, 5.0, 4.0]]])
        # all that matters for this test is that the below is non-negative
        # since negative indicates masking
        entity_ids = torch.tensor([[[1, 3, 4], [5, 3, 1]]])
        mask = torch.where(entity_ids < 0, torch.zeros_like(preds),
                           torch.ones_like(preds))
        pred_log_preds = eval_utils.masked_class_logsoftmax(pred=preds,
                                                            mask=mask)
        torch_logsoftmax = torch.nn.LogSoftmax(dim=2)
        torch_log_preds = torch_logsoftmax(preds)
        assert torch.allclose(torch_log_preds, pred_log_preds)

        # if we mask one of the candidates, we should no longer
        # get the same result as torch fn which doesn't mask
        entity_ids = torch.tensor([[[1, 3, 4], [5, 3, -1]]])
        mask = torch.where(entity_ids < 0, torch.zeros_like(preds),
                           torch.ones_like(preds))
        pred_log_preds = eval_utils.masked_class_logsoftmax(pred=preds,
                                                            mask=mask)
        assert not torch.allclose(torch_log_preds, pred_log_preds)
        # make sure masked values are approximately zero before log (when exponented)
        assert torch.allclose(
            torch.tensor([[[0.422319, 0.422319, 0.155362],
                           [0.119203, 0.880797, 0.0]]]),
            torch.exp(pred_log_preds),
        )
Пример #2
0
def disambig_loss(intermediate_output_dict, Y, active):
    """Returns the entity disambiguation loss on prediction heads.

    Args:
        intermediate_output_dict: output dict from the Emmental task flor
        Y: gold labels
        active: whether examples are "active" or not (used in Emmental slicing)

    Returns: loss
    """
    # Grab the first value of training (when doing distributed training, we will have one per process)
    training = intermediate_output_dict[PRED_LAYER]["training"][0].item()
    assert type(training) is bool
    outs = intermediate_output_dict[PRED_LAYER]["final_scores"][DISAMBIG]
    mask = intermediate_output_dict["_input_"]["entity_cand_eid_mask"][active]
    labels = Y[active]
    # During eval, even if our model does not predict a NIC candidate, we allow for a NIC gold QID
    # This qid gets assigned the label of -2 and is always incorrect
    # As NLLLoss assumes classes of 0 to #classes-1 except for pad idx, we manually mask
    # the -2 labels for the loss computation only. As this is just for eval, it won't matter.
    masked_labels = labels
    if not training:
        label_mask = labels == -2
        masked_labels = torch.where(~label_mask, labels,
                                    torch.ones_like(labels) * -1)

    temp = 0
    for out in outs.values():
        # batch x M x K -> transpose -> swap K classes with M spans for "k-dimensional" NLLloss
        log_probs = eval_utils.masked_class_logsoftmax(pred=out[active],
                                                       mask=~mask).transpose(
                                                           1, 2)
        temp += nn.NLLLoss(ignore_index=-1)(log_probs, masked_labels.long())

    return temp
Пример #3
0
 def test_masked_class_logsoftmax_grads_excluded_alias(self):
     preds = torch.tensor([[[2.0, 4.0], [1.0, 4.0], [8.0, 2.0]]],
                          requires_grad=True)
     # batch x M x K
     entity_ids = torch.tensor([[[1, -1], [4, 5], [8, 9]]])
     # batch x M
     true_entity_class = torch.tensor([[0, -1, 1]])
     mask = torch.where(entity_ids < 0, torch.zeros_like(preds),
                        torch.ones_like(preds))
     pred_log_preds = eval_utils.masked_class_logsoftmax(
         pred=preds, mask=mask).transpose(1, 2)
     pred_loss = torch.nn.NLLLoss(ignore_index=-1)(pred_log_preds,
                                                   true_entity_class)
     pred_loss.backward()
     actual_grad = preds.grad
     true_entity_class_expanded = true_entity_class.unsqueeze(-1).expand_as(
         entity_ids)
     masked_actual_grad = torch.where(
         (entity_ids != -1) & (true_entity_class_expanded != -1),
         torch.ones_like(preds),
         actual_grad,
     )
     # just put 1's where we want non-zeros and use mask above to only compare padded gradients
     expected_grad = torch.tensor([[[1.0, 0.0], [0.0, 0.0], [1.0, 1.0]]])
     assert torch.allclose(expected_grad, masked_actual_grad)
Пример #4
0
 def disambig_loss(self, training, outs, true_label, mask):
     """
     Returns the entity disambiguation loss on prediction heads.
     """
     device = next(iter(outs.values()))[0].device
     loss = LossPackage(device)
     for i, (loss_head, out) in enumerate(outs.items()):
         if FINAL_LOSS in loss_head:
             true_label_head = true_label[FINAL_LOSS]
         else:
             true_label_head = true_label[loss_head]
         # During eval, even if our model does not predict a NIC candidate, we allow for a NIC gold QID
         # This qid gets assigned the label of -2 and is always incorrect in eval_wapper
         # As NLLLoss assumes classes of 0 to #classes-1 except for pad idx, we manually mask
         # the -2 labels for the loss computation only. As this is just for eval, it won't matter.
         if not training:
             label_mask = true_label_head == -2
             true_label_head[label_mask] = -1
         # batch x M x K -> transpose -> swap K classes with M spans for "k-dimensional" NLLloss
         log_probs = eval_utils.masked_class_logsoftmax(pred=out, mask=~mask).transpose(1,2)
         temp = self.crit_pred(log_probs, true_label_head.long().to(device))
         loss.add_loss(loss_head, temp)
         if not training:
             true_label_head[label_mask] = -2
     return loss
Пример #5
0
    def test_masked_class_logsoftmax_grads(self):
        # check gradients on preds since that will go back into the rest of the network
        preds = torch.tensor(
            [[[2.0, 4.0, 1.0], [3.0, 5.0, 4.0], [1.0, 4.0, 6.0]]],
            requires_grad=True)
        entity_ids = torch.tensor([[[1, 3, -1], [5, -1, -1], [4, 5, 6]]])
        true_entity_class = torch.tensor([[1, 0, 2]])
        mask = torch.where(entity_ids < 0, torch.zeros_like(preds),
                           torch.ones_like(preds))
        pred_log_preds = eval_utils.masked_class_logsoftmax(
            pred=preds, mask=mask).transpose(1, 2)
        pred_loss = torch.nn.NLLLoss(ignore_index=-1)(pred_log_preds,
                                                      true_entity_class)
        pred_loss.backward()
        actual_grad = preds.grad

        # we want zero grads on masked candidates
        masked_actual_grad = torch.where(entity_ids > 0,
                                         torch.ones_like(preds), actual_grad)
        # just put 1's where we want non-zeros and use mask above to only compare padded gradients
        expected_grad = torch.tensor([[[1.0, 1.0, 0.0], [1.0, 0.0, 0.0],
                                       [1.0, 1.0, 1.0]]])
        assert torch.allclose(expected_grad, masked_actual_grad)

        # we want to match pytorch when NOT using masking
        # zero out the gradient to call backward again
        preds.grad.zero_()

        # no masking now
        entity_ids = torch.tensor([[[1, 3, 1], [5, 4, 8], [4, 5, 6]]])
        true_entity_class = torch.tensor([[1, 0, 2]])
        mask = torch.where(entity_ids < 0, torch.zeros_like(preds),
                           torch.ones_like(preds))
        pred_log_preds = eval_utils.masked_class_logsoftmax(
            pred=preds, mask=mask).transpose(1, 2)
        pred_loss = torch.nn.NLLLoss(ignore_index=-1)(pred_log_preds,
                                                      true_entity_class)
        pred_loss.backward()
        # clone so we can call backward again and zero out the grad
        actual_grad = preds.grad.clone()
        preds.grad.zero_()

        torch_loss_fn = torch.nn.CrossEntropyLoss()
        torch_loss = torch_loss_fn(preds.transpose(1, 2), true_entity_class)
        torch_loss.backward()
        torch_grad = preds.grad
        assert torch.allclose(torch_grad, actual_grad)
Пример #6
0
def disambig_output(intermediate_output_dict):
    """Function to return the probs for a task in Emmental.

    Args:
        intermediate_output_dict: output dict from Emmental task flow

    Returns: NED probabilities for candidates (B x M x K)
    """
    out = intermediate_output_dict[PRED_LAYER]["final_scores"][DISAMBIG][
        FINAL_LOSS]
    mask = intermediate_output_dict["_input_"]["entity_cand_eid_mask"]
    return eval_utils.masked_class_logsoftmax(pred=out, mask=~mask).exp()
Пример #7
0
 def test_masked_class_logsoftmax_masking(self):
     preds = torch.tensor([[[2., 4., 1.], [3., 5., 4.]]])
     entity_ids = torch.tensor([[[1, 3, -1], [5, -1, -1]]])
     first_sample = torch.tensor([[2., 4.]])
     denom_0 = torch.log(torch.sum(torch.exp(first_sample)))
     mask = torch.where(entity_ids < 0, torch.zeros_like(preds), torch.ones_like(preds))
     # we only need to match on non-masked values
     expected_log_probs = torch.tensor([[[
         first_sample[0][0]-denom_0, first_sample[0][1]-denom_0, 0],
         [0, 0, 0]]])
     pred_log_preds = eval_utils.masked_class_logsoftmax(pred=preds,
         mask=mask) * mask
     assert torch.allclose(expected_log_probs, pred_log_preds)
Пример #8
0
 def test_masked_class_logsoftmax_with_loss(self):
     # shape batch x M x K
     # model outputs
     preds = torch.tensor([[[2., 2., 1.], [3., 5., 4.]]])
     # all that matters for this test is that the below is non-negative
     # since negative indicates masking
     entity_ids = torch.tensor([[[1, 3, 4], [5, 3, 1]]])
     true_entity_class = torch.tensor([[0,1]])
     mask = torch.where(entity_ids < 0, torch.zeros_like(preds), torch.ones_like(preds))
     pred_log_preds = eval_utils.masked_class_logsoftmax(pred=preds,
         mask=mask).transpose(1,2)
     pred_loss = self.trainer.scorer.crit_pred(pred_log_preds, true_entity_class)
     torch_loss_fn = torch.nn.CrossEntropyLoss()
     # predictions need to be batch_size x K x M
     torch_loss = torch_loss_fn(preds.transpose(1,2), true_entity_class)
     assert torch.allclose(torch_loss, pred_loss)
Пример #9
0
 def test_masked_class_logsoftmax_grads_full_mask(self):
     preds = torch.tensor([[[2., 4.], [3., 5.], [1., 4.]]], requires_grad=True)
     # batch x M x K
     entity_ids = torch.tensor([[[1, -1], [-1, -1], [4, 5]]])
     # batch x M
     true_entity_class = torch.tensor([[0, -1, 1]])
     mask = torch.where(entity_ids < 0, torch.zeros_like(preds), torch.ones_like(preds))
     pred_log_preds = eval_utils.masked_class_logsoftmax(pred=preds,
         mask=mask).transpose(1,2)
     pred_loss = self.trainer.scorer.crit_pred(pred_log_preds, true_entity_class)
     pred_loss.backward()
     actual_grad = preds.grad
     true_entity_class_expanded = true_entity_class.unsqueeze(-1).expand_as(entity_ids)
     masked_actual_grad = torch.where((entity_ids != -1) & (true_entity_class_expanded != -1), torch.ones_like(preds), actual_grad)
     # just put 1's where we want non-zeros and use mask above to only compare padded gradients
     expected_grad = torch.tensor([[[1., 0.], [0., 0.], [1., 1.]]])
     assert torch.allclose(expected_grad, masked_actual_grad)
Пример #10
0
    def label_mentions(self, text):
        sample = self.extract_mentions(text)
        idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr = sentence_utils.split_sentence(
            max_aliases=self.args.data_config.max_aliases,
            phrase=sample['sentence'],
            spans=sample['spans'],
            aliases=sample['aliases'],
            aliases_seen_by_model=[i for i in range(len(sample['aliases']))],
            seq_len=self.args.data_config.max_word_token_len,
            word_symbols=self.word_db)
        aliases_arr = [[sample['aliases'][idx] for idx in idxs]
                       for idxs in idxs_arr]
        qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr]
        word_indices_arr = [
            self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr
        ]

        if len(idxs_arr) > 1:
            #TODO: support sentences that overflow due to long sequence length or too many mentions
            raise ValueError(
                'Overflowing sentences not currently supported in Annotator')

        # iterate over each sample in the split
        for sub_idx in range(len(idxs_arr)):
            example_aliases = np.ones(self.args.data_config.max_aliases,
                                      dtype=np.int) * PAD_ID
            example_true_entities = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_start = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_end = np.ones(
                self.args.data_config.max_aliases) * PAD_ID

            aliases = aliases_arr[sub_idx]
            for mention_idx, alias in enumerate(aliases):
                # get aliases
                alias_trie_idx = self.entity_db.get_alias_idx(alias)
                alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                example_aliases[mention_idx] = alias_trie_idx

                # alias_idx_pair
                span_idx = spans_arr[sub_idx][mention_idx]
                span_start_idx, span_end_idx = span_idx
                example_aliases_locs_start[mention_idx] = span_start_idx
                example_aliases_locs_end[mention_idx] = span_end_idx

            # get word indices
            word_indices = word_indices_arr[sub_idx]

            # entity indices from alias table (these are the candidates)
            entity_indices = self.alias_table(example_aliases)

            # all CPU embs have to retrieved on the fly
            batch_on_the_fly_data = {}
            for emb_name, emb in self.batch_on_the_fly_embs.items():
                batch_on_the_fly_data[emb_name] = torch.tensor(
                    emb.batch_prep(example_aliases, entity_indices),
                    device=self.device)

            outs, entity_pack, _ = self.model(
                alias_idx_pair_sent=[
                    torch.tensor(example_aliases_locs_start,
                                 device=self.device).unsqueeze(0),
                    torch.tensor(example_aliases_locs_end,
                                 device=self.device).unsqueeze(0)
                ],
                word_indices=torch.tensor([word_indices], device=self.device),
                alias_indices=torch.tensor(example_aliases,
                                           device=self.device).unsqueeze(0),
                entity_indices=torch.tensor(entity_indices,
                                            device=self.device).unsqueeze(0),
                batch_prepped_data={},
                batch_on_the_fly_data=batch_on_the_fly_data)

            entity_cands = eval_utils.map_aliases_to_candidates(
                self.args.data_config.train_in_candidates, self.entity_db,
                aliases)
            # recover predictions
            probs = torch.exp(
                eval_utils.masked_class_logsoftmax(
                    pred=outs[DISAMBIG][FINAL_LOSS],
                    mask=~entity_pack.mask,
                    dim=2))
            max_probs, max_probs_indices = probs.max(2)

            pred_cands = []
            pred_probs = []
            titles = []
            for alias_idx in range(len(aliases)):
                pred_idx = max_probs_indices[0][alias_idx]
                pred_prob = max_probs[0][alias_idx].item()
                pred_qid = entity_cands[alias_idx][pred_idx]
                if pred_prob > self.threshold:
                    pred_cands.append(pred_qid)
                    pred_probs.append(pred_prob)
                    titles.append(
                        self.entity_db.
                        get_title(pred_qid) if pred_qid != 'NC' else 'NC')

            return pred_cands, pred_probs, titles
Пример #11
0
    def forward(self, context_matrix_dict, alias_idx_pair_sent, entity_pack,
                sent_emb):
        out = {DISAMBIG: {}, INDICATOR: {}}
        indicator_outputs = {}
        expert_slice_repr = {}
        predictor_outputs = {}

        if "context_matrix_main" not in context_matrix_dict:
            context_matrix_dict[
                "context_matrix_main"] = model_utils.generate_final_context_matrix(
                    context_matrix_dict)

        context_matrix = context_matrix_dict["context_matrix_main"]

        batch_size, M, K, H = context_matrix.shape
        assert M == self.M
        assert K == self.K
        assert H == self.hidden_size
        for i, slice_head in enumerate(self.train_heads):
            # Generate slice expert representation per head
            # context_matrix is batch x M x K x H
            expert_slice_repr[slice_head] = self.transform_modules[slice_head](
                context_matrix)
            # Pass the expert slice representation through the shared prediction layer
            # Predictor_outputs is batch x M x K
            predictor_outputs[slice_head] = self.shared_slice_pred_head(
                expert_slice_repr[slice_head]).squeeze(-1)
            # Get an alias_matrix output (batch x M x H)
            # TODO: remove extra inputs
            alias_matrix, alias_word_weights = self.ind_alias_mha[slice_head](
                sent_embedding=sent_emb,
                entity_embedding=context_matrix,
                entity_mask=entity_pack.mask,
                alias_idx_pair_sent=alias_idx_pair_sent,
                slice_emb_alias=self.slice_emb_ind_alias(
                    torch.tensor(i, device=context_matrix.device)),
                slice_emb_ent=self.slice_emb_ind_ent(
                    torch.tensor(i, device=context_matrix.device)))
            # Get indicator head outputs; size batch x M x 2 per head
            indicator_outputs[slice_head] = self.indicator_heads[slice_head](
                alias_matrix)

        # Generate predictions via softmax + taking the "positive" class label
        # Output size is batch x M x num_slices
        indicator_preds = torch.cat(
            [
                F.log_softmax(indicator_outputs[slice_head],
                              dim=-1)[:, :, 1].unsqueeze(-1)
                for slice_head in self.train_heads
            ],
            dim=-1,
        )
        assert not torch.isnan(indicator_preds).any()
        assert not torch.isinf(indicator_preds).any()
        # Compute the "confidence"
        # Output size should be batch x M x K x num_slices
        predictor_confidences = torch.cat(
            [
                eval_utils.masked_class_logsoftmax(
                    pred=predictor_outputs[slice_head],
                    mask=~entity_pack.mask).unsqueeze(-1)
                for slice_head in self.train_heads
            ],
            dim=-1,
        )
        assert not torch.isnan(predictor_confidences).any()
        assert not torch.isinf(predictor_confidences).any()
        if self.use_ind_attn:
            prod = indicator_preds  # * margin_confidence
            prod[prod < 0.1] = -100000.0 * self.temperature
            attention_weights = F.softmax(prod / self.temperature, dim=-1)
        else:
            # Take margin confidence over K values to generate confidences of batch x M x num_slices
            vals, indices = torch.topk(predictor_confidences, k=2, dim=2)
            margin_confidence = (vals[:, :, 0, :] -
                                 vals[:, :, 1, :]) / vals.sum(2)
            assert list(margin_confidence.shape) == [
                batch_size, self.M, len(self.train_heads)
            ]
            attention_weights = F.softmax(
                (indicator_preds + margin_confidence) / self.temperature,
                dim=-1)

        assert not torch.isnan(attention_weights).any()
        assert not torch.isinf(attention_weights).any()

        # attention_weights is batch x M x num_slices
        # slice_representations is batch_size x M x K x num_slices x feat_dim
        slice_representations = torch.stack(
            [expert_slice_repr[slice_head] for slice_head in self.train_heads],
            dim=3)

        # attention_weights becomes batch_size x M x K x num_slices x H of slice_representations
        attention_weights = attention_weights.unsqueeze(2).unsqueeze(
            -1).expand_as(slice_representations)
        # Reweight representations with weighted sum across slices
        reweighted_rep = torch.sum(attention_weights * slice_representations,
                                   dim=3)
        assert reweighted_rep.shape == context_matrix.shape
        # Pass through the final prediction layer
        for slice_head in self.train_heads:
            out[DISAMBIG][train_utils.get_slice_head_pred_name(
                slice_head)] = predictor_outputs[slice_head]
        for slice_head in self.train_heads:
            out[INDICATOR][train_utils.get_slice_head_ind_name(
                slice_head)] = indicator_outputs[slice_head]
        # Used for debugging
        if self.remove_final_loss:
            out[DISAMBIG][FINAL_LOSS] = out[DISAMBIG][
                train_utils.get_slice_head_pred_name(BASE_SLICE)]
        else:
            out[DISAMBIG][FINAL_LOSS] = self.final_pred_head(
                reweighted_rep).squeeze(-1)
        return out, reweighted_rep
Пример #12
0
    def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding,
                batch_prepped, batch_on_the_fly_data):
        batch_size = sent_embedding.tensor.shape[0]
        out = {DISAMBIG: {}}

        # Prepare inputs for attention modules

        # Get the KG metadata for the KG module - this is 0/1 if pair is connected
        if REL_INDICES_KEY in batch_on_the_fly_data:
            kg_bias = batch_on_the_fly_data[REL_INDICES_KEY].float().to(
                sent_embedding.tensor.device).reshape(batch_size,
                                                      self.M * self.K,
                                                      self.M * self.K)
        else:
            kg_bias = torch.zeros(batch_size, self.M * self.K, self.M *
                                  self.K).to(sent_embedding.tensor.device)
        kg_bias_diag = kg_bias + self.kg_weight * torch.eye(
            self.M * self.K).repeat(batch_size, 1, 1).view(
                batch_size, self.M * self.K, self.M * self.K).to(
                    kg_bias.device)
        kg_bias_norm = self.softmax(
            kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9)))

        sent_tensor = sent_embedding.tensor.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D to B x (M*K) x D
        entity_mask = entity_embedding.mask
        entity_embedding = entity_embedding.tensor.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_mask.contiguous().view(
            batch_size, self.M * self.K)

        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            #============================================================================
            # Phrase module: compute attention between entities and words
            #============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding.mask,
                    attn_mask=None)
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            #============================================================================
            # Co-occurrence module: compute self attention over entities
            #============================================================================
            # Move entity mask over
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask)

            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = (
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1)
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask, torch.zeros_like(entity_attn_context),
                entity_attn_context)

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            context_matrix = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)

            #============================================================================
            # KG module: add in KG connectivity bias
            #============================================================================
            context_matrix_kg = torch.bmm(kg_bias_norm,
                                          context_matrix.transpose(
                                              0, 1)).transpose(0, 1)

            if stage_index < self.num_model_stages - 1:
                pred = self.predict_layers[DISAMBIG][
                    train_utils.get_stage_head_name(stage_index)](
                        context_matrix)
                pred = pred.transpose(0, 1).squeeze(2).reshape(
                    batch_size, self.M, self.K)
                pred_kg = self.predict_layers[DISAMBIG][
                    train_utils.get_stage_head_name(stage_index)](
                        (context_matrix + context_matrix_kg) / 2)
                pred_kg = pred_kg.transpose(0, 1).squeeze(2).reshape(
                    batch_size, self.M, self.K)
                out[DISAMBIG][
                    f"{train_utils.get_stage_head_name(stage_index)}"] = torch.max(
                        torch.cat([pred.unsqueeze(3),
                                   pred_kg.unsqueeze(3)],
                                  dim=-1),
                        dim=-1)[0]
                pred_logit = eval_utils.masked_class_logsoftmax(
                    pred=pred, mask=~entity_mask)
                context_norm = model_utils.normalize_matrix(
                    (context_matrix + context_matrix_kg) / 2, dim=2)
                assert not torch.isnan(context_norm).any()
                assert not torch.isinf(context_norm).any()
                # Add predictions so model can learn from previous predictions
                context_matrix = context_norm + pred_logit.contiguous().view(
                    batch_size, self.M * self.K, 1).transpose(0, 1)
            else:
                context_matrix_nokg = context_matrix
                context_matrix = (context_matrix + context_matrix_kg) / 2
                context_matrix_main = context_matrix

            query_tensor = context_matrix

        context_matrix_nokg = context_matrix_nokg.transpose(0, 1).reshape(
            batch_size, self.M, self.K, self.hidden_size)
        context_matrix_main = context_matrix_main.transpose(0, 1).reshape(
            batch_size, self.M, self.K, self.hidden_size)

        # context_mat_dict is the contextualized entity embeddings, out is the predictions
        context_mat_dict = {
            "context_matrix_nokg": context_matrix_nokg,
            MAIN_CONTEXT_MATRIX: context_matrix_main
        }
        return context_mat_dict, out
Пример #13
0
        def label_mentions(self, text_list):
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert type(text_list) is list and len(text_list) > 0 and type(
                text_list[0]) is str, f"We only accept inputs of strings and lists of strings"

        ebs = self.args.run_config.eval_batch_size
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(enumerate(text_list), desc="Prepping data", total=len(text_list)):
            sample = self.extract_mentions(text)
            total_start_exs += len(sample['aliases'])
            char_spans = self.get_char_spans(sample['spans'], text)

            final_char_spans.append(char_spans)

            idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs = sentence_utils.split_sentence(
                max_aliases=self.args.data_config.max_aliases,
                phrase=sample['sentence'],
                spans=sample['spans'],
                aliases=sample['aliases'],
                aliases_seen_by_model=[i for i in range(len(sample['aliases']))],
                seq_len=self.args.data_config.max_word_token_len,
                word_symbols=self.word_db)
            aliases_arr = [[sample['aliases'][idx] for idx in idxs] for idxs in idxs_arr]
            old_spans_arr = [[sample['spans'][idx] for idx in idxs] for idxs in idxs_arr]
            qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr]
            word_indices_arr = [self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert len(aliases_to_predict_arr) >= 0, f'There are no aliases to predict for an example. This should not happen at this point.'
                assert len(aliases_arr[
                               sub_idx]) <= self.args.data_config.max_aliases, f'Each example should have no more that {self.args.data_config.max_aliases} max aliases. {sample} does.'

                example_aliases = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_aliases_locs_end = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_alias_list_pos = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_true_entities = np.ones(self.args.data_config.max_aliases) * PAD_ID

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.args.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(alias_qids == qids_arr[sub_idx][mention_idx])[0][0] + (
                            not self.args.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(batch_example_aliases_locs_start, device=self.device)
        batch_example_aliases_locs_end = torch.tensor(batch_example_aliases_locs_end, device=self.device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities, device=self.device)
        batch_word_indices = torch.tensor(batch_word_indices, device=self.device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(range(0, batch_example_aliases.shape[0], ebs), desc="Evaluating model"):
            # entity indices from alias table (these are the candidates)
            batch_entity_indices = self.alias_table(batch_example_aliases[b_i:b_i + ebs])

            # all CPU embs have to retrieved on the fly
            batch_on_the_fly_data = {}
            for emb_name, emb in self.batch_on_the_fly_embs.items():
                batch_prep = []
                for j in range(b_i, min(b_i + ebs, batch_example_aliases.shape[0])):
                    batch_prep.append(emb.batch_prep(batch_example_aliases[j], batch_entity_indices[j - b_i]))
                batch_on_the_fly_data[emb_name] = torch.tensor(batch_prep, device=self.device)

            alias_idx_pair_sent = [batch_example_aliases_locs_start[b_i:b_i + ebs], batch_example_aliases_locs_end[b_i:b_i + ebs]]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            entity_indices = torch.tensor(batch_entity_indices, device=self.device)

            outs, entity_pack, _ = self.model(
                alias_idx_pair_sent=alias_idx_pair_sent,
                word_indices=word_indices,
                alias_indices=alias_indices,
                entity_indices=entity_indices,
                batch_prepped_data={},
                batch_on_the_fly_data=batch_on_the_fly_data)

            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================

            final_loss_vals = outs[DISAMBIG][FINAL_LOSS]
            # recover predictions
            probs = torch.exp(eval_utils.masked_class_logsoftmax(pred=final_loss_vals,
                                                                 mask=~entity_pack.mask, dim=2))
            max_probs, max_probs_indices = probs.max(2)
            for ex_i in range(final_loss_vals.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                subsplit_idx = batch_subsplit_idx[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(self.args.data_config.train_in_candidates,
                                                                    self.entity_db,
                                                                    batch_aliases_arr[b_i + ex_i])

                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].detach().cpu().numpy().reshape(self.args.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(self.entity_db.get_title(pred_qid) if pred_qid != 'NC' else 'NC')
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, f"Something went wrong and we have predicted fewer mentions than extracted. Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        return final_pred_cands, final_pred_probs, final_titles, final_all_cands, final_cand_probs, final_spans, final_aliases