示例#1
0
    def seq2seq_forward_pass(
        self, xs: torch.LongTensor, ys: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[Any, ...]]:
        """
        Simulate a standard seq2seq encoder/decoder forward pass.

        Used in thorough decoding.

        :param xs:
            input tokens
        :param ys:
            teacher forced decoder outputs

        :return (logits, preds, encoder_states):
            logits: token output distribution
            preds: max probability token at each output position
            encoder_states: output states from the encoder
        """
        encoder_states = self.seq2seq_encoder(xs)  # type: ignore
        bsz = ys.size(0)
        seqlen = ys.size(1)
        inputs = ys.narrow(1, 0, seqlen - 1)
        dec_inputs = self._rag_model_interface.get_initial_forced_decoder_input(
            bsz,
            inputs,
            n_docs=1,
            start_idx=self.START_IDX,
            end_idx=self.END_IDX,
            input_turns_cnt=None,
        )
        latent, _ = self.seq2seq_decoder(dec_inputs, encoder_states,
                                         None)  # type: ignore
        logits = self.decoder_output(latent)
        _, preds = logits.max(dim=-1)
        return logits, preds, encoder_states
示例#2
0
    def __init__(self,
                 data: torch.LongTensor,
                 d_batch: int = 10,
                 n_batch_per_test: int = 1,
                 device="cpu",
                 n_ext_ctx=None):
        """
            input -- LongTensor -- the LongTensor is strictly ordered
        """
        self.d_batch = d_batch
        self.bptt = n_batch_per_test
        self.n_ext_ctx = 0 if n_ext_ctx is None else n_ext_ctx

        self.device = device

        # Work out how cleanly we can divide the dataset into n_batch parts.
        self.n_dates = data.size(0)  # number of training dates
        self.n_batch = self.n_dates // d_batch

        # Trim off any extra elements that wouldn't cleanly fit (stub).
        # Remove the very beginning of the time series - keep the most recent
        data = data.narrow(0, self.n_dates % d_batch, data.size(0))

        # Evenly divide the input across the n_batch batches.
        self.data = data.view(d_batch, -1).t().contiguous().to(device)

        # Number of mini-batches
        self.n_batch = (self.n_batch + self.bptt - 1) // self.bptt
示例#3
0
    def decode_forced(
            self, encoder_states: Tuple[torch.Tensor, ...],
            ys: torch.LongTensor) -> Tuple[torch.Tensor, torch.LongTensor]:
        """
        Decode with a fixed, true sequence, computing loss.

        Override TGM.decode_forced to both:
        1) handle BART eos/bos issues, and
        2) appropriately get forced decoder input.

        :param encoder_states:
            encoder output states
        :param ys:
            teacher forced label

        :return logits, preds:
            logits: output token distribution (as logits, not probs)
            preds: tokens corresponding with max probs according to output distribution.
        """
        bsz = ys.size(0)
        seqlen = ys.size(1)
        inputs = ys.narrow(1, 0, seqlen - 1)
        if (ys[:, 0]
                == self.START_IDX).any() and self.generation_model != 'bart':
            raise AssertionError(
                "The Beginning of Sentence token is automatically added to the "
                "label in decode_forced, but you included it in the label. This means "
                "your model will have a double BOS token, which is probably not what "
                "you intended.")
        doc_scores = encoder_states[-1]

        inputs = self._rag_model_interface.get_initial_forced_decoder_input(
            bsz,
            inputs,
            n_docs=doc_scores.size(1) if doc_scores is not None else None,
            start_idx=self.START_IDX,
            end_idx=self.END_IDX,
            input_turns_cnt=encoder_states[2],
        )
        latent, _ = self.decoder(inputs, encoder_states)
        logits = self.output(latent)
        _, preds = logits.max(dim=-1)
        return logits, preds  # type: ignore
示例#4
0
    def criterion(
        prediction: torch.FloatTensor, target: torch.LongTensor
    ) -> torch.FloatTensor:
        """Transformer encoder decoder loss function.

        Args:
            prediction (torch.FloatTensor): tensor of shape
                `batch_size, seq_len, vocab_size`.
            target (torch.LongTensor): tensor of shape `batch_size, seq_len`.

        Returns:
            torch.FloatTensor: loss value.
        """
        seq_len = target.size(1) - 1

        prediction = prediction.narrow(1, 0, seq_len).flatten(0, 1)
        target = target.narrow(1, 1, seq_len).flatten()

        return nn.functional.cross_entropy(prediction, target, ignore_index=0)
示例#5
0
    def decode_forced(
        self, encoder_states: Tuple[Any], ys: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.LongTensor, torch.Tensor, torch.BoolTensor]:
        """
        Override TGM.decode_forced to return latent states.

        Nearly copied verbatim, except for return type.
        """
        bsz = ys.size(0)
        seqlen = ys.size(1)
        inputs = ys.narrow(1, 0, seqlen - 1)
        if (ys[:, 0] == self.START_IDX).any():
            raise AssertionError(
                "The Beginning of Sentence token is automatically added to the "
                "label in decode_forced, but you included it in the label. This means "
                "your model will have a double BOS token, which is probably not what "
                "you intended."
            )
        inputs = self._get_initial_forced_decoder_input(bsz, inputs)
        latent, mask = self.decoder(inputs, encoder_states)
        logits = self.output(latent)
        _, preds = logits.max(dim=2)
        return logits, preds, latent, mask
    def decode_forced(
            self, encoder_states: Tuple[Any, ...], ys: torch.LongTensor
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        """
        Decode with a fixed, true sequence, computing loss.

        Overriding `TGM.decode_forced` to bypass assertion that BOS is not present, and
        additionally insert EOS as first token
        """
        bsz = ys.size(0)
        seqlen = ys.size(1)
        inputs = ys.narrow(1, 0, seqlen - 1)
        inputs = torch.cat(
            [
                torch.LongTensor([self.END_IDX]).detach().expand(bsz,
                                                                 1).to(inputs),
                inputs,
            ],
            1,
        )
        latent, _ = self.decoder(inputs, encoder_states)
        logits = self.output(latent)
        _, preds = logits.max(dim=2)
        return logits, preds