def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
    ) -> torch.LongTensor:
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx]:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_hyp.add(final_tokens, final_score)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[self.num_beam_hyps_to_keep * i +
                             j] = len(best_hyp)
                best.append(best_hyp)

        # prepare for adding eos
        sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
        decoded: torch.LongTensor = input_ids.new(
            batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, hypo in enumerate(best):
            decoded[i, :sent_lengths[i]] = hypo
            if sent_lengths[i] < self.max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
        return decoded
示例#2
0
 def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(1):
         return 0
     predicted_tensor = targets.new(predicted)
     targets_trimmed = targets[:, :len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0])
示例#3
0
    def _forward(self,
                 xs: torch.FloatTensor,
                 ilens: torch.LongTensor,
                 olens: torch.LongTensor = None,
                 ds: torch.LongTensor = None,
                 ps: torch.FloatTensor = None,
                 es: torch.FloatTensor = None,
                 in_masks: torch.LongTensor = None,
                 out_masks: torch.LongTensor = None,
                 is_inference: bool = False):
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder.forward(xs, x_masks)

        # ignore spk embedding

        d_masks = ~in_masks if in_masks is not None else None
        v_masks = ~out_masks if out_masks is not None else None
        if is_inference:
            hs, d_outs, p_outs, e_outs = self.variance_adaptor.inference(
                hs, ilens, d_masks, v_masks)
        else:
            hs, d_outs, p_outs, e_outs = self.variance_adaptor.forward(
                hs, ds, ilens, ps, es, d_masks, v_masks)

        # forward decoder
        if olens is not None:
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            h_masks = self._source_mask(olens_in)
        else:
            h_masks = None
        zs, _ = self.decoder.forward(hs, h_masks)
        before_outs = self.feat_out.forward(zs).view(zs.shape[0], -1,
                                                     self.odim)

        # postnet
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        if is_inference:
            return before_outs, after_outs
        else:
            return before_outs, after_outs, d_outs, p_outs, e_outs
示例#4
0
    def finalize(self, input_ids: torch.LongTensor,
                 final_beam_scores: torch.FloatTensor):
        batch_size = len(self._beam_hyps)
        device = input_ids.device

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_hyp.add(final_tokens, final_score)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep,
                                  device=device,
                                  dtype=torch.float32)

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                sent_lengths[self.num_beam_hyps_to_keep * i +
                             j] = len(best_hyp)

                # append to lists
                best.append(best_hyp)
                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        return {
            "sequences": torch.cat(best, dim=0).view(len(best), -1),
            "sequence_scores": best_scores
        }
示例#5
0
    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        max_length: int,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
    ) -> Tuple[torch.LongTensor]:
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx]:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams

            ids_collect = []
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]

                completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
                if completes_constraint:
                    beam_hyp.add(final_tokens, final_score)
                    ids_collect.append(beam_id)

            # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
            # generation. In these cases we simply return the highest scoring outputs.
            if len(ids_collect) < self.num_beam_hyps_to_keep:
                for beam_id in range(self.num_beams):
                    if beam_id not in ids_collect:
                        batch_beam_idx = batch_idx * self.num_beams + beam_id
                        final_score = final_beam_scores[batch_beam_idx].item()
                        final_tokens = input_ids[batch_beam_idx]
                        beam_hyp.add(final_tokens, final_score)
                    if len(ids_collect) >= self.num_beam_hyps_to_keep:
                        break

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

                # append to lists
                best.append(best_hyp)
                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # prepare for adding eos
        sent_lengths_max = sent_lengths.max().item() + 1

        sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, hypo in enumerate(best):
            decoded[i, : sent_lengths[i]] = hypo
            if sent_lengths[i] < sent_max_len:
                decoded[i, sent_lengths[i]] = eos_token_id

        return UserDict(
            {
                "sequences": decoded,
                "sequence_scores": best_scores,
            }
        )
示例#6
0
    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        max_length: int,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        beam_indices: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.LongTensor]:
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx]:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
                beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_indices = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                best_index = best_hyp_tuple[2]
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

                # append hyp to lists
                best.append(best_hyp)

                # append indices to list
                best_indices.append(best_index)

                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # prepare for adding eos
        sent_lengths_max = sent_lengths.max().item() + 1
        sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

        if len(best_indices) > 0 and best_indices[0] is not None:
            indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        else:
            indices = None

        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        if indices is not None:
            indices.fill_(-1)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
            decoded[i, : sent_lengths[i]] = hypo

            if indices is not None:
                indices[i, : len(best_idx)] = torch.tensor(best_idx)

            if sent_lengths[i] < sent_max_len:
                decoded[i, sent_lengths[i]] = eos_token_id

        return UserDict(
            {
                "sequences": decoded,
                "sequence_scores": best_scores,
                "beam_indices": indices,
            }
        )
