Beispiel #1
0
 def forward_padding_mask(
     self,
     features: torch.Tensor,
     padding_mask: torch.Tensor,
 ) -> torch.Tensor:
     extra = padding_mask.size(1) % features.size(1)
     if extra > 0:
         padding_mask = padding_mask[:, :-extra]
     padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
     padding_mask = padding_mask.all(-1)
     return padding_mask
Beispiel #2
0
    def generate_mask(
        self,
        silence: Tensor,
    ):
        """
        :param silence: bool (batch_size, length)
        :return:
            output: bool (batch_size, ?)
        """
        window_length = 1 + numpy.sum(2**numpy.arange(1, self.layer_num + 1))

        silence = silence.unsqueeze(2)
        silence = silence.as_strided(
            size=(silence.shape[0], silence.shape[1] - (window_length - 1),
                  window_length),
            stride=(1, 1, 1),
        )
        return ~(silence.all(dim=2))
Beispiel #3
0
    def masked_select(self, mask: Tensor) -> Hypotheses:
        """Returns a new instance of :class:`~Hypotheses` where all its attributes are selected from this instance and
        along the batch dimension, specified as the boolean mask which is a :class:`~BoolTensor`. The padding columns of
        new `sequences`, `dec_out` and `lm_dec_out` are also truncated.
        Note: this function will NOT modify this instance.

        Args:
            mask (Tensor): mask bool tensor of shape `(batch,)`

        Returns:
            hyps (Hypotheses): selected hypotheses
        """
        assert self.size() == mask.size(0)

        if self.size() == 0 or mask.all():
            return self

        scores = self.scores[mask]
        sequence_lengths = self.sequence_lengths[mask]
        num_emissions = self.num_emissions[mask]
        max_length = sequence_lengths.max() if sequence_lengths.size(0) > 0 else None
        sequences = self.sequences[:, :max_length][mask, :]
        if self.cached_state is not None:
            cached_state = {}
            for k, v in self.cached_state.items():
                if v is not None:
                    cached_state[k] = self.cached_state[k][:, mask, ...]
                else:
                    cached_state[k] = None
        else:
            cached_state = None
        if self.dec_out is not None:
            dec_out = self.dec_out[:, :max_length, :][mask, :, :]
        else:
            dec_out = None
        if self.prev_tokens is not None:
            prev_tokens = self.prev_tokens[mask]
        else:
            prev_tokens = None
        if self.alignments is not None:
            alignments = self.alignments[mask, :, :]
        else:
            alignments = None
        if self.lm_scores is not None:
            lm_scores = self.lm_scores[mask]
        else:
            lm_scores = None
        if self.lm_cached_state is not None:
            lm_cached_state = {}
            for k, v in self.lm_cached_state.items():
                if v is not None:
                    lm_cached_state[k] = self.lm_cached_state[k][:, mask, ...]
                else:
                    lm_cached_state[k] = None
        else:
            lm_cached_state = None
        if self.lm_dec_out is not None:
            lm_dec_out = self.lm_dec_out[:, :max_length, :][mask, :, :]
        else:
            lm_dec_out = None

        return Hypotheses(
            scores=scores,
            sequences=sequences,
            sequence_lengths=sequence_lengths,
            num_emissions=num_emissions,
            cached_state=cached_state,
            dec_out=dec_out,
            alignments=alignments,
            prev_tokens=prev_tokens,
            lm_scores=lm_scores,
            lm_cached_state=lm_cached_state,
            lm_dec_out=lm_dec_out,
        )
Beispiel #4
0
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # THE BELOW ONLY NEEDS TO BE SATISFIED FOR THE FANCY ITERATOR, MERITY
        # ET AL JUST PROPOGATE THE HIDDEN STATE NO MATTER WHAT
        # To make life easier when evaluating the model we use a BasicIterator
        # so that we do not need to worry about the sequence truncation
        # performed by our splitting iterators. To accomodate this, we assume
        # that if reset is not given, then everything gets reset.
        if reset is None:
            self._state = None
        elif reset.all() and (self._state is not None):
            logger.debug('RESET')
            self._state = None
        elif reset.any() and (self._state is not None):
            for layer in range(self.num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)

        target_mask = get_text_field_mask(target)
        source = source['tokens']
        target = target['tokens']

        embeddings = embedded_dropout(self.embedder, source,
                                      dropout=self.dropoute if self.training else 0)
        embeddings = self.locked_dropout(embeddings, self.dropouti)

        # Iterate through RNN layers
        current_input = embeddings
        current_hidden = []
        outputs = []
        dropped_outputs = []
        for layer, rnn in enumerate(self.rnns):

            # Bookkeeping
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None

            # Forward-pass
            output, hidden = rnn(current_input, prev_hidden)

            # More bookkeeping
            output = output.contiguous()
            outputs.append(output)
            hidden = tuple(h.detach() for h in hidden)
            current_hidden.append(hidden)

            # Apply dropout
            if layer == self.num_layers - 1:
                current_input = self.locked_dropout(output, self.dropout)
                dropped_outputs.append(output)
            else:
                current_input = self.locked_dropout(output, self.dropouth)
                dropped_outputs.append(current_input)

        # Compute logits and loss
        logits = self.decoder(current_input)
        loss = sequence_cross_entropy_with_logits(logits, target.contiguous(),
                                                  target_mask,
                                                  average="token")
        num_tokens = target_mask.float().sum() + 1e-13

        # Activation regularization
        if self.alpha:
            loss = loss + self.alpha * current_input.pow(2).mean()
        # Temporal activation regularization (slowness)
        if self.beta:
            loss = loss + self.beta * (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update metrics and state
        unks = target.eq(self._unk_index)
        unk_penalty = self._unk_penalty * unks.float().sum()

        self.ppl(loss * num_tokens, num_tokens)
        self.upp(loss * num_tokens + unk_penalty, num_tokens)
        self._state = {'layer_%i' % l: h for l, h in enumerate(current_hidden)}

        return {'loss': loss}