Пример #1
0
class KglmDisc(Model):
    """
    Knowledge graph language model discriminator (for importance sampling).

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 entity_embedder: TextFieldEmbedder,
                 relation_embedder: TextFieldEmbedder,
                 knowledge_graph_path: str,
                 use_shortlist: bool,
                 hidden_size: int,
                 num_layers: int,
                 cutoff: int = 30,
                 tie_weights: bool = False,
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(KglmDisc, self).__init__(vocab)

        # We extract the `Embedding` layers from the `TokenEmbedders` to apply dropout later on.
        # pylint: disable=protected-access
        self._token_embedder = token_embedder._token_embedders['tokens']
        self._entity_embedder = entity_embedder._token_embedders['entity_ids']
        self._relation_embedder = relation_embedder._token_embedders['relations']
        self._recent_entities = RecentEntities(cutoff=cutoff)
        self._knowledge_graph_lookup = KnowledgeGraphLookup(knowledge_graph_path, vocab=vocab)
        self._use_shortlist = use_shortlist
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._cutoff = cutoff
        self._tie_weights = tie_weights

        # Dropout
        self._locked_dropout = LockedDropout()
        self._dropout = dropout
        self._dropouth = dropouth
        self._dropouti = dropouti
        self._dropoute = dropoute
        self._wdrop = wdrop

        # Regularization strength
        self._alpha = alpha
        self._beta = beta

        # RNN Encoders.
        entity_embedding_dim = entity_embedder.get_output_dim()
        token_embedding_dim = token_embedder.get_output_dim()
        self.entity_embedding_dim = entity_embedding_dim
        self.token_embedding_dim = token_embedding_dim

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = token_embedding_dim
            else:
                input_size = hidden_size
            if i == num_layers - 1:
                output_size = token_embedding_dim + 2 * entity_embedding_dim
            else:
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        # Various linear transformations.
        self._fc_mention_type = torch.nn.Linear(
            in_features=token_embedding_dim,
            out_features=4)

        if not use_shortlist:
            self._fc_new_entity = torch.nn.Linear(
                in_features=entity_embedding_dim,
                out_features=vocab.get_vocab_size('entity_ids'))

            if tie_weights:
                self._fc_new_entity.weight = self._entity_embedder.weight

        self._state: Optional[Dict[str, Any]] = None

        # Metrics
        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))
        self._avg_mention_type_loss = Average()
        self._avg_new_entity_loss = Average()
        self._avg_knowledge_graph_entity_loss = Average()
        self._new_mention_f1 = F1Measure(positive_label=1)
        self._kg_mention_f1 = F1Measure(positive_label=2)
        self._new_entity_accuracy = CategoricalAccuracy()
        self._new_entity_accuracy20 = CategoricalAccuracy(top_k=20)
        self._parent_ppl = Ppl()
        self._relation_ppl = Ppl()

        initializer(self)

    def sample(self,
               source: Dict[str, torch.Tensor],
               target: Dict[str, torch.Tensor],
               reset: torch.Tensor,
               metadata: Dict[str, Any],
               alias_copy_inds: torch.Tensor,
               shortlist: Dict[str, torch.Tensor] = None,
               **kwargs) -> Dict[str, Any]:  # **kwargs intended to eat the other fields if they are provided.
        """
        Sampling annotations for the generative model. Note that unlike forward, this function
        expects inputs from a **generative** dataset reader, not a **discriminative** one.
        """

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        logp = 0.0

        mask = get_text_field_mask(target).byte()
        # We encode the target tokens (**not** source) since the discriminative model makes
        # predictions on the current token, but the generative model expects labels for the
        # **next** (e.g. target) token!
        encoded, *_ = self._encode_source(target['tokens'])
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

        # Compute new mention logits
        mention_logits = self._fc_mention_type(encoded_token)
        mention_probs = F.softmax(mention_logits, dim=-1)
        mention_type = parallel_sample(mention_probs)
        mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log()
        mention_logp[~mask] = 0
        mention_logp = mention_logp.sum()

        # Compute entity logits
        new_entity_mask = mention_type.eq(1)
        new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist)
        if self._use_shortlist:
            # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up
            shortlist_mask = get_text_field_mask(shortlist)
            new_entity_probs = masked_softmax(new_entity_logits, shortlist_mask)
            shortlist_inds = torch.zeros_like(mention_type)
            # Some sequences may be full of padding in which case the shortlist
            # is empty
            not_just_padding = shortlist_mask.byte().any(-1)
            shortlist_inds[not_just_padding] = parallel_sample(new_entity_probs[not_just_padding])
            shortlist_inds[~new_entity_mask] = 0
            _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log()
            new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds)
        else:
            new_entity_logits = new_entity_logits
            # If not using shortlist, then samples are indexed w.r.t to the global vocab
            new_entity_probs = F.softmax(new_entity_logits, dim=-1)
            new_entity_samples = parallel_sample(new_entity_probs)
            _new_entity_logp = new_entity_probs.gather(-1, new_entity_samples.unsqueeze(-1)).log()
            shortlist_inds = None
        # Zero out masked tokens and non-new entity predictions
        _new_entity_logp[~mask] = 0
        _new_entity_logp[~new_entity_mask] = 0
        new_entity_logp = _new_entity_logp.sum()

        # Start filling in the entity ids
        entity_ids = torch.zeros_like(target['tokens'])
        entity_ids[new_entity_mask] = new_entity_samples[new_entity_mask]

        # ...UGH we also need the raw ids - remapping time
        raw_entity_ids = torch.zeros_like(target['tokens'])
        for *index, entity_id in nested_enumerate(entity_ids.tolist()):
            token = self.vocab.get_token_from_index(entity_id, 'entity_ids')
            raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids')
            raw_entity_ids[tuple(index)] = raw_entity_id

        # Derived mentions need to be computed sequentially.
        parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1)
        derived_entity_mask = mention_type.eq(2)
        derived_entity_logp = 0.0

        sequence_length = target['tokens'].shape[1]

        for i in range(sequence_length):

            current_mask = derived_entity_mask[:, i] & mask[:, i]

            # -------------------  SAMPLE PARENTS ---------------------

            # Update recent entities with **current** entity only
            current_entity_id = entity_ids[:, i].unsqueeze(1)
            candidate_ids, candidate_mask = self._recent_entities(current_entity_id)

            # If no mentions are derived, there is no point continuing after entities have been updated.
            if not current_mask.any():
                continue

            # Otherwise we proceed
            candidate_embeddings = self._entity_embedder(candidate_ids)

            # Compute logits w.r.t **current** hidden state only
            current_head_encoding = encoded_head[:, i].unsqueeze(1)
            selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2))
            selection_probs = masked_softmax(selection_logits, candidate_mask)

            # Only sample if there is at least one viable candidate (e.g. if a sampling distribution
            # has no probability mass we cannot sample from it). Return zero as the parent for
            # non-viable distributions.
            viable_candidate_mask = candidate_mask.any(-1).squeeze()
            _parent_ids = torch.zeros_like(current_entity_id)
            parent_logp = torch.zeros_like(current_entity_id, dtype=torch.float32)
            if viable_candidate_mask.any():
                viable_candidate_ids = candidate_ids[viable_candidate_mask]
                viable_candidate_probs = selection_probs[viable_candidate_mask]
                viable_parent_samples = parallel_sample(viable_candidate_probs)
                viable_logp = viable_candidate_probs.gather(-1, viable_parent_samples.unsqueeze(-1)).log()
                viable_parent_ids = viable_candidate_ids.gather(-1, viable_parent_samples)
                _parent_ids[viable_candidate_mask] = viable_parent_ids
                parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1)

            parent_ids[current_mask, i] = _parent_ids[current_mask]  # TODO: Double-check
            derived_entity_logp += parent_logp[current_mask].sum()

            # ---------------------- SAMPLE RELATION -----------------------------

            # Lookup sampled parent ids in the knowledge graph
            indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(_parent_ids)
            relation_embeddings = [self._relation_embedder(r) for r in relations_list]

            # Sample tail ids
            current_relation_encoding = encoded_relation[:, i].unsqueeze(1)
            _raw_tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            _tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            for index, relation_embedding, tail_id_lookup in zip(indices, relation_embeddings, tail_ids_list):
                # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss
                # code index has a slice. We don't need that here since there is always a
                # **single** parent.
                logits = torch.mv(relation_embedding, current_relation_encoding[index])
                # Convert to probability
                tail_probs = F.softmax(logits, dim=-1)
                # Sample
                tail_sample = torch.multinomial(tail_probs, 1)
                # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced
                # null parents to zero we shouldn't be accumulating probabilities for unused predictions.
                tail_logp = tail_probs.gather(-1, tail_sample).log()
                derived_entity_logp += tail_logp.sum()  # Sum is redundant, just need it to make logp a scalar

                # Map back to raw id
                raw_tail_id = tail_id_lookup[tail_sample]
                # Convert raw id to id
                tail_id_string = self.vocab.get_token_from_index(raw_tail_id.item(), 'raw_entity_ids')
                tail_id = self.vocab.get_token_index(tail_id_string, 'entity_ids')

                _raw_tail_ids[index[:-1]] = raw_tail_id
                _tail_ids[index[:-1]] = tail_id

            raw_entity_ids[current_mask, i] = _raw_tail_ids[current_mask]  # TODO: Double-check
            entity_ids[current_mask, i] = _tail_ids[current_mask]  # TODO: Double-check

            self._recent_entities.insert(_tail_ids, current_mask)

            # --------------------- CONTINUE MENTIONS ---------------------------------------
            continue_mask = mention_type[:, i].eq(3) & mask[:, i]
            if not current_mask.any() or i == 0:
                continue
            raw_entity_ids[continue_mask, i] = raw_entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            parent_ids[continue_mask, i] = parent_ids[continue_mask, i-1]
            if self._use_shortlist:
                shortlist_inds[continue_mask, i] = shortlist_inds[continue_mask, i-1]
            alias_copy_inds[continue_mask, i] = alias_copy_inds[continue_mask, i-1]

        # Lastly, because entities won't always match the true entity ids,
        # we need to zero out any alias copy ids that won't be valid.
        if 'raw_entity_ids' in kwargs:
            true_raw_entity_ids = kwargs['raw_entity_ids']['raw_entity_ids']
            invalid_id_mask = ~true_raw_entity_ids.eq(raw_entity_ids)
            alias_copy_inds[invalid_id_mask] = 0

        # Pass denotes fields that are passed directly from input to output.
        sample = {
            'source': source,  # Pass
            'target': target,  # Pass
            'reset': reset,  # Pass
            'metadata': metadata,  # Pass
            'mention_type': mention_type,
            'raw_entity_ids': {'raw_entity_ids': raw_entity_ids},
            'entity_ids': {'entity_ids': entity_ids},
            'parent_ids': {'entity_ids': parent_ids},
            'relations': {'relations': None},  # We aren't using them - eventually should remove entirely
            'shortlist': shortlist,  # Pass
            'shortlist_inds': shortlist_inds,
            'alias_copy_inds': alias_copy_inds
        }
        logp = mention_logp + new_entity_logp + derived_entity_logp
        return {'sample': sample, 'logp': logp}

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                metadata: List[Dict[str, Any]],
                mention_type: torch.Tensor = None,
                raw_entity_ids: Dict[str, torch.Tensor] = None,
                entity_ids: Dict[str, torch.Tensor] = None,
                parent_ids: Dict[str, torch.Tensor] = None,
                relations: Dict[str, torch.Tensor] = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                alias_database=alias_database,
                mention_type=mention_type,
                raw_entity_ids=raw_entity_ids,
                entity_ids=entity_ids,
                parent_ids=parent_ids,
                relations=relations,
                shortlist=shortlist,
                shortlist_inds=shortlist_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict

    def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor:

        # Extract and embed source tokens.
        source_embeddings = embedded_dropout(
            embed=self._token_embedder,
            words=source,
            dropout=self._dropoute if self.training else 0)
        source_embeddings = self._locked_dropout(source_embeddings, self._dropouti)

        # Encode.
        current_input = source_embeddings
        hidden_states = []
        for layer, rnn in enumerate(self.rnns):
            # Retrieve previous hidden state for layer.
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None
            # Forward-pass.
            output, hidden = rnn(current_input, prev_hidden)
            output = output.contiguous()
            # Update hidden state for layer.
            hidden = tuple(h.detach() for h in hidden)
            hidden_states.append(hidden)
            # Apply dropout.
            if layer == self._num_layers - 1:
                dropped_output = self._locked_dropout(output, self._dropout)
            else:
                dropped_output = self._locked_dropout(output, self._dropouth)
            current_input = dropped_output
        encoded = current_input

        alpha_loss = dropped_output.pow(2).mean()
        beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update state.
        self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)}

        return encoded, alpha_loss, beta_loss

    def _mention_type_loss(self,
                           encoded: torch.Tensor,
                           mention_type: torch.Tensor,
                           mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss for predicting whether or not the the next token will be part of an
        entity mention.
        """
        logits = self._fc_mention_type(encoded)
        mention_type_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask,
                                                               average='token')
        # if not self.training:
        self._new_mention_f1(predictions=logits,
                             gold_labels=mention_type,
                             mask=mask)
        self._kg_mention_f1(predictions=logits,
                            gold_labels=mention_type,
                            mask=mask)

        return mention_type_loss

    def _new_entity_logits(self,
                           encoded: torch.Tensor,
                           shortlist: torch.Tensor) -> torch.Tensor:
        if self._use_shortlist:
            # Embed the shortlist entries
            shortlist_embeddings = embedded_dropout(
                embed=self._entity_embedder,
                words=shortlist['entity_ids'],
                dropout=self._dropoute if self.training else 0)
            # Compute logits using inner product between the predicted entity embedding and the
            # embeddings of entities in the shortlist
            encodings = self._locked_dropout(encoded, self._dropout)
            logits = torch.bmm(encodings, shortlist_embeddings.transpose(1, 2))
        else:
            logits = self._fc_new_entity(encoded)
        return logits

    def _new_entity_loss(self,
                         encoded: torch.Tensor,
                         target_inds: torch.Tensor,
                         shortlist: torch.Tensor,
                         target_mask: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ==========
        target_inds : ``torch.Tensor``
            Either the shortlist inds if using shortlist, otherwise the target entity ids.
        """
        logits = self._new_entity_logits(encoded, shortlist)
        if self._use_shortlist:
            # Take masked softmax to get log probabilties and gather the targets.
            shortlist_mask = get_text_field_mask(shortlist)
            log_probs = masked_log_softmax(logits, shortlist_mask)
        else:
            logits = logits
            log_probs = F.log_softmax(logits, dim=-1)
            num_categories = log_probs.shape[-1]
            log_probs = log_probs.view(-1, num_categories)
            target_inds = target_inds.view(-1)
        target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1)

        mask = ~target_inds.eq(0)
        target_log_probs[~mask] = 0

        if mask.any():
            self._new_entity_accuracy(predictions=log_probs[mask],
                                    gold_labels=target_inds[mask])
            self._new_entity_accuracy20(predictions=log_probs[mask],
                                        gold_labels=target_inds[mask])

        return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

    def _parent_log_probs(self,
                          encoded_head: torch.Tensor,
                          entity_ids: torch.Tensor,
                          parent_ids: torch.Tensor) -> torch.Tensor:
        # Lookup recent entities (which are candidates for parents) and get their embeddings.
        candidate_ids, candidate_mask = self._recent_entities(entity_ids)
        logger.debug('Candidate ids shape: %s', candidate_ids.shape)
        candidate_embeddings = embedded_dropout(self._entity_embedder,
                                                words=candidate_ids,
                                                dropout=self._dropoute if self.training else 0)

        # Logits are computed using a general bilinear form that measures the similarity between
        # the projected hidden state and the embeddings of candidate entities
        encoded = self._locked_dropout(encoded_head, self._dropout)
        selection_logits = torch.bmm(encoded, candidate_embeddings.transpose(1, 2))

        # Get log probabilities using masked softmax (need to double check mask works properly).

        # shape: (batch_size, sequence_length, num_candidates)
        log_probs = masked_log_softmax(selection_logits, candidate_mask)

        # Now for the tricky part. We need to convert the parent ids to a mask that selects the
        # relevant probabilities from log_probs. To do this we need to align the candidates with
        # the parent ids, which can be achieved by an element-wise equality comparison. We also
        # need to ensure that null parents are not selected.

        # shape: (batch_size, sequence_length, num_parents, 1)
        _parent_ids = parent_ids.unsqueeze(-1)

        batch_size, num_candidates = candidate_ids.shape
        # shape: (batch_size, 1, 1, num_candidates)
        _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates)

        # shape: (batch_size, sequence_length, num_parents, num_candidates)
        is_parent = _parent_ids.eq(_candidate_ids)
        # shape: (batch_size, 1, 1, num_candidates)
        non_null = ~_candidate_ids.eq(0)

        # Since multiplication is addition in log-space, we can apply mask by adding its log (+
        # some small constant for numerical stability).
        mask = is_parent & non_null
        masked_log_probs = log_probs.unsqueeze(2) + (mask.float() + 1e-45).log()
        logger.debug('Masked log probs shape: %s', masked_log_probs.shape)

        # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would
        # be to marginalize it out. However, since our data is sparse (the last two dims are
        # essentially a delta function) this would add a lot of unneccesary terms to the computation graph.
        # To get around this we are going to try to use a gather.
        _, index = torch.max(mask, dim=-1, keepdim=True)
        target_log_probs = torch.gather(masked_log_probs, dim=-1, index=index).squeeze(-1)

        return target_log_probs

    def _relation_log_probs(self,
                            encoded_relation: torch.Tensor,
                            raw_entity_ids: torch.Tensor,
                            parent_ids: torch.Tensor) -> torch.Tensor:

        # Lookup edges out of parents
        indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(parent_ids)

        # Embed relations
        relation_embeddings = [self._relation_embedder(r) for r in relations_list]

        # Logits are computed using a general bi-linear form that measures the similarity between
        # the projected hidden state and the embeddings of relations
        encoded = self._locked_dropout(encoded_relation, self._dropout)

        # This is a little funky, but to avoid massive amounts of padding we are going to just
        # iterate over the relation and tail_id vectors one-by-one.
        # shape: (batch_size, sequence_length, num_parents, num_relations)
        target_log_probs = encoded.new_empty(*parent_ids.shape).fill_(math.log(1e-45))
        for index, parent_id, relation_embedding, tail_id in zip(indices, parent_ids_list, relation_embeddings, tail_ids_list):
            # First we compute the score for each relation w.r.t the current encoding, and convert
            # the scores to log-probabilities
            logits = torch.mv(relation_embedding, encoded[index[:-1]])
            logger.debug('Relation logits shape: %s', logits.shape)
            log_probs = F.log_softmax(logits, dim=-1)

            # Next we gather the log probs for edges with the correct tail entity and sum them up
            target_id = raw_entity_ids[index[:-1]]
            mask = tail_id.eq(target_id)
            relevant_log_probs = log_probs.masked_select(tail_id.eq(target_id))
            target_log_prob = torch.logsumexp(relevant_log_probs, dim=0)
            target_log_probs[index] = target_log_prob

        return target_log_probs

    def _knowledge_graph_entity_loss(self,
                                     encoded_head: torch.Tensor,
                                     encoded_relation: torch.Tensor,
                                     raw_entity_ids: torch.Tensor,
                                     entity_ids: torch.Tensor,
                                     parent_ids: torch.Tensor,
                                     target_mask: torch.Tensor) -> torch.Tensor:
        # First get the log probabilities of the parents and relations that lead to the current
        # entity.
        parent_log_probs = self._parent_log_probs(encoded_head, entity_ids, parent_ids)
        relation_log_probs = self._relation_log_probs(encoded_relation, raw_entity_ids, parent_ids)
        # Next take their product + marginalize
        combined_log_probs = parent_log_probs + relation_log_probs
        target_log_probs = torch.logsumexp(combined_log_probs, dim=-1)
        # Zero out any non-kg predictions
        mask = ~parent_ids.eq(0).all(dim=-1)
        target_log_probs = target_log_probs * mask.float()
        # If validating, measure ppl of the predictions:
        # if not self.training:
        self._parent_ppl(-torch.logsumexp(parent_log_probs, dim=-1)[mask].sum(), mask.float().sum())
        self._relation_ppl(-torch.logsumexp(relation_log_probs, dim=-1)[mask].sum(), mask.float().sum())
        # Lastly return the tokenwise average loss
        return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      mention_type: torch.Tensor,
                      raw_entity_ids: Dict[str, torch.Tensor],
                      entity_ids: Dict[str, torch.Tensor],
                      parent_ids: Dict[str, torch.Tensor],
                      relations: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get the token mask and extract indexed text fields.
        # shape: (batch_size, sequence_length)
        target_mask = get_text_field_mask(source)
        source = source['tokens']
        raw_entity_ids = raw_entity_ids['raw_entity_ids']
        entity_ids = entity_ids['entity_ids']
        parent_ids = parent_ids['entity_ids']
        relations = relations['relations']

        logger.debug('Source & Target shape: %s', source.shape)
        logger.debug('Entity ids shape: %s', entity_ids.shape)
        logger.debug('Relations & Parent ids shape: %s', relations.shape)
        logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
        # Embed source tokens.
        # shape: (batch_size, sequence_length, embedding_dim)
        encoded, alpha_loss, beta_loss = self._encode_source(source)
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

        # Predict whether or not the next token will be an entity mention, and if so which type.
        mention_type_loss = self._mention_type_loss(encoded_token, mention_type, target_mask)
        self._avg_mention_type_loss(float(mention_type_loss))

        # For new mentions, predict which entity (among those in the supplied shortlist) will be
        # mentioned.
        if self._use_shortlist:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    shortlist_inds,
                                                    shortlist,
                                                    target_mask)
        else:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    entity_ids,
                                                    None,
                                                    target_mask)

        self._avg_new_entity_loss(float(new_entity_loss))

        # For derived mentions, first predict which parent(s) to expand...
        knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head,
                                                                        encoded_relation,
                                                                        raw_entity_ids,
                                                                        entity_ids,
                                                                        parent_ids,
                                                                        target_mask)
        self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss))

        # Compute total loss
        loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * alpha_loss
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * beta_loss

        return {'loss': loss}

    @overrides
    def train(self, mode=True):
        # TODO: This is a temporary hack to ensure that the internal state resets when the model
        # switches from training to evaluation. The complication arises from potentially differing
        # batch sizes (e.g. the `reset` tensor will not be the right size).
        # In future implementations this should be handled more robustly.
        super().train(mode)
        self._state = None

    @overrides
    def eval(self):
        # TODO: See train.
        super().eval()
        self._state = None

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        out =  {
            'type': self._avg_mention_type_loss.get_metric(reset),
            'new': self._avg_new_entity_loss.get_metric(reset),
            'kg': self._avg_knowledge_graph_entity_loss.get_metric(reset),
        }
        # if not self.training:
        p, r, f = self._new_mention_f1.get_metric(reset)
        out['new_p'] = p
        out['new_r'] = r
        out['new_f1'] = f
        p, r, f = self._kg_mention_f1.get_metric(reset)
        out['kg_p'] = p
        out['kg_r'] = r
        out['kg_f1'] = f
        out['new_ent_acc'] = self._new_entity_accuracy.get_metric(reset)
        out['new_ent_acc_20'] = self._new_entity_accuracy20.get_metric(reset)
        out['parent_ppl'] = self._parent_ppl.get_metric(reset)
        out['relation_ppl'] = self._relation_ppl.get_metric(reset)
        return out
Пример #2
0
class AliasCopynet(Model):
    """
    Oracle alias copynet language model.

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed tokens.
    encoder : ``Seq2SeqEncoder``
        Used to encode the sequence of token embeddings.
    embedding_dim : ``int``
        The dimension of entity / length embeddings. Should match the encoder output size.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 entity_embedder: TextFieldEmbedder,
                 alias_encoder: Seq2SeqEncoder,
                 hidden_size: int,
                 num_layers: int,
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 tie_weights: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(AliasCopynet, self).__init__(vocab)

        # Model architecture - Note: we need to extract the `Embedding` layers from the
        # `TokenEmbedders` to apply dropout later on.
        # pylint: disable=protected-access
        self._token_embedder = token_embedder._token_embedders['tokens']
        self._entity_embedder = entity_embedder._token_embedders['entity_ids']
        self._alias_encoder = alias_encoder
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._tie_weights = tie_weights

        # Dropout
        self._locked_dropout = LockedDropout()
        self._dropout = dropout
        self._dropouth = dropouth
        self._dropouti = dropouti
        self._dropoute = dropoute
        self._wdrop = wdrop

        # Regularization strength
        self._alpha = alpha
        self._beta = beta

        # RNN Encoders. TODO: Experiment with seperate encoder for aliases.
        entity_embedding_dim = entity_embedder.get_output_dim()
        token_embedding_dim = entity_embedder.get_output_dim()
        assert entity_embedding_dim == token_embedding_dim
        embedding_dim = token_embedding_dim

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = token_embedding_dim
            else:
                input_size = hidden_size
            if (i == num_layers - 1) and tie_weights:
                output_size = token_embedding_dim
            else:
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        # Various linear transformations.
        self._fc_mention = torch.nn.Linear(
            in_features=embedding_dim,
            out_features=2)

        self._fc_entity = torch.nn.Linear(
            in_features=embedding_dim,
            out_features=embedding_dim)

        self._fc_condense = torch.nn.Linear(
            in_features=2 * embedding_dim,
            out_features=embedding_dim)

        self._fc_generate = torch.nn.Linear(
            in_features=embedding_dim,
            out_features=vocab.get_vocab_size('tokens'))

        self._fc_copy = torch.nn.Linear(
            in_features=embedding_dim,
            out_features=embedding_dim)

        if tie_weights:
            self._fc_generate.weight = self._token_embedder.weight

        self._state: Optional[Dict[str, Any]]= None

        # Metrics
        # self._avg_mention_loss = Average()
        # self._avg_entity_loss = Average()
        # self._avg_vocab_loss = Average()
        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))
        self._ppl = Ppl()
        self._upp = Ppl()
        self._kg_ppl = Ppl()  # Knowledge-graph ppl
        self._bg_ppl = Ppl()  # Background ppl

        initializer(self)

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                entity_ids: torch.Tensor = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None,
                alias_copy_inds: torch.Tensor = None,
                alias_tokens: torch.Tensor = None,
                alias_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        alias_tokens = alias_tokens['tokens']
        # Inds have fixed size and don't get truncated on split so truncate
        # now.
        alias_inds = alias_inds[:, :, :alias_tokens.shape[2]]

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                target=target,
                alias_copy_inds=alias_copy_inds,
                alias_tokens=alias_tokens,
                alias_inds=alias_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict

    def _mention_loss(self,
                      encoded: torch.Tensor,
                      targets: torch.Tensor,
                      mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss for predicting whether or not the the next token will be part of an
        entity mention.
        """
        logits = self._fc_mention(encoded)
        mention_loss = sequence_cross_entropy_with_logits(logits, targets, mask,
                                                          average='token')
        return mention_loss

    def _entity_loss(self,
                     encoded: torch.Tensor,
                     targets: torch.Tensor,
                     mask: torch.Tensor,
                     shortlist_embeddings: torch.Tensor,
                     shortlist_mask: torch.Tensor) -> torch.Tensor:
        # Logits are computed using a bilinear form that measures the similarity between the
        # projected hidden state and the embeddings of entities in the shortlist
        projected = self._fc_entity(encoded)
        projected = self._locked_dropout(projected, self._dropout)
        logits = torch.bmm(projected, shortlist_embeddings.transpose(1, 2))

        # There are technically two masks that need to be accounted for: a class-wise mask which
        # specifies which logits to ignore in the class dimension, and a token-wise mask (e.g.
        # `mask`) which avoids measuring loss for predictions on non-mention tokens. In practice,
        # we only need the class-wise mask since the non-mention tokens cannot be associated with a
        # valid target.
        batch_size = encoded.shape[0]
        entity_loss = 0.0
        for i in range(batch_size):
            entity_loss += F.cross_entropy(
                input=logits[i],
                target=targets[i],
                weight=shortlist_mask[i].float(),
                reduction='sum')
        entity_loss = entity_loss / (mask.float().sum() + 1e-13)
        return entity_loss

    def _copy_scores(self,
                     encoded: torch.Tensor,
                     alias_tokens: torch.Tensor) -> torch.Tensor:
        # Begin by flattening the tokens so that they fit the expected shape of a
        # ``Seq2SeqEncoder``.
        batch_size, sequence_length, alias_length = alias_tokens.shape
        flattened = alias_tokens.view(-1, alias_length)
        copy_mask = flattened != 0
        if copy_mask.sum() == 0:
            return encoded.new_zeros((batch_size, sequence_length, alias_length),
                                     dtype=torch.float32)

        # Embed and encode the alias tokens.
        embedded = self._token_embedder(flattened)
        mask = flattened.gt(0)
        encoded_aliases = self._alias_encoder(embedded, mask)

        # Equation 8 in the CopyNet paper recommends applying the additional step.
        projected = torch.tanh(self._fc_copy(encoded_aliases))
        projected = self._locked_dropout(projected, self._dropout)

        # This part gets a little funky - we need to make sure that the first dimension in
        # `projected` and `hidden` is batch_size x sequence_length.
        encoded = encoded.view(batch_size * sequence_length, 1, -1)
        projected = projected.view(batch_size * sequence_length, -1, alias_length)
        copy_scores = torch.bmm(encoded, projected).squeeze()
        copy_scores = copy_scores.view(batch_size, sequence_length, -1).contiguous()

        return copy_scores

    def _vocab_loss(self,
                    generate_scores: torch.Tensor,
                    copy_scores: torch.Tensor,
                    target_tokens: torch.Tensor,
                    target_alias_indices: torch.Tensor,
                    mask: torch.Tensor,
                    alias_indices: torch.Tensor,
                    alias_tokens: torch.Tensor):
        mention_mask = target_alias_indices.gt(0)
        batch_size, sequence_length, vocab_size = generate_scores.shape
        copy_sequence_length = copy_scores.shape[-1]

        # Flat sequences make life **much** easier.
        flattened_targets = target_tokens.view(batch_size * sequence_length, 1)
        flattened_mask = mask.view(-1, 1).byte()
        alias_mask = alias_indices.view(batch_size, sequence_length, -1).gt(0)

        # The log-probability distribution is then given by taking the masked log softmax.
        generate_log_probs = masked_log_softmax(generate_scores,
                                                torch.ones_like(generate_scores))
        copy_log_probs = masked_log_softmax(copy_scores, alias_mask)

        # GENERATE LOSS ###
        # The generated token loss is a simple cross-entropy calculation, we can just gather
        # the log probabilties...
        flattened_log_probs = generate_log_probs.view(batch_size * sequence_length, -1)
        generate_log_probs = flattened_log_probs.gather(1, flattened_targets)
        # ...except we need to ignore the contribution of UNK tokens that are
        # copied (always in the simplified model). To do that we create a mask
        # which is 1 only if the token is not a copied UNK (or padding).
        unks = target_tokens.eq(self._unk_index).view(-1, 1)
        copied = target_alias_indices.gt(0).view(-1, 1)
        generate_mask = ~copied & flattened_mask
        # Since we are in log-space we apply the mask by addition.
        generate_log_probs = generate_log_probs + (generate_mask.float() + 1e-45).log()

        # COPY LOSS ###
        copy_log_probs = copy_log_probs.view(batch_size * sequence_length, -1)
        # When computing the loss we need to get the log probability of **only** the copied tokens.
        alias_indices = alias_indices.view(batch_size * sequence_length, -1)
        target_alias_indices = target_alias_indices.view(-1, 1)
        copy_mask = alias_indices.eq(target_alias_indices) & flattened_mask & target_alias_indices.gt(0)
        copy_log_probs = copy_log_probs + (copy_mask.float() + 1e-45).log()

        # COMBINED LOSS ###
        # The final loss term is computed using our log probs computed w.r.t to the entire
        # vocabulary.
        kg_mask = (mention_mask & mask.byte()).view(-1)
        bg_mask = (~mention_mask & mask.byte()).view(-1)
        mask = mask.byte().view(-1)

        combined_log_probs = torch.cat((generate_log_probs,
                                        copy_log_probs),
                                       dim=1)
        combined_log_probs = torch.logsumexp(combined_log_probs,
                                             dim=1)
        vocab_loss = -combined_log_probs[mask].sum() / (mask.float().sum() + 1e-13)

        # PERPLEXITY ###
        # Our perplexity terms are computed using the log probs computed w.r.t the source
        # vocabulary.

        # For UPP we penalize **only** p(UNK); not the copy probabilities!
        penalized_log_probs = generate_log_probs - self._unk_penalty * unks.float()
        penalized_log_probs = torch.cat((penalized_log_probs,
                                         copy_log_probs),
                                        dim=1)
        penalized_log_probs = torch.logsumexp(penalized_log_probs,
                                              dim=1)

        self._ppl(-combined_log_probs[mask].sum(), mask.float().sum() + 1e-13)
        self._upp(-penalized_log_probs[mask].sum(), mask.float().sum() + 1e-13)
        if kg_mask.any():
            self._kg_ppl(-combined_log_probs[kg_mask].sum(), kg_mask.float().sum() + 1e-13)
        if bg_mask.any():
            self._bg_ppl(-combined_log_probs[bg_mask].sum(), bg_mask.float().sum() + 1e-13)

        return vocab_loss

    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      target: Dict[str, torch.Tensor],
                      alias_copy_inds: torch.Tensor,
                      alias_tokens: torch.Tensor,
                      alias_inds: torch.Tensor) -> Dict[str, torch.Tensor]:

        # Get the token mask and unwrap the target tokens.
        target_mask = get_text_field_mask(target)
        target = target['tokens']

        # Embed source tokens.
        source = source['tokens']
        source_embeddings = embedded_dropout(
            embed=self._token_embedder,
            words=source,
            dropout=self._dropoute if self.training else 0)
        source_embeddings = self._locked_dropout(source_embeddings, self._dropouti)

        # Encode source tokens.
        current_input = source_embeddings
        hidden_states = []
        for layer, rnn in enumerate(self.rnns):
            # Retrieve previous hidden state for layer.
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None
            # Forward-pass.
            output, hidden = rnn(current_input, prev_hidden)
            output = output.contiguous()
            # Update hidden state for layer.
            hidden = tuple(h.detach() for h in hidden)
            hidden_states.append(hidden)
            # Apply dropout.
            if layer == self._num_layers - 1:
                dropped_output = self._locked_dropout(output, self._dropout)
            else:
                dropped_output = self._locked_dropout(output, self._dropouth)
            current_input = dropped_output
        encoded = current_input
        self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)}

        # Predict generation-mode scores. Start by concatenating predicted entity embeddings with
        # the encoder output - then feed through a linear layer.
        generate_scores = self._fc_generate(encoded)

        # Predict copy-mode scores.
        # alias_tokens, alias_inds = alias_database.lookup(entity_ids)
        copy_scores = self._copy_scores(encoded, alias_tokens)

        # Combine scores to get vocab loss
        vocab_loss = self._vocab_loss(generate_scores,
                                      copy_scores,
                                      target,
                                      alias_copy_inds,
                                      target_mask,
                                      alias_inds,
                                      alias_tokens)

        # Compute total loss
        loss = vocab_loss

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * dropped_output.pow(2).mean()
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * (output[:, 1:] - output[:, :-1]).pow(2).mean()

        return {'loss': loss}

    @overrides
    def train(self, mode=True):
        # TODO: This is a temporary hack to ensure that the internal state resets when the model
        # switches from training to evaluation. The complication arises from potentially differing
        # batch sizes (e.g. the `reset` tensor will not be the right size). In future
        # implementations this should be handled more robustly.
        super().train(mode)
        self._state = None

    @overrides
    def eval(self):
        # TODO: See train.
        super().eval()
        self._state = None

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'ppl': self._ppl.get_metric(reset),
            'upp': self._upp.get_metric(reset),
            'kg_ppl': self._kg_ppl.get_metric(reset),
            'bg_ppl': self._bg_ppl.get_metric(reset)
        }
Пример #3
0
class AwdLstmLanguageModel(Model):
    """
    Port of the awd-lstm-lm model from:
        https://github.com/salesforce/awd-lstm-lm/

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed tokens.
    hidden_size : ``int``
        LSTM hidden layer size (note: not needed if num_layers == 1)
    num_layers : ``int``
        Number of LSTM layers to use in encoder.
    splits : ``List[int]``, optional (default=``[]``)
        Splits to use in adaptive softmax.

    A bunch of optional dropout parameters...

    tie_weights : ``bool``, optional (default=``False``)
        Whether to tie embedding and output projection weights.
    initializer: ``InitializerApplicator``,  optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_size: int,
                 hidden_size: int,
                 num_layers: int,
                 splits: List[int] = [],
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 tie_weights: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(AwdLstmLanguageModel, self).__init__(vocab)

        # Model architecture
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.tie_weights = tie_weights
        self.splits = splits
        self.alpha = alpha
        self.beta = beta

        # Dropout stuff
        self.locked_dropout = LockedDropout()
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.dropoute = dropoute
        self.dropout = dropout

        # Initialize empty state dict
        self._state: Optional[Dict[str, Any]] = None

        # Tokens are manually embedded instead of using a TokenEmbedder to make using
        # embedding_dropout easier.
        self.embedder = torch.nn.Embedding(vocab.get_vocab_size(namespace='tokens'),
                                           embedding_size)

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = embedding_size
            else:
                input_size = hidden_size
            if (i == num_layers - 1) and tie_weights:
                output_size = embedding_size
            else:
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        self.decoder = torch.nn.Linear(output_size, vocab.get_vocab_size(namespace='tokens'))

        # Optionally tie weights
        if tie_weights:
            # pylint: disable=protected-access
            self.decoder.weight = self.embedder.weight

        initializer(self)

        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))

        self.ppl = Ppl()
        self.upp = Ppl()

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # THE BELOW ONLY NEEDS TO BE SATISFIED FOR THE FANCY ITERATOR, MERITY
        # ET AL JUST PROPOGATE THE HIDDEN STATE NO MATTER WHAT
        # To make life easier when evaluating the model we use a BasicIterator
        # so that we do not need to worry about the sequence truncation
        # performed by our splitting iterators. To accomodate this, we assume
        # that if reset is not given, then everything gets reset.
        if reset is None:
            self._state = None
        elif reset.all() and (self._state is not None):
            logger.debug('RESET')
            self._state = None
        elif reset.any() and (self._state is not None):
            for layer in range(self.num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)

        target_mask = get_text_field_mask(target)
        source = source['tokens']
        target = target['tokens']

        embeddings = embedded_dropout(self.embedder, source,
                                      dropout=self.dropoute if self.training else 0)
        embeddings = self.locked_dropout(embeddings, self.dropouti)

        # Iterate through RNN layers
        current_input = embeddings
        current_hidden = []
        outputs = []
        dropped_outputs = []
        for layer, rnn in enumerate(self.rnns):

            # Bookkeeping
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None

            # Forward-pass
            output, hidden = rnn(current_input, prev_hidden)

            # More bookkeeping
            output = output.contiguous()
            outputs.append(output)
            hidden = tuple(h.detach() for h in hidden)
            current_hidden.append(hidden)

            # Apply dropout
            if layer == self.num_layers - 1:
                current_input = self.locked_dropout(output, self.dropout)
                dropped_outputs.append(output)
            else:
                current_input = self.locked_dropout(output, self.dropouth)
                dropped_outputs.append(current_input)

        # Compute logits and loss
        logits = self.decoder(current_input)
        loss = sequence_cross_entropy_with_logits(logits, target.contiguous(),
                                                  target_mask,
                                                  average="token")
        num_tokens = target_mask.float().sum() + 1e-13

        # Activation regularization
        if self.alpha:
            loss = loss + self.alpha * current_input.pow(2).mean()
        # Temporal activation regularization (slowness)
        if self.beta:
            loss = loss + self.beta * (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update metrics and state
        unks = target.eq(self._unk_index)
        unk_penalty = self._unk_penalty * unks.float().sum()

        self.ppl(loss * num_tokens, num_tokens)
        self.upp(loss * num_tokens + unk_penalty, num_tokens)
        self._state = {'layer_%i' % l: h for l, h in enumerate(current_hidden)}

        return {'loss': loss}

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'ppl': self.ppl.get_metric(reset),
            'upp': self.upp.get_metric(reset)
        }