示例#7
0
    def compute_partial_decoded_loss(
        self,
        batch: Batch,
        latent: torch.Tensor,
        encoder_states: Tuple[torch.Tensor, ...],
        cand_vecs: torch.LongTensor,
        label_inds: torch.LongTensor,
    ) -> torch.Tensor:
        """
        Compute partial loss from decoding outputs.

        Here, we consider each partially decoded sequence as a separate
        item from which to compute multiobjective scores.

        :param batch:
            batch being considered
        :param latent:
            decoder output representations
        :param encoder_states:
            encoder output representations
        :param cand_vecs:
            character candidate vectors
        :param label_inds:
            list of indices indicating which character is correct in the character candidates

        :return partial_loss:
            return loss for each batch item as a sum of the partial losses.
        """
        assert self.opt['multiobjective_latent_representation'] == 'decoder_final_layer'
        assert latent.dim() == 3 and latent.size(0) == cand_vecs.size(0)
        bsz, seq_len, dim = latent.size()
        seq_lens = []
        partial_char_losses = []
        seq_scores = []
        stride_length = 2
        for stride in range(0, bsz, stride_length):  # arbitrary stride for now
            # Compute new batches of items; latent reps, candidate vectors, etc.
            end_idx = min(stride + stride_length, bsz)
            new_bsz = batch.label_vec[stride:end_idx].ne(self.NULL_IDX).sum().item()
            new_latent = latent.new(new_bsz, seq_len, dim).fill_(0)
            new_cand_vecs = cand_vecs.new(new_bsz, *cand_vecs.shape[1:]).fill_(
                self.NULL_IDX
            )
            if new_cand_vecs.dim() == 2:
                new_cand_vecs = new_cand_vecs.unsqueeze(1).repeat(
                    1, cand_vecs.size(0), 1
                )
            new_label_inds = label_inds[stride:end_idx].new(new_bsz).fill_(0)

            # For each batch item in the stride, we compute seq_length examples
            # where each example represents a partial output of the decoder.
            offset = 0
            for i in range(stride, end_idx):
                cand_vecs_i = cand_vecs if cand_vecs.dim() == 2 else cand_vecs[i]
                seq_len_i = batch.label_vec[i].ne(self.NULL_IDX).sum().item()
                seq_lens.append(seq_len_i)
                for j in range(seq_len_i):
                    new_latent[offset + j, 0 : j + 1, :] = latent[
                        i : i + 1, 0 : j + 1, :
                    ]
                new_cand_vecs[offset : offset + seq_len_i] = cand_vecs_i
                new_label_inds[offset : offset + seq_len_i] = label_inds[
                    i : i + 1
                ].repeat(seq_len_i)
                offset += seq_len_i

            assert isinstance(new_cand_vecs, torch.LongTensor)
            seq_score = self.get_multiobjective_output(
                new_latent, encoder_states, new_cand_vecs, 'partial'
            )
            partial_char_losses.append(
                self.multiobj_criterion(seq_score, new_label_inds)
            )
            seq_scores.append(seq_score)
        partial_char_loss = torch.cat(partial_char_losses, dim=0)
        seq_scores = torch.cat(seq_scores, dim=0)
        partial_char_loss_metric = partial_char_loss.new(bsz).fill_(0)
        offset = 0
        partial_char_scores = torch.zeros(
            batch.batchsize,
            batch.batchsize if cand_vecs.dim() == 2 else cand_vecs.size(1),
        ).to(latent)
        for i in range(bsz):
            partial_char_loss_metric[i] = partial_char_loss[
                offset : offset + seq_lens[i]
            ].mean()
            partial_char_scores[i] = seq_scores[
                partial_char_loss[offset : offset + seq_lens[i]].argmin()
            ]
        self.compute_multiobj_metrics(
            partial_char_loss_metric, partial_char_scores, label_inds, prefix='partial'
        )
        return partial_char_loss
