class Stack(torch.nn.Module): """Summary Attributes: models (TYPE): Description """ def __init__(self, model_dict, mode='sum'): """Summary Args: model_dict (TYPE): Description mode (str, optional): Description """ super().__init__() self.models = ModuleDict(model_dict) def forward(self, x): """Summary Args: x (TYPE): Description Returns: TYPE: Description """ for i, key in enumerate(self.models.keys()): if i == 0: result = self.models[key](x).sum().reshape(-1) else: new_result = self.models[key](x).sum().reshape(-1) result += new_result return result
def __init__(self, backbone: nn.Module, decoders: nn.ModuleDict, tasks: list): super(MultiTaskModel, self).__init__() assert (set(decoders.keys()) == set(tasks)) self.backbone = backbone self.decoders = decoders self.tasks = tasks
class Stack(torch.nn.Module): def __init__(self, model_dict, mode='sum'): super().__init__() implemented_mode = ['sum', 'mean'] if mode not in implemented_mode: raise NotImplementedError( '{} mode is not implemented for Stack'.format(key)) # to implement a check for readout keys self.models = ModuleDict(model_dict) self.mode = mode def forward(self, batch, keys_to_combine=['energy', 'energy_grad']): # run models result_list = [self.models[key](batch) for key in self.models.keys()] # perform further operations combine_results = dict() for i, result in enumerate(result_list): for key in keys_to_combine: if i != 0: combine_results[key] += result[key] else: combine_results[key] = result[key] if self.mode == 'mean': for key in keys_to_combine: combine_results[key] /= len(result_list) return combine_results
def __init__(self, tasks: list, loss_ft: nn.ModuleDict, loss_weights: dict): super(MultiTaskLoss, self).__init__() assert (set(tasks) == set(loss_ft.keys())) assert (set(tasks) == set(loss_weights.keys())) self.tasks = tasks self.loss_ft = loss_ft self.loss_weights = loss_weights
def _get_input_features(self, module_dict: ModuleDict) -> int: r""" Get the number of input features by the ModuleDict input instance Args: module_dict (): the dictionary which contains the keys and their corresponding modules. Returns: the number of features used. """ indices = [self._get_keys(key) for key in module_dict.keys()] return len(set(flatten(indices)))
def filter_models(models: nn.ModuleDict, keys) -> nn.ModuleDict: if isinstance(keys, list): missing_keys = set(keys) - models.keys() checkraise( len(missing_keys) == 0, ValueError, 'models dictionary does not contains keys {}', missing_keys, ) return nn.ModuleDict({k: models[k] for k in keys}) if isinstance(keys, dict): missing_keys = set(keys.keys()) - models.keys() checkraise( len(missing_keys) == 0, ValueError, 'models dictionary does not contains keys {}', missing_keys, ) return nn.ModuleDict( {k: filter_models(models[k], v) for k, v in keys.items()}) raise NotImplementedError
def apply_module_dict(modules: nn.ModuleDict, encoded: torch.Tensor, **kwargs) -> torch.Tensor: """ Predict next entry using given module and equation. :param nn.ModuleDict modules: Dictionary of modules to be applied. Modules will be applied with ascending order of keys. We expect three types of modules: nn.Linear, nn.LayerNorm and MultiheadAttention. :param torch.Tensor encoded: Float Tensor that represents encoded vectors. Shape [B, T, H], where B = batch size, T = length of equation, and H = hidden dimension. :keyword torch.Tensor key_value: Float Tensor that represents key and value vectors when computing attention. Shape [B, K, H], where K = length of keys :keyword torch.Tensor key_ignorance_mask: Bool Tensor whose True values at (b, k) make attention layer ignore k-th key on b-th item in the batch. Shape [B, K]. :keyword attention_mask: Bool Tensor whose True values at (t, k) make attention layer ignore k-th key when computing t-th query. Shape [T, K]. :rtype: torch.Tensor :return: Float Tensor that indicates the scores under given information. Shape will be [B, T, ?] """ output = encoded keys = sorted(modules.keys()) # Apply modules (ascending order of keys). for key in keys: layer = modules[key] if isinstance(layer, (MultiheadAttention, MultiheadAttentionWeights)): output = layer(query=output, **kwargs) else: output = layer(output) return output
class Event2Mind(Model): """ This ``Event2Mind`` class is a :class:`Model` which takes an event sequence, encodes it, and then uses the encoded representation to decode several mental state sequences. It is based on `the paper by Rashkin et al. <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_ Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (``tokens``) or the target tokens can have a different namespace, in which case it needs to be specified as ``target_namespace``. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences. embedding_dropout: float, required The amount of dropout to apply after the source tokens have been embedded. encoder : ``Seq2VecEncoder``, required The encoder of the "encoder/decoder" model. max_decoding_steps : int, required Length of decoded sequences. beam_size : int, optional (default = 10) The width of the beam search. target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact']) Names of the target fields matching those in the ``Instance`` objects. target_namespace : str, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : int, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, embedding_dropout: float, encoder: Seq2VecEncoder, max_decoding_steps: int, beam_size: int = 10, target_names: List[str] = None, target_namespace: str = "tokens", target_embedding_dim: int = None) -> None: super().__init__(vocab) target_names = target_names or ["xintent", "xreact", "oreact"] # Note: The original tweaks the embeddings for "personx" to be the mean # across the embeddings for "he", "she", "him" and "her". Similarly for # "personx's" and so forth. We could consider that here as a well. self._source_embedder = source_embedder self._embedding_dropout = nn.Dropout(embedding_dropout) self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) # Warning: The different decoders share a vocabulary! This may be # counterintuitive, but consider the case of xreact and oreact. A # reaction of "happy" could easily apply to both the subject of the # event and others. This could become less appropriate as more decoders # are added. num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim( ) self._states = ModuleDict() for name in target_names: self._states[name] = StateDecoder(num_classes, target_embedding_dim, self._decoder_output_dim) self._beam_search = BeamSearch(self._end_index, beam_size=beam_size, max_steps=max_decoding_steps) def _update_recall(self, all_top_k_predictions: torch.Tensor, target_tokens: Dict[str, torch.LongTensor], target_recall: UnigramRecall) -> None: targets = target_tokens["tokens"] target_mask = get_text_field_mask(target_tokens) # See comment in _get_loss. # TODO(brendanr): Do we need contiguous here? relevant_targets = targets[:, 1:].contiguous() relevant_mask = target_mask[:, 1:].contiguous() target_recall(all_top_k_predictions, relevant_targets, relevant_mask, self._end_index) def _get_num_decoding_steps( self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int: if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end # symbol. Either way, we don't have to process it. (To be clear, # we do still output and compare against the end symbol, but there # is no need to take the end symbol as input to the decoder.) return target_sequence_length - 1 else: return self._max_decoding_steps @overrides def forward( self, # type: ignore source: Dict[str, torch.LongTensor], **target_tokens: Dict[str, Dict[str, torch.LongTensor]] ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the target sequences. Parameters ---------- source : ``Dict[str, torch.LongTensor]`` The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``: Dictionary from name to output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, embedding_dim) embedded_input = self._embedding_dropout(self._source_embedder(source)) source_mask = get_text_field_mask(source) # (batch_size, encoder_output_dim) final_encoder_output = self._encoder(embedded_input, source_mask) output_dict = {} # Perform greedy search so we can get the loss. if target_tokens: if target_tokens.keys() != self._states.keys(): target_only = target_tokens.keys() - self._states.keys() states_only = self._states.keys() - target_tokens.keys() raise Exception( "Mismatch between target_tokens and self._states. Keys in " + f"targets only: {target_only} Keys in states only: {states_only}" ) total_loss = 0 for name, state in self._states.items(): loss = self.greedy_search( final_encoder_output=final_encoder_output, target_tokens=target_tokens[name], target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer) total_loss += loss output_dict[f"{name}_loss"] = loss # Use mean loss (instead of the sum of the losses) to be comparable to the paper. output_dict["loss"] = total_loss / len(self._states) # Perform beam search to obtain the predictions. if not self.training: batch_size = final_encoder_output.size()[0] for name, state in self._states.items(): start_predictions = final_encoder_output.new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} # (batch_size, 10, num_decoding_steps) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, start_state, state.take_step) if target_tokens: self._update_recall(all_top_k_predictions, target_tokens[name], state.recall) output_dict[ f"{name}_top_k_predictions"] = all_top_k_predictions output_dict[ f"{name}_top_k_log_probabilities"] = log_probabilities return output_dict def greedy_search(self, final_encoder_output: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor], target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.FloatTensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the cross entropy between this sequence and ``target_tokens``. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()`` applied on some target ``TextField``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._get_num_decoding_steps(target_tokens) targets = target_tokens["tokens"] decoder_hidden = final_encoder_output step_logits = [] for timestep in range(num_decoding_steps): # See https://github.com/allenai/allennlp/issues/1134. input_choices = targets[:, timestep] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) target_mask = get_text_field_mask(target_tokens) return self._get_loss(logits, targets, target_mask) def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [ final_encoder_output.new_full((batch_size, ), fill_value=self._start_index, dtype=torch.long) ] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:] @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.FloatTensor: """ Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ relevant_targets = targets[:, 1:].contiguous( ) # (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous( ) # (batch_size, num_decoding_steps) loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) return loss def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds fields for the tokens to the ``output_dict``. """ for name in self._states: top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][ 0] output_dict[f"{name}_top_k_predicted_tokens"] = [ self.decode_all(top_k_predicted_indices) ] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics = {} # Recall@10 needs beam search which doesn't happen during training. if not self.training: for name, state in self._states.items(): all_metrics[name] = state.recall.get_metric(reset=reset) return all_metrics
def get_trunk_forward_outputs( feat: torch.Tensor, out_feat_keys: List[str], feature_blocks: nn.ModuleDict, feature_mapping: Dict[str, str] = None, use_checkpointing: bool = False, checkpointing_splits: int = 2, ) -> List[torch.Tensor]: """ Args: feat: model input. out_feat_keys: a list/tuple with the feature names of the features that the function should return. By default the last feature of the network is returned. feature_blocks: ModuleDict containing feature blocks in the model feature_mapping: an optional correspondence table in between the requested feature names and the model's. Returns: out_feats: a list with the asked output features placed in the same order as in `out_feat_keys`. """ # Sanitize inputs if feature_mapping is not None: out_feat_keys = [feature_mapping[f] for f in out_feat_keys] out_feat_keys, max_out_feat = parse_out_keys_arg( out_feat_keys, list(feature_blocks.keys()) ) # Forward pass over the trunk unique_out_feats = {} unique_out_feat_keys = list(set(out_feat_keys)) # FIXME: Ideally this should only be done once at construction time if use_checkpointing: feature_blocks = checkpoint_trunk( feature_blocks, unique_out_feat_keys, checkpointing_splits ) # If feat is the first input to the network, it doesn't have requires_grad, # which will make checkpoint's backward function not being called. So we need # to set it to true here. feat.requires_grad = True # Go through the blocks, and save the features as we go # NOTE: we are not doing several forward passes but instead just checking # whether the feature is requested to be returned. for i, (feature_name, feature_block) in enumerate(feature_blocks.items()): # The last chunk has to be non-volatile if use_checkpointing and i < len(feature_blocks) - 1: # Un-freeze the running stats in any BN layer for m in filter(lambda x: isinstance(x, _bn_cls), feature_block.modules()): m.track_running_stats = m.training feat = checkpoint(feature_block, feat) # Freeze the running stats in any BN layer # the checkpointing process will have to do another FW pass for m in filter(lambda x: isinstance(x, _bn_cls), feature_block.modules()): m.track_running_stats = False else: feat = feature_block(feat) # This feature is requested, store. If the same feature is requested several # times, we return the feature several times. if feature_name in unique_out_feat_keys: unique_out_feats[feature_name] = feat # Early exit if all the features have been collected if i == max_out_feat and not use_checkpointing: break # now return the features as requested by the user. If there are no duplicate keys, # return as is. if len(unique_out_feat_keys) == len(out_feat_keys): return list(unique_out_feats.values()) output_feats = [] for key_name in out_feat_keys: output_feats.append(unique_out_feats[key_name]) return output_feats
def get_tunk_forward_interpolated_outputs( input_type: str, # bgr or rgb or lab interpolate_out_feat_key_name: str, remove_padding_before_feat_key_name: str, feat: MultiDimensionalTensor, feature_blocks: nn.ModuleDict, feature_mapping: Dict[str, str] = None, use_checkpointing: bool = False, checkpointing_splits: int = 2, ) -> List[torch.Tensor]: """ Args: input_type (AttrDict): whether the model input should be RGB or BGR or LAB interpolate_out_feat_key_name (str): what feature dimensions should be used to interpolate the mask remove_padding_before_feat_key_name (str): name of the feature block for which the input should have padding removed using the interpolated mask feat (MultiDimensionalTensor): model input feature_blocks (nn.ModuleDict): ModuleDict containing feature blocks in the model feature_mapping (Dict[str, str]): an optional correspondence table in between the requested feature names and the model's. Returns: out_feats: a list with the asked output features placed in the same order as in `out_feat_keys`. """ if feature_mapping is not None: interpolate_out_feat_key_name = feature_mapping[interpolate_out_feat_key_name] model_input = transform_model_input_data_type(feat.tensor, input_type) out = get_trunk_forward_outputs( feat=model_input, out_feat_keys=[interpolate_out_feat_key_name], feature_blocks=feature_blocks, use_checkpointing=use_checkpointing, checkpointing_splits=checkpointing_splits, ) # mask is of shape N x H x W and has 1.0 value for places with padding # we interpolate the mask spatially to N x out.shape[-2] x out.shape[-1]. interp_mask = F.interpolate(feat.mask[None].float(), size=out[0].shape[-2:]).to( torch.bool )[0] # we want to iterate over the rest of the feature blocks now _, max_out_feat = parse_out_keys_arg( [interpolate_out_feat_key_name], list(feature_blocks.keys()) ) for i, (feature_name, feature_block) in enumerate(feature_blocks.items()): # We have already done the forward till the max_out_feat. # we forward through rest of the blocks now. if i >= (max_out_feat + 1): if remove_padding_before_feat_key_name and ( feature_name == remove_padding_before_feat_key_name ): # negate the mask so that the padded locations have 0.0 and the non-padded # locations have 1.0. This is used to extract the h, w of the original tensors. interp_mask = (~interp_mask).chunk(len(feat.image_sizes)) tensors = out[0].chunk(len(feat.image_sizes)) res = [] for i, tensor in enumerate(tensors): w = torch.sum(interp_mask[i][0], dim=0)[0] h = torch.sum(interp_mask[i][0], dim=1)[0] res.append(feature_block(tensor[:, :, :w, :h])) out[0] = torch.cat(res) else: out[0] = feature_block(out[0]) return out
class Event2Mind(Model): """ This ``Event2Mind`` class is a :class:`Model` which takes an event sequence, encodes it, and then uses the encoded representation to decode several mental state sequences. It is based on `the paper by Rashkin et al. <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_ Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (``tokens``) or the target tokens can have a different namespace, in which case it needs to be specified as ``target_namespace``. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences. embedding_dropout: float, required The amount of dropout to apply after the source tokens have been embedded. encoder : ``Seq2VecEncoder``, required The encoder of the "encoder/decoder" model. max_decoding_steps : int, required Length of decoded sequences. beam_size : int, optional (default = 10) The width of the beam search. target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact']) Names of the target fields matching those in the ``Instance`` objects. target_namespace : str, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : int, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, embedding_dropout: float, encoder: Seq2VecEncoder, max_decoding_steps: int, beam_size: int = 10, target_names: List[str] = None, target_namespace: str = "tokens", target_embedding_dim: int = None) -> None: super().__init__(vocab) target_names = target_names or ["xintent", "xreact", "oreact"] # Note: The original tweaks the embeddings for "personx" to be the mean # across the embeddings for "he", "she", "him" and "her". Similarly for # "personx's" and so forth. We could consider that here as a well. self._source_embedder = source_embedder self._embedding_dropout = nn.Dropout(embedding_dropout) self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) # Warning: The different decoders share a vocabulary! This may be # counterintuitive, but consider the case of xreact and oreact. A # reaction of "happy" could easily apply to both the subject of the # event and others. This could become less appropriate as more decoders # are added. num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim() self._states = ModuleDict() for name in target_names: self._states[name] = StateDecoder( num_classes, target_embedding_dim, self._decoder_output_dim ) self._beam_search = BeamSearch( self._end_index, beam_size=beam_size, max_steps=max_decoding_steps ) def _update_recall(self, all_top_k_predictions: torch.Tensor, target_tokens: Dict[str, torch.LongTensor], target_recall: UnigramRecall) -> None: targets = target_tokens["tokens"] target_mask = get_text_field_mask(target_tokens) # See comment in _get_loss. # TODO(brendanr): Do we need contiguous here? relevant_targets = targets[:, 1:].contiguous() relevant_mask = target_mask[:, 1:].contiguous() target_recall( all_top_k_predictions, relevant_targets, relevant_mask, self._end_index ) def _get_num_decoding_steps(self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int: if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end # symbol. Either way, we don't have to process it. (To be clear, # we do still output and compare against the end symbol, but there # is no need to take the end symbol as input to the decoder.) return target_sequence_length - 1 else: return self._max_decoding_steps @overrides def forward(self, # type: ignore source: Dict[str, torch.LongTensor], **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the target sequences. Parameters ---------- source : ``Dict[str, torch.LongTensor]`` The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``: Dictionary from name to output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, embedding_dim) embedded_input = self._embedding_dropout(self._source_embedder(source)) source_mask = get_text_field_mask(source) # (batch_size, encoder_output_dim) final_encoder_output = self._encoder(embedded_input, source_mask) output_dict = {} # Perform greedy search so we can get the loss. if target_tokens: if target_tokens.keys() != self._states.keys(): target_only = target_tokens.keys() - self._states.keys() states_only = self._states.keys() - target_tokens.keys() raise Exception("Mismatch between target_tokens and self._states. Keys in " + f"targets only: {target_only} Keys in states only: {states_only}") total_loss = 0 for name, state in self._states.items(): loss = self.greedy_search( final_encoder_output=final_encoder_output, target_tokens=target_tokens[name], target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer ) total_loss += loss output_dict[f"{name}_loss"] = loss # Use mean loss (instead of the sum of the losses) to be comparable to the paper. output_dict["loss"] = total_loss / len(self._states) # Perform beam search to obtain the predictions. if not self.training: batch_size = final_encoder_output.size()[0] for name, state in self._states.items(): start_predictions = final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} # (batch_size, 10, num_decoding_steps) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, start_state, state.take_step) if target_tokens: self._update_recall(all_top_k_predictions, target_tokens[name], state.recall) output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities return output_dict def greedy_search(self, final_encoder_output: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor], target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.FloatTensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the cross entropy between this sequence and ``target_tokens``. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()`` applied on some target ``TextField``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._get_num_decoding_steps(target_tokens) targets = target_tokens["tokens"] decoder_hidden = final_encoder_output step_logits = [] for timestep in range(num_decoding_steps): # See https://github.com/allenai/allennlp/issues/1134. input_choices = targets[:, timestep] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) target_mask = get_text_field_mask(target_tokens) return self._get_loss(logits, targets, target_mask) def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long )] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:] @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.FloatTensor: """ Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ relevant_targets = targets[:, 1:].contiguous() # (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() # (batch_size, num_decoding_steps) loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) return loss def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds fields for the tokens to the ``output_dict``. """ for name in self._states: top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0] output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics = {} # Recall@10 needs beam search which doesn't happen during training. if not self.training: for name, state in self._states.items(): all_metrics[name] = state.recall.get_metric(reset=reset) return all_metrics