Пример #4
0
class NoStory(Model):
    """
    Knowledge graph language model - generative story

    Parameters
    ----------
    vocab : ``Vocabulary``
        The model vocabulary.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 entity_embedder: TextFieldEmbedder,
                 alias_encoder: Seq2SeqEncoder,
                 use_shortlist: bool,
                 hidden_size: int,
                 num_layers: int,
                 cutoff: int = 30,
                 tie_weights: bool = False,
                 dropout: float = 0.4,
                 dropouth: float = 0.3,
                 dropouti: float = 0.65,
                 dropoute: float = 0.1,
                 wdrop: float = 0.5,
                 alpha: float = 2.0,
                 beta: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:
        super(NoStory, self).__init__(vocab)

        # We extract the `Embedding` layers from the `TokenEmbedders` to apply dropout later on.
        # pylint: disable=protected-access
        self._token_embedder = token_embedder._token_embedders['tokens']
        self._entity_embedder = entity_embedder._token_embedders['entity_ids']
        self._alias_encoder = alias_encoder
        self._recent_entities = RecentEntities(cutoff=cutoff)
        self._use_shortlist = use_shortlist
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._cutoff = cutoff
        self._tie_weights = tie_weights

        # Dropout
        self._locked_dropout = LockedDropout()
        self._dropout = dropout
        self._dropouth = dropouth
        self._dropouti = dropouti
        self._dropoute = dropoute
        self._wdrop = wdrop

        # Regularization strength
        self._alpha = alpha
        self._beta = beta

        # RNN Encoders.
        entity_embedding_dim = entity_embedder.get_output_dim()
        token_embedding_dim = token_embedder.get_output_dim()
        self.entity_embedding_dim = entity_embedding_dim
        self.token_embedding_dim = token_embedding_dim

        rnns: List[torch.nn.Module] = []
        for i in range(num_layers):
            if i == 0:
                input_size = token_embedding_dim
            else:
                input_size = hidden_size
            if (i == num_layers - 1):
                output_size = token_embedding_dim + entity_embedding_dim
            else:
                output_size = hidden_size
            rnns.append(torch.nn.LSTM(
                input_size, output_size, batch_first=True))
        rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop)
                for rnn in rnns]
        self.rnns = torch.nn.ModuleList(rnns)

        # Various linear transformations.
        self._fc_mention_type = torch.nn.Linear(
            in_features=token_embedding_dim,
            out_features=2)

        if not use_shortlist:
            self._fc_new_entity = torch.nn.Linear(
                in_features=entity_embedding_dim,
                out_features=vocab.get_vocab_size('entity_ids'))

            if tie_weights:
                self._fc_new_entity.weight = self._entity_embedder.weight

        self._fc_condense = torch.nn.Linear(
            in_features=token_embedding_dim + entity_embedding_dim,
            out_features=token_embedding_dim)

        self._fc_generate = torch.nn.Linear(
            in_features=token_embedding_dim,
            out_features=vocab.get_vocab_size('tokens'))

        self._fc_copy = torch.nn.Linear(
            in_features=token_embedding_dim,
            out_features=token_embedding_dim)

        if tie_weights:
            self._fc_generate.weight = self._token_embedder.weight

        self._state: Optional[Dict[str, Any]] = None

        # Metrics
        self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk'))
        self._ppl = Ppl()
        self._upp = Ppl()
        self._kg_ppl = Ppl()  # Knowledge-graph ppl
        self._bg_ppl = Ppl()  # Background ppl
        self._avg_mention_type_loss = Average()
        self._avg_new_entity_loss = Average()
        self._avg_vocab_loss = Average()
        self._new_mention_f1 = F1Measure(positive_label=1)
        self._new_entity_accuracy = CategoricalAccuracy()
        self._new_entity_accuracy20 = CategoricalAccuracy(top_k=20)

        initializer(self)

    @overrides
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                metadata: List[Dict[str, Any]],
                mention_type: torch.Tensor = None,
                raw_entity_ids: Dict[str, torch.Tensor] = None,
                entity_ids: Dict[str, torch.Tensor] = None,
                parent_ids: Dict[str, torch.Tensor] = None,
                relations: Dict[str, torch.Tensor] = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None,
                alias_copy_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                target=target,
                alias_database=alias_database,
                mention_type=mention_type,
                raw_entity_ids=raw_entity_ids,
                entity_ids=entity_ids,
                parent_ids=parent_ids,
                relations=relations,
                shortlist=shortlist,
                shortlist_inds=shortlist_inds,
                alias_copy_inds=alias_copy_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict

    def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor:

        # Extract and embed source tokens.
        source_embeddings = embedded_dropout(
            embed=self._token_embedder,
            words=source,
            dropout=self._dropoute if self.training else 0)
        source_embeddings = self._locked_dropout(
            source_embeddings, self._dropouti)

        # Encode.
        current_input = source_embeddings
        hidden_states = []
        for layer, rnn in enumerate(self.rnns):
            # Retrieve previous hidden state for layer.
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None
            # Forward-pass.
            output, hidden = rnn(current_input, prev_hidden)
            output = output.contiguous()
            # Update hidden state for layer.
            hidden = tuple(h.detach() for h in hidden)
            hidden_states.append(hidden)
            # Apply dropout.
            if layer == self._num_layers - 1:
                dropped_output = self._locked_dropout(output, self._dropout)
            else:
                dropped_output = self._locked_dropout(output, self._dropouth)
            current_input = dropped_output
        encoded = current_input

        alpha_loss = dropped_output.pow(2).mean()
        beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update state.
        self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)}

        return encoded, alpha_loss, beta_loss

    def _mention_type_loss(self,
                           encoded: torch.Tensor,
                           mention_type: torch.Tensor,
                           mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss for predicting whether or not the the next token will be part of an
        entity mention.
        """
        logits = self._fc_mention_type(encoded)
        mention_type_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask,
                                                               average='token')
        # if not self.training:
        self._new_mention_f1(predictions=logits,
                             gold_labels=mention_type,
                             mask=mask)

        return mention_type_loss

    def _new_entity_loss(self,
                         encoded: torch.Tensor,
                         target_inds: torch.Tensor,
                         shortlist: torch.Tensor,
                         target_mask: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ==========
        target_inds : ``torch.Tensor``
            Either the shortlist inds if using shortlist, otherwise the target entity ids.
        """
        if self._use_shortlist:

            # First we embed the shortlist entries
            shortlist_mask = get_text_field_mask(shortlist)
            shortlist_embeddings = embedded_dropout(
                embed=self._entity_embedder,
                words=shortlist['entity_ids'],
                dropout=self._dropoute if self.training else 0)

            # Logits are computed using the inner product that between the predicted entity embedding
            # and the embeddings of entities in the shortlist
            encodings = self._locked_dropout(encoded, self._dropout)
            logits = torch.bmm(encodings, shortlist_embeddings.transpose(1, 2))

            # Take masked softmax to get log probabilties and gather the targets.
            log_probs = masked_log_softmax(logits, shortlist_mask)
            target_log_probs = torch.gather(
                log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1)

            # If not generating a new mention, the action is deterministic - so the loss is 0 for these tokens.
            mask = ~target_inds.eq(0)
            target_log_probs[~mask] = 0

            # if not self.training:
            self._new_entity_accuracy(predictions=log_probs[mask],
                                      gold_labels=target_inds[mask])
            self._new_entity_accuracy20(predictions=log_probs[mask],
                                        gold_labels=target_inds[mask])

            # Return the token-wise average loss
            return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

        else:
            logits = self._fc_new_entity(encoded)
            log_probs = F.log_softmax(logits, dim=-1)

            num_categories = log_probs.shape[-1]
            flat_log_probs = log_probs.view(-1, num_categories)
            flat_target_inds = target_inds.view(-1)
            target_log_probs = torch.gather(
                flat_log_probs, -1, flat_target_inds.unsqueeze(-1)).squeeze(-1)

            mask = ~flat_target_inds.eq(0)
            target_log_probs[~mask] = 0

            self._new_entity_accuracy(predictions=flat_log_probs[mask],
                                      gold_labels=flat_target_inds[mask])
            self._new_entity_accuracy20(predictions=flat_log_probs[mask],
                                        gold_labels=flat_target_inds[mask])

            return -target_log_probs.sum() / (target_mask.sum() + 1e-13)

    def _parent_log_probs(self,
                          encoded_head: torch.Tensor,
                          entity_ids: torch.Tensor,
                          parent_ids: torch.Tensor) -> torch.Tensor:
        # Lookup recent entities (which are candidates for parents) and get their embeddings.
        candidate_ids, candidate_mask = self._recent_entities(entity_ids)
        logger.debug('Candidate ids shape: %s', candidate_ids.shape)
        candidate_embeddings = embedded_dropout(self._entity_embedder,
                                                words=candidate_ids,
                                                dropout=self._dropoute if self.training else 0)

        # Logits are computed using a general bilinear form that measures the similarity between
        # the projected hidden state and the embeddings of candidate entities
        encoded = self._locked_dropout(encoded_head, self._dropout)
        selection_logits = torch.bmm(
            encoded, candidate_embeddings.transpose(1, 2))

        # Get log probabilities using masked softmax (need to double check mask works properly).

        # shape: (batch_size, sequence_length, num_candidates)
        log_probs = masked_log_softmax(selection_logits, candidate_mask)

        # Now for the tricky part. We need to convert the parent ids to a mask that selects the
        # relevant probabilities from log_probs. To do this we need to align the candidates with
        # the parent ids, which can be achieved by an element-wise equality comparison. We also
        # need to ensure that null parents are not selected.

        # shape: (batch_size, sequence_length, num_parents, 1)
        _parent_ids = parent_ids.unsqueeze(-1)

        batch_size, num_candidates = candidate_ids.shape
        # shape: (batch_size, 1, 1, num_candidates)
        _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates)

        # shape: (batch_size, sequence_length, num_parents, num_candidates)
        is_parent = _parent_ids.eq(_candidate_ids)
        # shape: (batch_size, 1, 1, num_candidates)
        non_null = ~_candidate_ids.eq(0)

        # Since multiplication is addition in log-space, we can apply mask by adding its log (+
        # some small constant for numerical stability).
        mask = is_parent & non_null
        masked_log_probs = log_probs.unsqueeze(
            2) + (mask.float() + 1e-45).log()
        logger.debug('Masked log probs shape: %s', masked_log_probs.shape)

        # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would
        # be to marginalize it out. However, since our data is sparse (the last two dims are
        # essentially a delta function) this would add a lot of unneccesary terms to the computation graph.
        # To get around this we are going to try to use a gather.
        _, index = torch.max(mask, dim=-1, keepdim=True)
        target_log_probs = torch.gather(
            masked_log_probs, dim=-1, index=index).squeeze(-1)

        return target_log_probs

    def _generate_scores(self,
                         encoded: torch.Tensor,
                         entity_ids: torch.Tensor) -> torch.Tensor:
        entity_embeddings = embedded_dropout(embed=self._entity_embedder,
                                             words=entity_ids,
                                             dropout=self._dropoute if self.training else 0)
        concatenated = torch.cat((encoded, entity_embeddings), dim=-1)
        condensed = self._fc_condense(concatenated)
        return self._fc_generate(condensed)

    def _copy_scores(self,
                     encoded: torch.Tensor,
                     alias_tokens: torch.Tensor) -> torch.Tensor:
        # Begin by flattening the tokens so that they fit the expected shape of a
        # ``Seq2SeqEncoder``.
        batch_size, sequence_length, num_aliases, alias_length = alias_tokens.shape
        flattened = alias_tokens.view(-1, alias_length)
        copy_mask = flattened != 0
        if copy_mask.sum() == 0:
            return encoded.new_zeros((batch_size, sequence_length, num_aliases * alias_length),
                                     dtype=torch.float32)

        # Embed and encode the alias tokens.
        embedded = self._token_embedder(flattened)
        mask = flattened.gt(0)
        encoded_aliases = self._alias_encoder(embedded, mask)

        # Equation 8 in the CopyNet paper recommends applying the additional step.
        projected = torch.tanh(self._fc_copy(encoded_aliases))
        projected = self._locked_dropout(projected, self._dropout)

        # This part gets a little funky - we need to make sure that the first dimension in
        # `projected` and `hidden` is batch_size x sequence_length.
        encoded = encoded.view(batch_size * sequence_length, 1, -1)
        projected = projected.view(
            batch_size * sequence_length, -1, num_aliases * alias_length)
        copy_scores = torch.bmm(encoded, projected).squeeze()
        copy_scores = copy_scores.view(
            batch_size, sequence_length, -1).contiguous()
        logger.debug('Copy scores shape: %s', copy_scores.shape)

        return copy_scores

    def _vocab_loss(self,
                    generate_scores: torch.Tensor,
                    copy_scores: torch.Tensor,
                    target_tokens: torch.Tensor,
                    target_alias_indices: torch.Tensor,
                    mask: torch.Tensor,
                    alias_indices: torch.Tensor,
                    mention_mask: torch.Tensor):
        batch_size, sequence_length, vocab_size = generate_scores.shape
        copy_sequence_length = copy_scores.shape[-1]

        # Flat sequences make life **much** easier.
        flattened_targets = target_tokens.view(batch_size * sequence_length, 1)
        flattened_mask = mask.view(-1, 1).byte()

        # In order to obtain proper log probabilities we create a mask to omit padding alias tokens
        # from the calculation.
        alias_mask = alias_indices.view(batch_size, sequence_length, -1).gt(0)
        score_mask = mask.new_ones(
            batch_size, sequence_length, vocab_size + copy_sequence_length)
        score_mask[:, :, vocab_size:] = alias_mask

        # The log-probability distribution is then given by taking the masked log softmax.
        concatenated_scores = torch.cat((generate_scores, copy_scores), dim=-1)
        log_probs = masked_log_softmax(concatenated_scores, score_mask)

        # GENERATE LOSS ###
        # The generated token loss is a simple cross-entropy calculation, we can just gather
        # the log probabilties...
        flattened_log_probs = log_probs.view(batch_size * sequence_length, -1)
        generate_log_probs_source_vocab = flattened_log_probs.gather(
            1, flattened_targets)
        # ...except we need to ignore the contribution of UNK tokens that are copied (only when
        # computing the loss). To do that we create a mask which is 1 only if the token is not a
        # copied UNK (or padding).
        unks = target_tokens.eq(self._unk_index).view(-1, 1)
        copied = target_alias_indices.gt(0).view(-1, 1)
        generate_mask = ~(unks & copied) & flattened_mask
        # Since we are in log-space we apply the mask by addition.
        generate_log_probs_extended_vocab = generate_log_probs_source_vocab + \
            (generate_mask.float() + 1e-45).log()

        # COPY LOSS ###
        copy_log_probs = flattened_log_probs[:, vocab_size:]
        # When computing the loss we need to get the log probability of **only** the copied tokens.
        alias_indices = alias_indices.view(batch_size * sequence_length, -1)
        target_alias_indices = target_alias_indices.view(-1, 1)
        copy_mask = alias_indices.eq(
            target_alias_indices) & flattened_mask & target_alias_indices.gt(0)
        copy_log_probs = copy_log_probs + (copy_mask.float() + 1e-45).log()

        # COMBINED LOSS ###
        # The final loss term is computed using our log probs computed w.r.t to the entire
        # vocabulary.
        combined_log_probs_extended_vocab = torch.cat((generate_log_probs_extended_vocab,
                                                       copy_log_probs),
                                                      dim=1)
        combined_log_probs_extended_vocab = torch.logsumexp(combined_log_probs_extended_vocab,
                                                            dim=1)
        flattened_mask = flattened_mask.squeeze()
        # Zero out padding loss
        combined_log_probs_extended_vocab = combined_log_probs_extended_vocab * \
            flattened_mask.float()
        vocab_loss = -combined_log_probs_extended_vocab.sum() / (mask.sum() + 1e-13)

        # Unknown penalty - only applies to non-copied unks
        true_unks = unks.squeeze() & ~copied.squeeze() & flattened_mask
        penalized_log_probs = combined_log_probs_extended_vocab - \
            self._unk_penalty * true_unks.float()
        penalized_log_probs[~flattened_mask] = 0
        penalized_vocab_loss = -penalized_log_probs.sum() / (mask.sum() + 1e-13)

        # PERPLEXITY ###
        # Our perplexity terms are computed using the log probs computed w.r.t the source
        # vocabulary.
        combined_log_probs_source_vocab = torch.cat((generate_log_probs_source_vocab,
                                                     copy_log_probs),
                                                    dim=1)
        combined_log_probs_source_vocab = torch.logsumexp(combined_log_probs_source_vocab,
                                                          dim=1)

        # For UPP we penalize **only** p(UNK); not the copy probabilities!
        penalized_log_probs_source_vocab = generate_log_probs_source_vocab - \
            self._unk_penalty * unks.float()
        penalized_log_probs_source_vocab = torch.cat((penalized_log_probs_source_vocab,
                                                      copy_log_probs),
                                                     dim=1)
        penalized_log_probs_source_vocab = torch.logsumexp(penalized_log_probs_source_vocab,
                                                           dim=1)

        kg_mask = (mention_mask * mask.byte()).view(-1)
        bg_mask = ((1 - mention_mask) * mask.byte()).view(-1)
        mask = (kg_mask | bg_mask)

        self._ppl(-combined_log_probs_source_vocab[mask].sum(),
                  mask.float().sum() + 1e-13)
        self._upp(-penalized_log_probs_source_vocab[mask].sum(),
                  mask.float().sum() + 1e-13)
        if kg_mask.any():
            self._kg_ppl(-combined_log_probs_source_vocab[kg_mask].sum(
            ), kg_mask.float().sum() + 1e-13)
        if bg_mask.any():
            self._bg_ppl(-combined_log_probs_source_vocab[bg_mask].sum(
            ), bg_mask.float().sum() + 1e-13)

        return vocab_loss, penalized_vocab_loss

    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      target: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      mention_type: torch.Tensor,
                      raw_entity_ids: Dict[str, torch.Tensor],
                      entity_ids: Dict[str, torch.Tensor],
                      parent_ids: Dict[str, torch.Tensor],
                      relations: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor,
                      alias_copy_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get the token mask and extract indexed text fields.
        # shape: (batch_size, sequence_length)
        target_mask = get_text_field_mask(target)
        source = source['tokens']
        target = target['tokens']
        raw_entity_ids = raw_entity_ids['raw_entity_ids']
        entity_ids = entity_ids['entity_ids']

        logger.debug('Source & Target shape: %s', source.shape)
        logger.debug('Entity ids shape: %s', entity_ids.shape)
        logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
        # Embed source tokens.
        # shape: (batch_size, sequence_length, embedding_dim)
        encoded, alpha_loss, beta_loss = self._encode_source(source)
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim]
        encoded_token, encoded_head = encoded.split(splits, dim=-1)

        # Predict whether or not the next token will be an entity mention, and if so which type.
        mention_type = mention_type.gt(0).long()  # Map 1, 2 -> 1
        mention_type_loss = self._mention_type_loss(
            encoded_token, mention_type, target_mask)
        self._avg_mention_type_loss(float(mention_type_loss))

        # For new mentions, predict which entity (among those in the supplied shortlist) will be
        # mentioned.
        if self._use_shortlist:
            new_entity_loss = self._new_entity_loss(encoded_head,
                                                    shortlist_inds,
                                                    shortlist,
                                                    target_mask)
        else:
            new_entity_loss = self._new_entity_loss(encoded_head,
                                                    entity_ids,
                                                    None,
                                                    target_mask)

        self._avg_new_entity_loss(float(new_entity_loss))

        # Predict generation-mode scores. Note: these are W.R.T to entity_ids since we need the embedding.
        generate_scores = self._generate_scores(encoded_token, entity_ids)

        # Predict copy-mode scores. Note: these are W.R.T raw_entity_ids since we need to look up aliases.
        alias_tokens, alias_inds = alias_database.lookup(raw_entity_ids)
        copy_scores = self._copy_scores(encoded_token, alias_tokens)

        # Combine scores to get vocab loss
        vocab_loss, penalized_vocab_loss = self._vocab_loss(generate_scores,
                                                            copy_scores,
                                                            target,
                                                            alias_copy_inds,
                                                            target_mask,
                                                            alias_inds,
                                                            entity_ids.gt(0))
        self._avg_vocab_loss(float(vocab_loss))

        # Compute total loss. Also compute logp (needed for importance sampling evaluation).
        loss = vocab_loss + mention_type_loss + new_entity_loss
        logp = -(vocab_loss + mention_type_loss +
                 new_entity_loss) * target_mask.sum()
        penalized_logp = - \
            (penalized_vocab_loss + mention_type_loss +
             new_entity_loss) * target_mask.sum()

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * alpha_loss
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * beta_loss

        return {'loss': loss, 'logp': logp, 'penalized_logp': penalized_logp}

    @overrides
    def train(self, mode=True):
        # TODO: This is a temporary hack to ensure that the internal state resets when the model
        # switches from training to evaluation. The complication arises from potentially differing
        # batch sizes (e.g. the `reset` tensor will not be the right size). In future
        # implementations this should be handled more robustly.
        super().train(mode)
        self._state = None

    @overrides
    def eval(self):
        # TODO: See train.
        super().eval()
        self._state = None

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        out = {
            'ppl': self._ppl.get_metric(reset),
            'upp': self._upp.get_metric(reset),
            'kg_ppl': self._kg_ppl.get_metric(reset),
            'bg_ppl': self._bg_ppl.get_metric(reset),
            'type': self._avg_mention_type_loss.get_metric(reset),
            'new': self._avg_new_entity_loss.get_metric(reset),
            'vocab': self._avg_vocab_loss.get_metric(reset),
        }
        # if not self.training:
        p, r, f = self._new_mention_f1.get_metric(reset)
        out['new_p'] = p
        out['new_r'] = r
        out['new_f1'] = f
        out['new_ent_acc'] = self._new_entity_accuracy.get_metric(reset)
        out['new_ent_acc_20'] = self._new_entity_accuracy20.get_metric(reset)
        return out