def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, target_namespace: str = "tokens", attention_function: SimilarityFunction = None, scheduled_sampling_ratio: float = 0.0, label_smoothing: float = None, target_embedding_dim: int = None, target_tokens_embedder: TokenEmbedder = None) -> None: super(PretrSeq2Seq, self).__init__(vocab) self._label_smoothing = label_smoothing self._source_embedder = source_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._attention_function = attention_function self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. Also, if # we're using attention with ``DotProductSimilarity``, this is needed. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # PRETRAINED PART if target_tokens_embedder: target_embedding_dim = target_tokens_embedder.get_output_dim() self._target_embedder = target_tokens_embedder if self._attention_function: self._decoder_attention = LegacyAttention(self._attention_function) # The output of attention, a weighted average over encoder outputs, will be # concatenated to the input vector of the decoder at each time step. self._decoder_input_dim = self._encoder.get_output_dim( ) + target_embedding_dim else: self._decoder_input_dim = target_embedding_dim # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
def __init__(self, input_dim: int, decoder_rnn_output_dim: int, output_projection_dim: int, max_decoding_steps: int, attention: Attention = None, attention_function: SimilarityFunction = None, scheduled_sampling_ratio: float = 0.) -> None: super(AttentionalRnnDecoder, self).__init__() self._scheduled_sampling_ratio = scheduled_sampling_ratio self._input_dim = input_dim self._output_dim = output_projection_dim # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. decoder_rnn_input_dim = input_dim + output_projection_dim else: # Otherwise, the input to the decoder is just the previous target embedding. decoder_rnn_input_dim = output_projection_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(decoder_rnn_input_dim, decoder_rnn_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(decoder_rnn_output_dim, output_projection_dim) # At prediction time, we can use a beam search to find the most likely sequence of target tokens. # If the beam_size parameter is not given, we'll just use a greedy search (equivalent to beam_size = 1). self._max_decoding_steps = max_decoding_steps
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, attention_function: SimilarityFunction = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.5) -> None: super(Seq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None self._token_based_metric = TokenSequenceAccuracy() # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError("You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
def __init__( self, vocab: Vocabulary, encoder: Seq2SeqEncoder, input_size: int, target_embedding_dim: int, decoder_hidden_dim: int, max_decoding_steps: int, max_decoding_ratio: float = 1.5, dep_parser: Model = None, pos_tagger: Model = None, cmvn: str = 'none', delta: int = 0, time_mask_width: int = 0, freq_mask_width: int = 0, time_mask_max_ratio: float = 0.0, dec_layers: int = 1, layerwise_pretraining: List[Tuple[int, int]] = None, cnn: Seq2SeqEncoder = None, conv_lstm: Seq2SeqEncoder = None, train_at_phn_level: bool = False, rnnt_layer: Model = None, phn_ctc_layer: Model = None, ctc_layer: Model = None, projection_layer: nn.Module = None, tie_proj: bool = False, att_ratio: float = 0.0, dep_ratio: float = 0.0, pos_ratio: float = 0.0, attention: Attention = None, attention_function: SimilarityFunction = None, latency_penalty: float = 0.0, loss_type: str = "nll", beam_size: int = 1, target_namespace: str = "tokens", phoneme_target_namespace: str = "phonemes", dropout: float = 0.0, blank: str = "_", sampling_strategy: str = "max", from_candidates: bool = False, scheduled_sampling_ratio: float = 0., initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(PhnMoChA, self).__init__(vocab) self._input_size = input_size self._target_namespace = target_namespace self._phn_target_namespace = phoneme_target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._sampling_strategy = sampling_strategy self._train_at_phn_level = train_at_phn_level self._blank = blank self._dep_parser = dep_parser self._pos_tagger = pos_tagger self._ctc_layer = ctc_layer self._rnnt_layer = rnnt_layer self._phn_ctc_layer = phn_ctc_layer self._projection_layer = projection_layer if tie_proj: self._rnnt_layer.set_projection_layer(projection_layer) self._att_ratio = att_ratio self._dep_ratio = dep_ratio self._pos_ratio = pos_ratio self._loss_type = loss_type # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._phn_pad_index = self.vocab.get_token_index( self.vocab._padding_token, self._phn_target_namespace) # pylint: disable=protected-access exclude_indices = {self._pad_index, self._end_index, self._start_index} self._logs: Dict[str, Union[Metric, None]] = { "att_wer": (WER(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_bleu": (BLEU(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_loss": (Average() if self._att_ratio > 0 else None), "phn_ctc_loss": (Average() if self._phn_ctc_layer else None), "ctc_loss": (Average() if self._ctc_layer else None), "rnnt_loss": (Average() if self._rnnt_layer else None), "dal_loss": (Average() if latency_penalty > 0.0 else None), "dep_loss": (Average() if self._dep_parser else None), "pos_loss": (Average() if self._pos_tagger else None), "tag_loss": (Average() if self._dep_parser else None), "arc_loss": (Average() if self._dep_parser else None) } # At prediction time, we use a beam search to find the most likely sequence of target tokens. self._max_decoding_steps = max_decoding_steps self._max_decoding_ratio = max_decoding_ratio self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder self._cnn = cnn self._conv_lstm = conv_lstm num_classes = self.vocab.get_vocab_size(self._target_namespace) self._num_classes = num_classes # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = decoder_hidden_dim self._dec_layers = dec_layers if self._decoder_output_dim != self._encoder_output_dim: self.bridge = nn.Linear(self._encoder_output_dim, self._dec_layers * self._decoder_output_dim, bias=False) if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim self.att_out = Linear(self._decoder_output_dim + self._encoder_output_dim, self._decoder_output_dim, bias=True) else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder = nn.LSTM(self._decoder_input_dim, self._decoder_output_dim, num_layers=self._dec_layers, batch_first=True) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) self._input_norm = lambda x: x if cmvn == 'global': self._input_norm = nn.BatchNorm1d(self._input_size * (delta + 1)) elif cmvn == 'utt': self._input_norm = nn.InstanceNorm1d(self._input_size * (delta + 1)) self._delta = None if delta > 0: self._delta = Delta(order=delta) self._epoch_num = float("inf") self._layerwise_pretraining = layerwise_pretraining try: if isinstance(self._encoder, PytorchSeq2SeqWrapper): self._num_layers = self._encoder._module.num_layers else: self._num_layers = self._encoder.num_layers except AttributeError: self._num_layers = float("inf") self._output_layer_num = self._num_layers self._loss = None self._from_candidates = from_candidates if loss_type == "ocd": self._loss = OCDLoss(self._end_index, 1e-7, 1e-7, 5) elif loss_type == "edocd": self._loss = EDOCDLoss(self._end_index, 1e-7, 1e-7, 5) self._latency_penalty = latency_penalty self._target_granularity = self._target_namespace self.time_mask = TimeMask(time_mask_width, time_mask_max_ratio) self.freq_mask = FreqMask(freq_mask_width) initializer(self)
class PhnMoChA(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__( self, vocab: Vocabulary, encoder: Seq2SeqEncoder, input_size: int, target_embedding_dim: int, decoder_hidden_dim: int, max_decoding_steps: int, max_decoding_ratio: float = 1.5, dep_parser: Model = None, pos_tagger: Model = None, cmvn: str = 'none', delta: int = 0, time_mask_width: int = 0, freq_mask_width: int = 0, time_mask_max_ratio: float = 0.0, dec_layers: int = 1, layerwise_pretraining: List[Tuple[int, int]] = None, cnn: Seq2SeqEncoder = None, conv_lstm: Seq2SeqEncoder = None, train_at_phn_level: bool = False, rnnt_layer: Model = None, phn_ctc_layer: Model = None, ctc_layer: Model = None, projection_layer: nn.Module = None, tie_proj: bool = False, att_ratio: float = 0.0, dep_ratio: float = 0.0, pos_ratio: float = 0.0, attention: Attention = None, attention_function: SimilarityFunction = None, latency_penalty: float = 0.0, loss_type: str = "nll", beam_size: int = 1, target_namespace: str = "tokens", phoneme_target_namespace: str = "phonemes", dropout: float = 0.0, blank: str = "_", sampling_strategy: str = "max", from_candidates: bool = False, scheduled_sampling_ratio: float = 0., initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(PhnMoChA, self).__init__(vocab) self._input_size = input_size self._target_namespace = target_namespace self._phn_target_namespace = phoneme_target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._sampling_strategy = sampling_strategy self._train_at_phn_level = train_at_phn_level self._blank = blank self._dep_parser = dep_parser self._pos_tagger = pos_tagger self._ctc_layer = ctc_layer self._rnnt_layer = rnnt_layer self._phn_ctc_layer = phn_ctc_layer self._projection_layer = projection_layer if tie_proj: self._rnnt_layer.set_projection_layer(projection_layer) self._att_ratio = att_ratio self._dep_ratio = dep_ratio self._pos_ratio = pos_ratio self._loss_type = loss_type # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._phn_pad_index = self.vocab.get_token_index( self.vocab._padding_token, self._phn_target_namespace) # pylint: disable=protected-access exclude_indices = {self._pad_index, self._end_index, self._start_index} self._logs: Dict[str, Union[Metric, None]] = { "att_wer": (WER(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_bleu": (BLEU(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_loss": (Average() if self._att_ratio > 0 else None), "phn_ctc_loss": (Average() if self._phn_ctc_layer else None), "ctc_loss": (Average() if self._ctc_layer else None), "rnnt_loss": (Average() if self._rnnt_layer else None), "dal_loss": (Average() if latency_penalty > 0.0 else None), "dep_loss": (Average() if self._dep_parser else None), "pos_loss": (Average() if self._pos_tagger else None), "tag_loss": (Average() if self._dep_parser else None), "arc_loss": (Average() if self._dep_parser else None) } # At prediction time, we use a beam search to find the most likely sequence of target tokens. self._max_decoding_steps = max_decoding_steps self._max_decoding_ratio = max_decoding_ratio self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder self._cnn = cnn self._conv_lstm = conv_lstm num_classes = self.vocab.get_vocab_size(self._target_namespace) self._num_classes = num_classes # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = decoder_hidden_dim self._dec_layers = dec_layers if self._decoder_output_dim != self._encoder_output_dim: self.bridge = nn.Linear(self._encoder_output_dim, self._dec_layers * self._decoder_output_dim, bias=False) if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim self.att_out = Linear(self._decoder_output_dim + self._encoder_output_dim, self._decoder_output_dim, bias=True) else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder = nn.LSTM(self._decoder_input_dim, self._decoder_output_dim, num_layers=self._dec_layers, batch_first=True) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) self._input_norm = lambda x: x if cmvn == 'global': self._input_norm = nn.BatchNorm1d(self._input_size * (delta + 1)) elif cmvn == 'utt': self._input_norm = nn.InstanceNorm1d(self._input_size * (delta + 1)) self._delta = None if delta > 0: self._delta = Delta(order=delta) self._epoch_num = float("inf") self._layerwise_pretraining = layerwise_pretraining try: if isinstance(self._encoder, PytorchSeq2SeqWrapper): self._num_layers = self._encoder._module.num_layers else: self._num_layers = self._encoder.num_layers except AttributeError: self._num_layers = float("inf") self._output_layer_num = self._num_layers self._loss = None self._from_candidates = from_candidates if loss_type == "ocd": self._loss = OCDLoss(self._end_index, 1e-7, 1e-7, 5) elif loss_type == "edocd": self._loss = EDOCDLoss(self._end_index, 1e-7, 1e-7, 5) self._latency_penalty = latency_penalty self._target_granularity = self._target_namespace self.time_mask = TimeMask(time_mask_width, time_mask_max_ratio) self.freq_mask = FreqMask(freq_mask_width) initializer(self) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_features: torch.FloatTensor, source_lengths: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor] = None, words: Dict[str, torch.LongTensor] = None, segments: torch.LongTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, epoch_num: int = None, dataset: str = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ output_dict = {} if dataset is not None: self._target_granularity = dataset[0] if epoch_num is not None: self._epoch_num = epoch_num[0] self.set_output_layer_num() source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)).bool() source_features = source_features.unsqueeze(1) # make a channel dim if self._delta: source_features = self._delta(source_features) batch_size, n_channels, timesteps, feature_size = source_features.size( ) source_features = self._input_norm( source_features.transpose(-2, -1).reshape(batch_size, -1, timesteps)) \ .view(batch_size, n_channels, feature_size, timesteps).transpose(-2, -1) source_features = self.time_mask(source_features, source_mask) source_features = self.freq_mask(source_features, source_mask) source_features = source_features.masked_fill( ~source_mask.unsqueeze(1).unsqueeze(-1).expand_as(source_features), 0.0) state = self._encode(source_features, source_lengths) source_lengths = util.get_lengths_from_binary_sequence_mask( state["source_mask"]) target_tokens["mask"] = (target_tokens[self._target_namespace] != self._pad_index).bool() if self._phn_ctc_layer and \ (self._phn_target_namespace in self._target_granularity or self._train_at_phn_level): raise NotImplementedError # logits = self._projection_layer(state["encoder_outputs"]) # phn_ctc_output_dict = self._phn_ctc_layer(logits, source_lengths, target_tokens) # output_dict.update({f"phn_ctc_{key}": value for key, value in phn_ctc_output_dict.items()}) if self._rnnt_layer is not None and self._rnnt_layer.loss_ratio > 0.0: rnnt_output_dict = self._rnnt_layer(state["encoder_outputs"], source_lengths, target_tokens) output_dict.update({ f"rnnt_{key}": value for key, value in rnnt_output_dict.items() }) if self._ctc_layer is not None and self._ctc_layer.loss_ratio > 0.0: logits = self._projection_layer(state["encoder_outputs"]) ctc_output_dict = self._ctc_layer(logits, source_lengths, target_tokens) output_dict.update({ f"ctc_{key}": value for key, value in ctc_output_dict.items() }) if target_tokens and self._att_ratio > 0.0 and \ self._target_granularity == self._target_namespace: targets = target_tokens[self._target_namespace] output_dict["target_tokens"] = targets target_mask = util.get_text_field_mask(target_tokens) if self._train_at_phn_level: raise NotImplementedError # state = self._get_phn_level_representations( # state["encoder_outputs"].detach().requires_grad_(True), # state["source_mask"], # output_dict["phn_ctc"]) state = self._init_decoder_state(state) output_dict.update(self._forward_loop(state, target_tokens)) self._logs["att_wer"](output_dict["predictions"], targets) if self._dep_parser or self._pos_tagger: relevant_mask = target_mask[:, 1:] attention_contexts, _ = _remove_eos( output_dict["attention_contexts"], relevant_mask) if segments is not None: segments, _ = remove_sentence_boundaries( segments, target_mask) attention_contexts, _ = \ char_to_word(attention_contexts, segments) contexts = {"tokens": attention_contexts} if self._dep_parser: parser_outputs = self._dep_parser(contexts, pos_tags, metadata, head_tags, head_indices) parser_outputs["dep_loss"] = parser_outputs.pop("loss") output_dict.update(parser_outputs) if self._pos_tagger: tagger_outputs = self._pos_tagger(contexts, pos_tags, metadata) tagger_outputs["pos_loss"] = tagger_outputs.pop("loss") output_dict.update(tagger_outputs) if not self.training: if self._target_granularity == self._target_namespace: if self._att_ratio > 0.0: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: targets = target_tokens[self._target_namespace] # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._logs["att_bleu"](best_predictions, targets) self._logs["att_wer"](best_predictions, targets) log_dict = self.decode(output_dict) verbose_target = [ self._indices_to_tokens(tokens.tolist()[1:]) for tokens in target_tokens[self._target_namespace] ] verbose_best_pred = [ beams[0] for beams in log_dict["predicted_tokens"] ] sep = " " if self._target_namespace == 'tokens' else "" with open(f"preds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_best_pred ])) fp.write("\n") with open(f"golds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_target ])) fp.write("\n") # for gold, pred in zip(verbose_target, verbose_best_pred): # print(gold, pred) if self.training: output_dict = self._collect_losses( output_dict, ctc=(self._ctc_layer.loss_ratio if self._ctc_layer else 0), rnnt=(self._rnnt_layer.loss_ratio if self._rnnt_layer else 0), att=self._att_ratio, dal=self._latency_penalty, dep=self._dep_ratio, pos=self._pos_ratio) if torch.isnan(output_dict["loss"]).any() or \ (torch.abs(output_dict["loss"]) == float('inf')).any(): for key, _ in output_dict.items(): if "loss" in key: output_dict[key] = output_dict[key].new_zeros( size=(), requires_grad=True).clone() self._update_metrics(output_dict) return output_dict def _indices_to_tokens(self, indices): # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] return predicted_tokens @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ def _decode_predictions(input_key: str, output_key: str, beam=False): if input_key in output_dict: if beam: all_predicted_tokens = [ list(map(self._indices_to_tokens, beams)) for beams in sanitize(output_dict[input_key]) ] else: all_predicted_tokens = list( map(self._indices_to_tokens, sanitize(output_dict[input_key]))) output_dict[output_key] = all_predicted_tokens _decode_predictions("predictions", "predicted_tokens", beam=True) _decode_predictions("ctc_predictions", "ctc_predicted_tokens") _decode_predictions("rnnt_predictions", "rnnt_predicted_tokens") _decode_predictions("target_tokens", "targets") return output_dict def _encode(self, source_features: torch.FloatTensor, source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) if self._cnn is not None: source_features, source_lengths = self._cnn( source_features, source_lengths) source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)) if self._conv_lstm is not None: source_features = self._conv_lstm(source_features, source_mask) if not isinstance(self._encoder, AWDRNN): encoder_outputs = self._encoder(source_features, source_mask) else: encoder_outputs, _, source_lengths = self._encoder( source_features, source_lengths, self._output_layer_num) source_mask = util.get_mask_from_sequence_lengths( source_lengths, encoder_outputs.size(1)) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _get_phn_level_representations( self, features: torch.FloatTensor, mask: torch.BoolTensor, phn_log_probs: torch.Tensor) -> Dict[str, torch.Tensor]: phn_enc_outs, segment_lengths = averaging_tensor_of_same_label( features, phn_log_probs, mask=mask) state = { "encoder_outputs": phn_enc_outs, "source_mask": util.get_mask_from_sequence_lengths(segment_lengths, int(max(segment_lengths))) } return state def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) encoder_outputs = state["encoder_outputs"] source_mask = state["source_mask"] final_encoder_output = util.get_final_encoder_states( encoder_outputs, source_mask, self._encoder.is_bidirectional()) if self._encoder_output_dim != self._dec_layers * self._decoder_output_dim: final_encoder_output = self.bridge(final_encoder_output) initial_decoder_input = final_encoder_output.view(-1, self._dec_layers, self._decoder_output_dim) \ .contiguous() # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = initial_decoder_input state["decoder_output"] = initial_decoder_input[:, 0] # shape: (batch_size, decoder_output_dim) state["decoder_context"] = encoder_outputs.new_zeros( batch_size, self._dec_layers, self._decoder_output_dim) state["attention"] = None if isinstance(self._attention, StatefulAttention): state["att_keys"], state["att_values"] = \ self._attention.init_state(encoder_outputs) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] candidates = None if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens[self._target_namespace] _, target_sequence_length = targets.size() if self._loss is not None: candidates = target_to_candidates(targets, self._num_classes, ignore_indices=[ self._pad_index, self._start_index, self._end_index ]) # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. if isinstance(self._loss, EDOCDLoss): num_decoding_steps = int( target_sequence_length * self._max_decoding_ratio) - 1 else: num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attns: List[torch.Tensor] = [] step_attn_cxts: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif self._loss is not None: # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_encoding_steps) if self._attention: step_attns.append(state["attention"].unsqueeze(1)) step_attn_cxts.append(state["attention_contexts"].unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) predicted_classes = maybe_sample_from_candidates( class_probabilities, candidates=(candidates if self._from_candidates else None), strategy=self._sampling_strategy) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = { "predictions": predictions, } # shape: (batch_size, num_decoding_steps, num_encoding_steps) if self._attention: output_dict["attentions"] = torch.cat(step_attns, 1) output_dict["attention_contexts"] = torch.cat(step_attn_cxts, 1) if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, predictions, targets, target_mask, candidates) output_dict["att_loss"] = loss if self._latency_penalty > 0.0: DAL = differentiable_average_lagging(output_dict["attentions"], source_mask, target_mask[:, 1:]) output_dict["dal"] = DAL return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, decoder_output_dim) decoder_output = state["decoder_output"] attention = state["attention"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((embedded_input, decoder_output), -1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) outputs, (decoder_hidden, decoder_context) = self._decoder( decoder_input.unsqueeze(1), (decoder_hidden.transpose(1, 0).contiguous(), decoder_context.transpose(1, 0).contiguous())) decoder_hidden = decoder_hidden.transpose(1, 0).contiguous() decoder_context = decoder_context.transpose(1, 0).contiguous() outputs = outputs.squeeze(1) if self._attention: # shape: (group_size, encoder_output_dim) attended_output, attention = self._prepare_attended_output( outputs, state) # shape: (group_size, decoder_output_dim) decoder_output = torch.tanh( self.att_out(torch.cat((attended_output, outputs), -1))) state["attention"] = attention state["attention_contexts"] = attended_output else: # shape: (group_size, target_embedding_dim) decoder_output = outputs state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context state["decoder_output"] = decoder_output # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _prepare_attended_output( self, decoder_hidden_state: torch.Tensor, state: Dict[str, torch.Tensor]) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs = state["encoder_outputs"] source_mask = state["source_mask"] prev_attention = state["attention"] att_keys = state["att_keys"] att_values = state["att_values"] # shape: (batch_size, max_input_sequence_length) mode = "soft" if self.training else "hard" if isinstance(self._attention, MonotonicAttention): encoder_outs: Dict[str, torch.Tensor] = { "value": state["encoder_outputs"], "mask": state["source_mask"] } monotonic_attention, chunk_attention = self._attention( encoder_outs, decoder_hidden_state, prev_attention, mode=mode) # shape: (batch_size, encoder_output_dim) attended_output = util.weighted_sum(encoder_outputs, chunk_attention) attention = monotonic_attention elif isinstance(self._attention, StatefulAttention): attended_output, attention = self._attention( decoder_hidden_state, att_keys, att_values, source_mask) else: attention = self._attention(decoder_hidden_state, source_mask) attended_output = util.weighted_sum(encoder_outputs, attention) return attended_output, attention # @staticmethod def _get_loss(self, logits: torch.FloatTensor, predictions: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, candidates: torch.LongTensor = None) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() if self._loss is not None: if isinstance(self._loss, OCDLoss) or isinstance( self._loss, EDOCDLoss): self._loss.update_temperature(self._epoch_num) if isinstance(self._loss, EDOCDLoss): log_probs = F.log_softmax(logits, dim=-1) return self._loss(log_probs, predictions, relevant_targets, relevant_mask) else: raise NotImplementedError return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) def _collect_losses(self, output_dict: Dict[str, torch.Tensor], phn_ctc: float = 1.0, ctc: float = 1.0, rnnt: float = 1.0, att: float = 1.0, dep: float = 1.0, pos: float = 1.0, dal: float = 1.0) -> torch.Tensor: loss = 0.0 if "phn_ctc_loss" in output_dict: loss += phn_ctc * output_dict["phn_ctc_loss"] if "ctc_loss" in output_dict: loss += ctc * output_dict["ctc_loss"] if "rnnt_loss" in output_dict: loss += rnnt * output_dict["rnnt_loss"] if "att_loss" in output_dict: loss += att * output_dict["att_loss"] if "dep_loss" in output_dict: loss += dep * output_dict["dep_loss"] if "pos_loss" in output_dict: loss += pos * output_dict["pos_loss"] if "dal" in output_dict: loss += dal * output_dict["dal"] output_dict["loss"] = loss return output_dict def _update_metrics(self, output_dict: Dict[str, torch.Tensor]) -> torch.Tensor: for key, track_func in self._logs.items(): try: value = output_dict[key] value = value.item() if isinstance(value, torch.Tensor) else value track_func(value) except KeyError: continue @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} for key, metric_tracker in self._logs.items(): if "phn" in key and self._phn_target_namespace not in self._target_granularity: continue if "att" in key and self._target_namespace not in self._target_granularity: continue if metric_tracker is not None: metric_values = metric_tracker.get_metric(reset=reset) if isinstance(metric_values, dict): all_metrics.update(metric_values) else: all_metrics[key] = metric_values if self._ctc_layer: all_metrics.update({ f"ctc_{key}": value for key, value in self._ctc_layer.get_metrics( reset=reset).items() }) if self._rnnt_layer: all_metrics.update({ f"rnnt_{key}": value for key, value in self._rnnt_layer.get_metrics( reset=reset).items() }) if not self.training: if self._dep_parser: all_metrics.update(self._dep_parser.get_metrics(reset=reset)) if self._pos_tagger: all_metrics.update(self._pos_tagger.get_metrics(reset=reset)) return all_metrics def set_output_layer_num(self): output_layer_num = self._num_layers if self._layerwise_pretraining is not None: for epoch, layer_num in self._layerwise_pretraining: if self._epoch_num < epoch: break output_layer_num = layer_num self._output_layer_num = output_layer_num return output_layer_num
def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, target_namespace: str = "tokens", target_embedding_dim: int = None, attention_function: SimilarityFunction = None, scheduled_sampling_ratio: float = 0.0, weight_function="softmax", gumbel_tau: float = 0.66, gumbel_hard: bool = True, gumbel_eps: float = 1e-10, infer_with: str = "distribution", self_feed_with: str = "argmax_distribution", ) -> None: super(Rnn2RnnDifferentiableNll, self).__init__(vocab) self._source_embedder = source_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._attention_function = attention_function self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. Also, if # we're using attention with ``DotProductSimilarity``, this is needed. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) if self._attention_function: self._decoder_attention = LegacyAttention(self._attention_function) # The output of attention, a weighted average over encoder outputs, will be # concatenated to the input vector of the decoder at each time step. self._decoder_input_dim = self._encoder.get_output_dim( ) + target_embedding_dim else: self._decoder_input_dim = target_embedding_dim # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) self._weights_calculation_function = weight_function self._gumbel_tau = gumbel_tau self._gumbel_hard = gumbel_hard self._gamble_eps = gumbel_eps if self_feed_with not in { "distribution", "argmax_logits", "argmax_distribution", "detach_distribution" }: raise ValueError( "Allowed vals for selffeed are {distribution, argmax_logits, argmax_distribution, detach_distribution}" ) if infer_with not in { "distribution", "argmax_logits", "argmax_distribution" }: raise ValueError( "Allowed vals for ifer_with are {distribution, argmax_logits, argmax_distribution}" ) self._infer_with = infer_with self._self_feed_with = self_feed_with
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, intent_encoder: Seq2SeqEncoder = None, tag_encoder: Seq2SeqEncoder = None, attention: Attention = None, attention_function: SimilarityFunction = None, context_for_intent: bool = True, context_for_tag: bool = True, attention_for_intent: bool = True, attention_for_tag: bool = True, sequence_label_namespace: str = "labels", intent_label_namespace: str = "intent_labels", feedforward: Optional[FeedForward] = None, label_encoding: Optional[str] = None, include_start_end_transitions: bool = True, crf_decoding: bool = False, constrain_crf_decoding: bool = None, focal_loss_gamma: float = None, nongeneral_intent_weight: float = 5., num_train_examples: float = None, calculate_span_f1: bool = None, dropout: Optional[float] = None, verbose_metrics: bool = False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self.context_for_intent = context_for_intent self.context_for_tag = context_for_tag self.attention_for_intent = attention_for_intent self.attention_for_tag = attention_for_tag self.sequence_label_namespace = sequence_label_namespace self.intent_label_namespace = intent_label_namespace self.text_field_embedder = text_field_embedder self.num_tags = self.vocab.get_vocab_size(sequence_label_namespace) self.num_intents = self.vocab.get_vocab_size(intent_label_namespace) self.encoder = encoder self.intent_encoder = intent_encoder self.tag_encoder = intent_encoder self._feedforward = feedforward self._verbose_metrics = verbose_metrics self.rl = False if attention: if attention_function: raise ConfigurationError("You can only specify an attention module or an " "attention function, but not both.") self.attention = attention elif attention_function: self.attention = LegacyAttention(attention_function) if dropout: self.dropout = torch.nn.Dropout(dropout) else: self.dropout = None projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim() if self.context_for_intent: projection_input_dim += self.encoder.get_output_dim() if self.attention_for_intent: projection_input_dim += self.encoder.get_output_dim() self.intent_projection_layer = Linear(projection_input_dim, self.num_intents) if num_train_examples: try: pos_weight = torch.tensor([log10((num_train_examples - self.vocab._retained_counter[intent_label_namespace][t]) / self.vocab._retained_counter[intent_label_namespace][t]) for i, t in self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) except: pos_weight = torch.tensor([1. for i, t in self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) else: # pos_weight = torch.tensor([(lambda t: 1. if "general" in t else nongeneral_intent_weight)(t) for i, t in pos_weight = torch.tensor([(lambda t: nongeneral_intent_weight if "Request" in t else 1.)(t) for i, t in self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) self.intent_loss = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none") tag_projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim() if self.context_for_tag: tag_projection_input_dim += self.encoder.get_output_dim() if self.attention_for_tag: tag_projection_input_dim += self.encoder.get_output_dim() self.tag_projection_layer = TimeDistributed(Linear(tag_projection_input_dim, self.num_tags)) # if constrain_crf_decoding and calculate_span_f1 are not # provided, (i.e., they're None), set them to True # if label_encoding is provided and False if it isn't. if constrain_crf_decoding is None: constrain_crf_decoding = label_encoding is not None if calculate_span_f1 is None: calculate_span_f1 = label_encoding is not None self.label_encoding = label_encoding if constrain_crf_decoding: if not label_encoding: raise ConfigurationError("constrain_crf_decoding is True, but " "no label_encoding was specified.") labels = self.vocab.get_index_to_token_vocabulary(sequence_label_namespace) constraints = allowed_transitions(label_encoding, labels) else: constraints = None self.include_start_end_transitions = include_start_end_transitions if crf_decoding: self.crf = ConditionalRandomField( self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions ) else: self.crf = None self._intent_f1_metric = MultiLabelF1Measure(vocab, namespace=intent_label_namespace) self.calculate_span_f1 = calculate_span_f1 if calculate_span_f1: if not label_encoding: raise ConfigurationError("calculate_span_f1 is True, but " "no label_encoding was specified.") self._f1_metric = SpanBasedF1Measure(vocab, tag_namespace=sequence_label_namespace, label_encoding=label_encoding) self._dai_f1_metric = DialogActItemF1Measure() check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim") if feedforward is not None: check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(), "encoder output dim", "feedforward input dim") initializer(self)
def __init__(self, vocab: Vocabulary, source_embedding: Embedding, target_embedding: Embedding, encoder: Seq2SeqEncoder, target_namespace: str, max_decoding_steps: int, attention: Attention = None, attention_function: SimilarityFunction = None, beam_size: int = None, scheduled_sampling_ratio: float = 0.) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) # Dense embedding of source vocab tokens. self._source_embedding = source_embedding # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. self._target_embedding = target_embedding target_embedding_dim = self._target_embedding.get_output_dim() # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder output will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) # At prediction time, we can use a beam search to find the most likely sequence of target tokens. # If the beam_size parameter is not given, we'll just use a greedy search (equivalent to beam_size = 1). self._max_decoding_steps = max_decoding_steps if beam_size is not None: self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) else: self._beam_search = None