示例#8
0
    def forward(self, xs: torch.FloatTensor, ilens: torch.LongTensor,
                ys: torch.FloatTensor, olens: torch.LongTensor,
                ds: torch.FloatTensor, ps: torch.FloatTensor,
                es: torch.FloatTensor):
        # rm padded part
        xs = xs[:, :max(ilens)]
        ys = ys[:, :max(olens)]
        ds = ds[:, :max(ilens)]
        ps = ps[:, :max(olens)]
        es = es[:, :max(olens)]

        in_masks = make_non_pad_mask(ilens).to(xs.device)
        out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
        # ignore spk embedding

        before_outs, after_outs, d_outs, p_outs, e_outs = \
            self._forward(xs, ilens, olens, ds, ps, es, in_masks=in_masks, out_masks=out_masks, is_inference=False)

        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]

        if self.use_masking:
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            before_outs = before_outs.masked_select(out_masks)
            after_outs = after_outs.masked_select(out_masks)
            ys = ys.masked_select(out_masks)
            p_outs = p_outs.masked_select(out_masks)
            e_outs = e_outs.masked_select(out_masks)
            ps = ps.masked_select(out_masks)
            es = es.masked_select(out_masks)

        # calculate loss
        if self.postnet is None:
            l1_loss = F.l1_loss(after_outs, ys)
        else:
            l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
        duration_loss = self.duration_criterion(d_outs, ds)
        pitch_loss = self.mse_criterion(p_outs, ps)
        energy_loss = self.mse_criterion(e_outs, es)

        loss = l1_loss + duration_loss + pitch_loss + energy_loss
        # report loss
        report_keys = [{
            "l1_loss": l1_loss.item()
        }, {
            "duration_loss": duration_loss.item()
        }, {
            "pitch_loss": pitch_loss.item()
        }, {
            "energy_loss": energy_loss.item()
        }, {
            "loss": loss.item()
        }]

        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        self.reporter.report(report_keys)
        return loss
    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        callback_handle: Optional = None,
        **model_kwargs,
    ) -> Tuple[torch.LongTensor]:
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx]:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_hyp.add(final_tokens, final_score)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep,
                                  device=self.device,
                                  dtype=torch.float32)

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            if callback_handle is not None:
                callback_handle(sorted_hyps, i, **model_kwargs)
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                sent_lengths[self.num_beam_hyps_to_keep * i +
                             j] = len(best_hyp)

                # append to lists
                best.append(best_hyp)
                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # prepare for adding eos
        sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
        decoded: torch.LongTensor = input_ids.new(
            batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, hypo in enumerate(best):
            decoded[i, :sent_lengths[i]] = hypo
            if sent_lengths[i] < self.max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
        return UserDict({
            "sequences": decoded,
            "sequence_scores": best_scores,
        })
示例#10
0
    def forward(self, input_seq: torch.LongTensor,
                target: torch.LongTensor) -> torch.FloatTensor:
        """Runs the Transformer.
        
        The Transformer expects both an input as well as a target sequence to be provided, and yields a probability
        distribution over all possible output tokens for each position in the target sequence.
        
        Args:
            input_seq (torch.LongTensor): The input sequence as (batch-size x input-seq-len)-tensor.
            target (torch.LongTensor): The target sequence as (batch-size x target-seq-len)-tensor.
        
        Returns:
            torch.FloatTensor: The computed probabilities for each position in ``target`` as a
                (batch-size x target-seq-len x output-size)-tensor.
        """
        # sanitize args
        if not isinstance(input_seq, torch.LongTensor) and not isinstance(
                input_seq, torch.cuda.LongTensor):
            raise TypeError("<input_seq> has to be a LongTensor!")
        if input_seq.dim() != 2:
            raise ValueError("<input_seq> has to have 2 dimensions!")
        if not isinstance(target, torch.LongTensor) and not isinstance(
                target, torch.cuda.LongTensor):
            raise TypeError("<target> has to be a LongTensor!")
        if target.dim() != 2:
            raise ValueError("<target> has to have 2 dimensions!")

        # create a tensor of indices, which is used to retrieve the according positional embeddings below
        index_seq = input_seq.new(range(
            input_seq.size(1))).unsqueeze(0).expand(input_seq.size(0), -1)

        # create padding mask for input
        padding_mask = util.create_padding_mask(input_seq, self._pad_index)

        # embed the provided input
        input_seq = self._word_emb(input_seq) + self._positional_emb(index_seq)

        # project input to the needed size
        input_seq = self._input_projection(input_seq)

        # run the encoder
        input_seq = self._encoder(input_seq, padding_mask=padding_mask)

        # create a tensor of indices, which is used to retrieve the positional embeddings for the targets below
        index_seq = target.new(range(target.size(1))).unsqueeze(0).expand(
            target.size(0), -1)

        # embed the provided targets
        target = self._word_emb(target) + self._positional_emb(index_seq)

        # project target to the needed size
        target = self._input_projection(target)

        # run the decoder
        output = self._decoder(input_seq, target, padding_mask=padding_mask)

        # project output to the needed size
        output = self._output_projection(output)

        # compute softmax
        return functional.softmax(output, dim=2)
示例#11
0
    def forward(self, batch: torch.LongTensor) -> torch.FloatTensor:
        """Computes the loss function.

        Args:
            batch (torch.LongTensor): A batch of training data, as (batch-size x max-seq-len)-tensor.

        Returns:
            torch.FloatTensor: The computed loss.
        """
        # sanitize args
        insanity.sanitize_type("batch", batch, torch.Tensor)
        if batch.dtype != torch.int64:
            raise TypeError("<batch> has to be a LongTensor!")
        if batch.dim() != 2:
            raise ValueError("<batch> has to be a 2d tensor!")
        
        # create the padding mask to use
        padding_mask = util.create_padding_mask(batch, self._pad_index)
        
        # create a tensor of indices, which is used to retrieve the according positional embeddings below
        index_seq = batch.new(range(batch.size(1))).unsqueeze(0).expand(batch.size(0), -1)
        
        # compute the sequence lengths for all samples in the batch
        seq_len = (batch != self._pad_index).sum(dim=1).cpu().numpy().tolist()
        
        # randomly choose the tokens to compute predictions for
        pred_mask = padding_mask.new(*batch.size()).zero_().long()  # all tokens being predicted
        mask_mask = padding_mask.new(*batch.size()).zero_().long()  # token replaced with <MASK>
        random_mask = padding_mask.new(*batch.size()).zero_().long()  # tokens replace with random tokens
        for sample_idx, sample_len in enumerate(seq_len):  # iterate over all samples in the batch
            
            # determine how many tokens to computed predictions for
            num_pred = int(math.ceil(sample_len * self._prediction_rate))  # num of tokens predictions are computed for
            num_mask = int(math.floor(num_pred * self._mask_rate))  # num of tokens replaced with <MASK>
            num_random = int(math.ceil(num_pred * self._random_rate))  # num of tokens randomly replaced
            
            # randomly select indices to compute predictions for
            pred_indices = list(range(sample_len))
            random.shuffle(pred_indices)
            pred_indices = pred_indices[:num_pred]
            
            # prepare the <MASK>-mask
            for token_idx in pred_indices[:num_mask]:
                pred_mask[sample_idx, token_idx] = 1
                mask_mask[sample_idx, token_idx] = 1
            
            # prepare the random-mask
            for token_idx in pred_indices[num_mask:(num_mask + num_random)]:
                pred_mask[sample_idx, token_idx] = 1
                random_mask[sample_idx, token_idx] = 1
            
            # remaining tokens that predictions are computed for are left untouched
            for token_idx in pred_indices[(num_mask + num_random):]:
                pred_mask[sample_idx, token_idx] = 1
        
        # replace predicted tokens in the batch appropriately
        masked_batch = (
                batch * (1 - mask_mask) * (1 - random_mask) +
                mask_mask * batch.new(*batch.size()).fill_(self._mask_index) +
                random_mask * (batch.new(*batch.size()).double().uniform_() * self._word_emb.num_embeddings).long()
        )
        
        # embed the batch
        masked_batch = self._word_emb(masked_batch) + self._pos_emb(index_seq)
        
        # encode sequence in the batch using BERT
        enc = self._model(masked_batch, padding_mask)
        
        # turn encodings, the target token indices (that we seek to predict), and the prediction mask, into matrices,
        # such that each row corresponds with one token
        enc = enc.view(enc.size(0) * enc.size(1), enc.size(2))
        target = batch.view(-1)
        pred_mask = pred_mask.view(-1)
        
        # turn the prediction mask into a tensor of indices (to select below)
        pred_mask = pred_mask.new(np.where(pred_mask.detach().cpu().numpy())[0])
        
        # fetch embeddings and target values of those tokens that are being predicted
        enc = enc.index_select(0, pred_mask)
        target = target.index_select(0, pred_mask)
        
        # compute predictions for each encoded token + the according loss
        pred = self._output_layer(enc)
        loss = self._loss(pred, target)
        
        return loss
示例#12
0
    def finalize(self,
                 input_ids: torch.LongTensor,
                 final_beam_scores: torch.FloatTensor,
                 final_beam_tokens: torch.LongTensor,
                 final_beam_indices: torch.LongTensor,
                 pad_token_id: Optional[int] = None,
                 eos_token_id: Optional[int] = None,
                 mems=None) -> Tuple[torch.LongTensor, List[torch.Tensor]]:
        batch_size = len(self._beam_hyps)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_idx]:
                continue

            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(self.num_beams):
                batch_beam_idx = batch_idx * self.num_beams + beam_id
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_hyp.add(final_tokens,
                             final_score,
                             mems=[mem[[batch_beam_idx]]
                                   for mem in mems] if mems else None)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []

        # retrieve best hypotheses
        for i, beam_hyp in enumerate(self._beam_hyps):
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                score, best_hyp, mems = sorted_hyps.pop()
                sent_lengths[self.num_beam_hyps_to_keep * i +
                             j] = len(best_hyp)
                best.append((best_hyp, mems, score))

        # prepare for adding eos
        sent_max_len = min(sent_lengths.max().item(), self.max_length)
        decoded: torch.LongTensor = input_ids.new(
            batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        # fill with hypotheses and eos_token_id if the latter fits in
        mems = []
        for i, (hypo, mem, score) in enumerate(best):
            scores[i] = score
            decoded[i, :sent_lengths[i]] = hypo
            if sent_lengths[i] < sent_max_len:
                decoded[i, sent_lengths[i]] = eos_token_id
            mems.append(mem)
        mems = [
            torch.cat([mem[i] for mem in mems], dim=0)
            for i in range(len(mems[0]))
        ] if mems and mems[0] else None
        return decoded, mems, scores
示例#13
0
def sample_output(
        model: transformer.Transformer,
        input_seq: torch.LongTensor,
        eos_index: int,
        pad_index: int,
        max_len: int
) -> torch.LongTensor:
    """Samples an output sequence based on the provided input.
    
    Args:
        model (:class:`transformer.Transformer`): The model to use.
        input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a
            (batch-size x input-seq-len)-tensor.
        eos_index (int): The index that indicates the end of a sequence.
        pad_index (int): The index that indicates a padding token in a sequence.
        max_len (int): The maximum length of the generated output.
    
    Returns:
        torch.LongTensor: The generated output sequence as (batch-size x output-seq-len)-tensor.
    """
    # sanitize args
    if not isinstance(model, transformer.Transformer):
        raise TypeError("The <model> has to be a transformer.Transformer!")
    if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor):
        raise TypeError("The <input_seq> has to be a LongTensor!")
    if input_seq.dim() != 2:
        raise ValueError("<input_seq> has to be a matrix!")
    if not isinstance(eos_index, int):
        raise TypeError("The <eos_index> has to be an integer!")
    if eos_index < 0 or eos_index >= model.output_size:
        raise ValueError("The <eos_index> is not a legal index in the vocabulary used by <model>!")
    if not isinstance(pad_index, int):
        raise TypeError("The <pad_index> has to be an integer!")
    if pad_index < 0 or pad_index >= model.output_size:
        raise ValueError("The <pad_index> is not a legal index in the vocabulary used by <model>!")
    if max_len is not None:
        if not isinstance(max_len, int):
            raise TypeError("<max_len> has to be an integer!")
        if max_len < 1:
            raise ValueError("<max_len> has to be > 0!")
    
    original_mode = model.training  # the original mode (train/eval) of the provided model
    batch_size = input_seq.size(0)  # number of samples in the provided input sequence
    
    # put model in evaluation mode
    model.eval()
    
    output_seq = []  # used to store the generated outputs for each position
    finished = [False] * batch_size
    
    for _ in range(max_len):
        
        # prepare the target to provide to the model
        # this is the current output with an additional final entry that is supposed to be predicted next
        # (which is why the concrete value does not matter)
        current_target = torch.cat(output_seq + [input_seq.new(batch_size, 1).zero_()], dim=1)
        
        # run the model
        probs = model(input_seq, current_target)[:, -1, :]
        
        # sample next output form the computed probabilities
        output = torch.multinomial(probs, 1)
        
        # determine which samples have been finished, and replace sampled output with padding for those that are already
        for sample_idx in range(batch_size):
            if finished[sample_idx]:
                output[sample_idx, 0] = pad_index
            elif output[sample_idx, 0].item() == eos_index:
                finished[sample_idx] = True
        
        # store created output
        output_seq.append(output)
        
        # check whether generation has been finished
        if all(finished):
            break
    
    # restore original mode of the model
    model.train(mode=original_mode)
    
    return torch.cat(output_seq, dim=1)