def test_update_embedding(self): embedding_dim = 4 max_embeddings = 10 batch_size = 1 dynamic_embedding = DynamicEmbedding(embedding_dim, max_embeddings) dynamic_embedding.reset_states(batch_size) hidden = torch.randn((batch_size, embedding_dim), requires_grad=True) update_indices = torch.tensor([0]) # pylint: disable=E1102 timestep = 1 # Check embedding changes on update original = dynamic_embedding.embeddings[0, 0].clone() dynamic_embedding.update_embeddings(hidden, update_indices, timestep) updated = dynamic_embedding.embeddings[0, 0] self.assertFalse(torch.allclose(original, updated)) # Check last seen is correct self.assertEqual(dynamic_embedding.last_seen[0, 0], 1) # Check gradient propagates to initial embedding and hidden updated.sum().backward() self.assertIsNotNone(dynamic_embedding._initial_embedding.grad) self.assertIsNotNone(hidden.grad)
class EntityNLM(Model): """ Implementation of the Entity Neural Language Model from: https://arxiv.org/abs/1708.00781 Parameters ---------- vocab : ``Vocabulary`` The model vocabulary. text_field_embedder : ``TextFieldEmbedder`` Used to embed tokens. encoder : ``Seq2SeqEncoder`` Used to encode the sequence of token embeddings. embedding_dim : ``int`` The dimension of entity / length embeddings. Should match the encoder output size. max_mention_length : ``int`` Maximum entity mention length. max_embeddings : ``int`` Maximum number of embeddings. tie_weights : ``bool`` Whether to tie embedding and output weights. variational_dropout_rate : ``float``, optional Dropout rate of variational dropout applied to input embeddings. Default: 0.0 dropout_rate : ``float``, optional Dropout rate applied to hidden states. Default: 0.0 initializer : ``InitializerApplicator``, optional Used to initialize model parameters. """ # pylint: disable=line-too-long def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, embedding_dim: int, max_mention_length: int, max_embeddings: int, tie_weights: bool, variational_dropout_rate: float = 0.0, dropout_rate: float = 0.0, initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(EntityNLM, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._encoder = encoder self._embedding_dim = embedding_dim self._max_mention_length = max_mention_length self._max_embeddings = max_embeddings self._tie_weights = tie_weights self._variational_dropout_rate = variational_dropout_rate self._dropout_rate = dropout_rate self._state: Optional[StateDict] = None # Input variational dropout self._variational_dropout = InputVariationalDropout( variational_dropout_rate) self._dropout = torch.nn.Dropout(dropout_rate) # For entity type prediction self._entity_type_projection = torch.nn.Linear( in_features=embedding_dim, out_features=2, bias=False) self._dynamic_embeddings = DynamicEmbedding( embedding_dim=embedding_dim, max_embeddings=max_embeddings) # For mention length prediction self._mention_length_projection = torch.nn.Linear( in_features=2 * embedding_dim, out_features=max_mention_length) # For next word prediction self._dummy_context_embedding = Parameter( F.normalize(torch.randn(1, embedding_dim))) # TODO: Maybe squeeze self._entity_output_projection = torch.nn.Linear( in_features=embedding_dim, out_features=embedding_dim, bias=False) self._context_output_projection = torch.nn.Linear( in_features=embedding_dim, out_features=embedding_dim, bias=False) self._vocab_projection = torch.nn.Linear( in_features=embedding_dim, out_features=vocab.get_vocab_size('tokens')) if tie_weights: self._vocab_projection.weight = self._text_field_embedder._token_embedders[ 'tokens'].weight # pylint: disable=W0212 # self._perplexity = Perplexity() # self._unknown_penalized_perplexity = UnknownPenalizedPerplexity(self.vocab) self._entity_type_accuracy = CategoricalAccuracy() self._entity_id_accuracy = CategoricalAccuracy() self._mention_length_accuracy = CategoricalAccuracy() if tie_weights: self._vocab_projection.weight = self._text_field_embedder._token_embedders[ 'tokens'].weight # pylint: disable=W0212 initializer(self) @overrides def forward( self, # pylint: disable=arguments-differ tokens: Dict[str, torch.Tensor], entity_types: Optional[torch.Tensor] = None, entity_ids: Optional[torch.Tensor] = None, mention_lengths: Optional[torch.Tensor] = None, reset: bool = False) -> Dict[str, torch.Tensor]: """ Computes the loss during training / validation. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. reset : ``bool`` Whether or not to reset the model's state. This should be done at the start of each new sequence. Returns ------- An output dictionary consisting of: loss : ``torch.Tensor`` The combined loss. """ batch_size = tokens['tokens'].shape[0] if reset: self.reset_states(batch_size) else: self.detach_states() if entity_types is not None: output_dict = self._forward_loop(tokens=tokens, entity_types=entity_types, entity_ids=entity_ids, mention_lengths=mention_lengths) else: output_dict = {} if not self.training: # TODO Some evaluation stuff pass return output_dict def _forward_loop( self, tokens: Dict[str, torch.Tensor], entity_types: torch.Tensor, entity_ids: torch.Tensor, mention_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: """ Performs the forward pass to calculate the loss on a chunk of training data. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. Returns ------- An output dictionary consisting of: entity_type_loss : ``torch.Tensor`` The loss of entity type predictions. entity_id_loss : ``torch.Tensor`` The loss of entity id predictions. mention_length_loss : ``torch.Tensor`` The loss of mention length predictions. vocab_loss : ``torch.Tensor`` The loss of vocab word predictions. loss : ``torch.Tensor`` The combined loss. logp : ``torch.Tensor`` Instance level log-probabilities """ batch_size, sequence_length = tokens['tokens'].shape # The model state allows us to recover the last timestep from the previous chunk in the # split. If it does not exist, then we are processing a new batch. if self._state is not None: tokens = { field: torch.cat( (self._state['prev_tokens'][field], tokens[field]), dim=1) for field in tokens } entity_types = torch.cat( (self._state['prev_entity_types'], entity_types), dim=1) entity_ids = torch.cat( (self._state['prev_entity_ids'], entity_ids), dim=1) mention_lengths = torch.cat( (self._state['prev_mention_lengths'], mention_lengths), dim=1) contexts = self._state['prev_contexts'] sequence_length += 1 else: contexts = self._dummy_context_embedding.repeat(batch_size, 1) # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) hidden = self._encoder(embeddings, mask) # Initialize losses entity_type_loss = 0.0 entity_id_loss = 0.0 mention_length_loss = 0.0 vocab_loss = 0.0 logp = hidden.new_zeros(batch_size) # We dynamically add entities and update their representations in sequence. The following # loop is designed to imitate as closely as possible lines 219-313 in: # https://github.com/jiyfeng/entitynlm/blob/master/entitynlm.h # while still being carried out in batch. for timestep in range(sequence_length - 1): current_entity_types = entity_types[:, timestep] current_entity_ids = entity_ids[:, timestep] current_mention_lengths = mention_lengths[:, timestep] current_hidden = self._dropout(hidden[:, timestep]) next_entity_types = entity_types[:, timestep + 1] next_entity_ids = entity_ids[:, timestep + 1] next_mention_lengths = mention_lengths[:, timestep + 1] next_mask = mask[:, timestep + 1] next_tokens = tokens['tokens'][:, timestep + 1] # We add new entities to any sequence where the current entity id matches the number of # embeddings that currently exist for that sequence (this means we need a new one since # there is an additional dummy embedding). new_entities = current_entity_ids == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) # We also perform updates of the currently observed entities. self._dynamic_embeddings.update_embeddings( hidden=current_hidden, update_indices=current_entity_ids, timestep=timestep, mask=current_entity_types) # This part is a little counter-intuitive. Because the above code adds a new embedding # whenever the **current** entity id matches the number of embeddings, we are one # embedding short if the **next** entity id has not been seen before. To deal with # this, we use the null embedding (e.g. the first one we created) as a proxy for the # new entity's embedding (since it is on average what the new entity's embedding will # be initialized in the next timestep). It might seem more sensible to just create the # embedding now, but we cannot because of the subsequent update (since this would # require access to the **next** hidden state, which does not exist during generation). next_entity_ids = next_entity_ids.clone( ) # This prevents mutating the source data. next_entity_ids[next_entity_ids == self._dynamic_embeddings.num_embeddings] = 0 # We only predict the types / ids / lengths of the next mention if we are not currently # in the process of generating it (e.g. if the current remaining mention length is 1). # Indexing / masking with ``predict_all`` makes it possible to do this in batch. predict_all = (current_mention_lengths == 1) * next_mask.byte() if predict_all.sum() > 0: # Equation 3 in the paper. entity_type_logits = self._entity_type_projection( current_hidden[predict_all]) _entity_type_loss = F.cross_entropy( entity_type_logits, next_entity_types[predict_all].long(), reduction='none') entity_type_loss += _entity_type_loss.sum() entity_type_logp = torch.zeros_like(next_entity_types, dtype=torch.float32) entity_type_logp[predict_all] = -_entity_type_loss logp += entity_type_logp self._entity_type_accuracy( predictions=entity_type_logits, gold_labels=next_entity_types[predict_all].long()) # Only proceed to predict entity and mention length if there is in fact an entity. predict_em = next_entity_types * predict_all if predict_em.sum() > 0: # Equation 4 in the paper. entity_id_prediction_outputs = self._dynamic_embeddings( hidden=current_hidden, target=next_entity_ids, mask=predict_em) _entity_id_loss = entity_id_prediction_outputs['loss'] entity_id_loss += _entity_id_loss.sum() entity_id_logp = torch.zeros_like(next_entity_ids, dtype=torch.float32) entity_id_logp[predict_em] = -_entity_id_loss logp += entity_id_logp self._entity_id_accuracy( predictions=entity_id_prediction_outputs['logits'], gold_labels=next_entity_ids[predict_em]) # Equation 5 in the paper. next_entity_embeddings = self._dynamic_embeddings.embeddings[ predict_em, next_entity_ids[predict_em]] next_entity_embeddings = self._dropout( next_entity_embeddings) concatenated = torch.cat( (current_hidden[predict_em], next_entity_embeddings), dim=-1) mention_length_logits = self._mention_length_projection( concatenated) _mention_length_loss = F.cross_entropy( mention_length_logits, next_mention_lengths[predict_em], reduction='none') mention_length_loss += _mention_length_loss.sum() mention_length_logp = torch.zeros_like( next_mention_lengths, dtype=torch.float32) mention_length_logp[predict_em] = -_mention_length_loss logp += mention_length_logp self._mention_length_accuracy( predictions=mention_length_logits, gold_labels=next_mention_lengths[predict_em]) # Always predict the next word. This is done using the hidden state and contextual bias. entity_embeddings = self._dynamic_embeddings.embeddings[ next_entity_types, next_entity_ids[next_entity_types]] entity_embeddings = self._entity_output_projection( entity_embeddings) context_embeddings = contexts[1 - next_entity_types] context_embeddings = self._context_output_projection( context_embeddings) # The checks in the following block of code are required to prevent adding empty # tensors to vocab_features (which causes a floating point error). vocab_features = current_hidden.clone() if next_entity_types.sum() > 0: vocab_features[next_entity_types] = vocab_features[ next_entity_types] + entity_embeddings if (1 - next_entity_types.sum()) > 0: vocab_features[1 - next_entity_types] = vocab_features[ 1 - next_entity_types] + context_embeddings vocab_logits = self._vocab_projection(vocab_features) _vocab_loss = F.cross_entropy(vocab_logits, next_tokens, reduction='none') _vocab_loss = _vocab_loss * next_mask.float() vocab_loss += _vocab_loss.sum() logp += -_vocab_loss # self._perplexity(logits=vocab_logits, # labels=next_tokens, # mask=next_mask.float()) # self._unknown_penalized_perplexity(logits=vocab_logits, # labels=next_tokens, # mask=next_mask.float()) # Lastly update contexts contexts = current_hidden # Normalize the losses entity_type_loss /= mask.sum() entity_id_loss /= mask.sum() mention_length_loss /= mask.sum() vocab_loss /= mask.sum() total_loss = entity_type_loss + entity_id_loss + mention_length_loss + vocab_loss output_dict = { 'entity_type_loss': entity_type_loss, 'entity_id_loss': entity_id_loss, 'mention_length_loss': mention_length_loss, 'vocab_loss': vocab_loss, 'loss': total_loss, 'logp': logp } # Update the model state self._state = { 'prev_tokens': { field: tokens[field][:, -1].unsqueeze(1).detach() for field in tokens }, 'prev_entity_types': entity_types[:, -1].unsqueeze(1).detach(), 'prev_entity_ids': entity_ids[:, -1].unsqueeze(1).detach(), 'prev_mention_lengths': mention_lengths[:, -1].unsqueeze(1).detach(), 'prev_contexts': contexts.detach() } return output_dict def reset_states(self, batch_size: int) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" self._encoder.reset_states() self._dynamic_embeddings.reset_states(batch_size) self._state = None def detach_states(self): """Detaches the model's state to enforce truncated backpropagation.""" self._dynamic_embeddings.detach_states() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { # 'ppl': self._perplexity.get_metric(reset), # 'upp': self._unknown_penalized_perplexity.get_metric(reset), 'et_acc': self._entity_type_accuracy.get_metric(reset), 'eid_acc': self._entity_id_accuracy.get_metric(reset), 'ml_acc': self._mention_length_accuracy.get_metric(reset) }
class EntityNLMDiscriminator(Model): """ Implementation of the discriminative model from: https://arxiv.org/abs/1708.00781 used to draw importance samples. Parameters ---------- vocab : ``Vocabulary`` The model vocabulary. text_field_embedder : ``TextFieldEmbedder`` Used to embed tokens. encoder : ``Seq2SeqEncoder`` Used to encode the sequence of token embeddings. embedding_dim : ``int`` The dimension of entity / length embeddings. Should match the encoder output size. max_mention_length : ``int`` Maximum entity mention length. max_embeddings : ``int`` Maximum number of embeddings. variational_dropout_rate : ``float``, optional Dropout rate of variational dropout applied to input embeddings. Default: 0.0 dropout_rate : ``float``, optional Dropout rate applied to hidden states. Default: 0.0 initializer : ``InitializerApplicator``, optional Used to initialize model parameters. """ # pylint: disable=line-too-long def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, embedding_dim: int, max_mention_length: int, max_embeddings: int, variational_dropout_rate: float = 0.0, dropout_rate: float = 0.0, initializer: InitializerApplicator = InitializerApplicator()) -> None: super(EntityNLMDiscriminator, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._encoder = encoder self._embedding_dim = embedding_dim self._max_mention_length = max_mention_length self._max_embeddings = max_embeddings self._state: Optional[StateDict] = None # Input variational dropout self._variational_dropout = InputVariationalDropout(variational_dropout_rate) self._dropout = torch.nn.Dropout(dropout_rate) # For entity type prediction self._entity_type_projection = torch.nn.Linear(in_features=embedding_dim, out_features=2, bias=False) self._dynamic_embeddings = DynamicEmbedding(embedding_dim=embedding_dim, max_embeddings=max_embeddings) # For mention length prediction self._mention_length_projection = torch.nn.Linear(in_features=2*embedding_dim, out_features=max_mention_length) self._entity_type_accuracy = CategoricalAccuracy() self._entity_id_accuracy = CategoricalAccuracy() self._mention_length_accuracy = CategoricalAccuracy() initializer(self) @overrides def forward(self, # pylint: disable=arguments-differ tokens: Dict[str, torch.Tensor], entity_types: Optional[torch.Tensor] = None, entity_ids: Optional[torch.Tensor] = None, mention_lengths: Optional[torch.Tensor] = None, reset: bool = False)-> Dict[str, torch.Tensor]: """ Computes the loss during training / validation. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. reset : ``bool`` Whether or not to reset the model's state. This should be done at the start of each new sequence. Returns ------- An output dictionary consisting of: loss : ``torch.Tensor`` The combined loss. """ batch_size = tokens['tokens'].shape[0] if reset: self.reset_states(batch_size) else: self.detach_states() if entity_types is not None: output_dict = self._forward_loop(tokens=tokens, entity_types=entity_types, entity_ids=entity_ids, mention_lengths=mention_lengths) else: output_dict = {} return output_dict def sample(self, tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Generates a sample from the discriminative model. WARNING: Unlike during training, this function expects the full (unsplit) sequence of tokens. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. Returns ------- An output dictionary consisting of: logp : ``torch.Tensor`` A tensor containing the log-probability of the sample (averaged over time) entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. """ batch_size, sequence_length = tokens['tokens'].shape # We will use a standard iterator during evaluation instead of a split iterator. Otherwise # it will be a pain to handle generating multiple samples for a sequence since there's no # way to get back to the first split. self.reset_states(batch_size) # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) hidden = self._encoder(embeddings, mask) prev_mention_lengths = tokens['tokens'].new_ones(batch_size) # Initialize outputs logp = hidden.new_zeros(batch_size) # Track total logp for **each** generated sample entity_types = torch.zeros_like(tokens['tokens'], dtype=torch.uint8) entity_ids = torch.zeros_like(tokens['tokens']) mention_lengths = torch.ones_like(tokens['tokens']) # Generate outputs for timestep in range(sequence_length): current_hidden = hidden[:, timestep] # We only predict types / ids / lengths if the previous mention is terminated. predict_mask = prev_mention_lengths == 1 predict_mask = predict_mask * mask[:, timestep].byte() if predict_mask.sum() > 0: # Predict entity types entity_type_logits = self._entity_type_projection(current_hidden[predict_mask]) entity_type_logp = F.log_softmax(entity_type_logits, dim=-1) entity_type_prediction_logp, entity_type_predictions = sample_from_logp(entity_type_logp) entity_type_predictions = entity_type_predictions.byte() entity_types[predict_mask, timestep] = entity_type_predictions logp[predict_mask] += entity_type_prediction_logp # Only predict entity and mention lengths if we predicted that there was a mention predict_em = entity_types[:, timestep] * predict_mask if predict_em.sum() > 0: # Predict entity ids entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, mask=predict_em) entity_id_logits = entity_id_prediction_outputs['logits'] entity_id_logp = F.log_softmax(entity_id_logits, dim=-1) entity_id_prediction_logp, entity_id_predictions = sample_from_logp(entity_id_logp) # Predict mention lengths - we do this before writing the # entity id predictions since we'll need to reindex the new # entities, but need the null embeddings here. predicted_entity_embeddings = self._dynamic_embeddings.embeddings[predict_em, entity_id_predictions] concatenated = torch.cat((current_hidden[predict_em], predicted_entity_embeddings), dim=-1) mention_length_logits = self._mention_length_projection(concatenated) mention_length_logp = F.log_softmax(mention_length_logits, dim=-1) mention_length_prediction_logp, mention_length_predictions = sample_from_logp(mention_length_logp) # Write predictions new_entity_mask = entity_id_predictions == 0 new_entity_labels = self._dynamic_embeddings.num_embeddings[predict_em] entity_id_predictions[new_entity_mask] = new_entity_labels[new_entity_mask] entity_ids[predict_em, timestep] = entity_id_predictions logp[predict_em] += entity_id_prediction_logp mention_lengths[predict_em, timestep] = mention_length_predictions logp[predict_em] += mention_length_prediction_logp # Add / update entity embeddings new_entities = entity_ids[:, timestep] == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=entity_ids[:, timestep], timestep=timestep, mask=predict_em) # If the previous mentions are ongoing, we assign the output deterministically. Mention # lengths decrease by 1, all other outputs are copied from the previous timestep. Do # not need to add anything to logp since these 'predictions' have probability 1 under # the model. deterministic_mask = prev_mention_lengths > 1 deterministic_mask = deterministic_mask * mask[:, timestep].byte() if deterministic_mask.sum() > 1: entity_types[deterministic_mask, timestep] = entity_types[deterministic_mask, timestep - 1] entity_ids[deterministic_mask, timestep] = entity_ids[deterministic_mask, timestep - 1] mention_lengths[deterministic_mask, timestep] = mention_lengths[deterministic_mask, timestep - 1] - 1 # Update mention lengths for next timestep prev_mention_lengths = mention_lengths[:, timestep] return { 'logp': logp, 'sample': { 'entity_types': entity_types, 'entity_ids': entity_ids, 'mention_lengths': mention_lengths } } def _forward_loop(self, tokens: Dict[str, torch.Tensor], entity_types: torch.Tensor, entity_ids: torch.Tensor, mention_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: """ Performs the forward pass to calculate the loss on a chunk of training data. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. Returns ------- An output dictionary consisting of: entity_type_loss : ``torch.Tensor`` The loss of entity type predictions. entity_id_loss : ``torch.Tensor`` The loss of entity id predictions. mention_length_loss : ``torch.Tensor`` The loss of mention length predictions. loss : ``torch.Tensor`` The combined loss. """ batch_size, sequence_length = tokens['tokens'].shape # Need to track previous mention lengths in order to know when to measure loss. if self._state is None: prev_mention_lengths = mention_lengths.new_ones(batch_size) else: prev_mention_lengths = self._state['prev_mention_lengths'] # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) hidden = self._encoder(embeddings, mask) # Initialize losses entity_type_loss = torch.tensor(0.0, requires_grad=True, device=hidden.device) entity_id_loss = torch.tensor(0.0, requires_grad=True, device=hidden.device) mention_length_loss = torch.tensor(0.0, requires_grad=True, device=hidden.device) for timestep in range(sequence_length): current_entity_types = entity_types[:, timestep] current_entity_ids = entity_ids[:, timestep] current_mention_lengths = mention_lengths[:, timestep] current_hidden = hidden[:, timestep] current_hidden = self._dropout(hidden[:, timestep]) # We only predict types / ids / lengths if we are not currently in the process of # generating a mention (e.g. if the previous remaining mention length is 1). Indexing / # masking with ``predict_all`` makes it possible to do this in batch. predict_all = prev_mention_lengths == 1 predict_all = predict_all * mask[:, timestep].byte() if predict_all.sum() > 0: # Equation 3 in the paper. entity_type_logits = self._entity_type_projection(current_hidden[predict_all]) entity_type_loss = entity_type_loss + F.cross_entropy( entity_type_logits, current_entity_types[predict_all].long(), reduction='sum') self._entity_type_accuracy(predictions=entity_type_logits, gold_labels=current_entity_types[predict_all].long()) # Only proceed to predict entity and mention length if there is in fact an entity. predict_em = current_entity_types * predict_all if predict_em.sum() > 0: # Equation 4 in the paper. We want new entities to correspond to a prediction of # zero, their embedding should be added after they've been predicted for the first # time. modified_entity_ids = current_entity_ids.clone() modified_entity_ids[modified_entity_ids == self._dynamic_embeddings.num_embeddings] = 0 entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, target=modified_entity_ids, mask=predict_em) entity_id_loss = entity_id_loss + entity_id_prediction_outputs['loss'].sum() self._entity_id_accuracy(predictions=entity_id_prediction_outputs['logits'], gold_labels=modified_entity_ids[predict_em]) # Equation 5 in the paper. predicted_entity_embeddings = self._dynamic_embeddings.embeddings[predict_em, modified_entity_ids[predict_em]] predicted_entity_embeddings = self._dropout(predicted_entity_embeddings) concatenated = torch.cat((current_hidden[predict_em], predicted_entity_embeddings), dim=-1) mention_length_logits = self._mention_length_projection(concatenated) mention_length_loss = mention_length_loss + F.cross_entropy( mention_length_logits, current_mention_lengths[predict_em]) self._mention_length_accuracy(predictions=mention_length_logits, gold_labels=current_mention_lengths[predict_em]) # We add new entities to any sequence where the current entity id matches the number of # embeddings that currently exist for that sequence (this means we need a new one since # there is an additional dummy embedding). new_entities = current_entity_ids == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) # We also perform updates of the currently observed entities. self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=current_entity_ids, timestep=timestep, mask=current_entity_types) prev_mention_lengths = current_mention_lengths # Normalize the losses entity_type_loss = entity_type_loss / mask.sum() entity_id_loss = entity_id_loss / mask.sum() mention_length_loss = mention_length_loss / mask.sum() total_loss = entity_type_loss + entity_id_loss + mention_length_loss output_dict = { 'entity_type_loss': entity_type_loss, 'entity_id_loss': entity_id_loss, 'mention_length_loss': mention_length_loss, 'loss': total_loss } # Update state self._state = { 'prev_mention_lengths': prev_mention_lengths.detach() } return output_dict def reset_states(self, batch_size: int) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" self._encoder.reset_states() self._dynamic_embeddings.reset_states(batch_size) self._state = None def detach_states(self): """Detaches the model's state to enforce truncated backpropagation.""" self._dynamic_embeddings.detach_states() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'et_acc': self._entity_type_accuracy.get_metric(reset), 'eid_acc': self._entity_id_accuracy.get_metric(reset), 'ml_acc': self._mention_length_accuracy.get_metric(reset) }