def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, ) -> torch.LongTensor: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp = sorted_hyps.pop()[1] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) best.append(best_hyp) # prepare for adding eos sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) decoded: torch.LongTensor = input_ids.new( batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, :sent_lengths[i]] = hypo if sent_lengths[i] < self.max_length: decoded[i, sent_lengths[i]] = eos_token_id return decoded
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0])
def _forward(self, xs: torch.FloatTensor, ilens: torch.LongTensor, olens: torch.LongTensor = None, ds: torch.LongTensor = None, ps: torch.FloatTensor = None, es: torch.FloatTensor = None, in_masks: torch.LongTensor = None, out_masks: torch.LongTensor = None, is_inference: bool = False): x_masks = self._source_mask(ilens) hs, _ = self.encoder.forward(xs, x_masks) # ignore spk embedding d_masks = ~in_masks if in_masks is not None else None v_masks = ~out_masks if out_masks is not None else None if is_inference: hs, d_outs, p_outs, e_outs = self.variance_adaptor.inference( hs, ilens, d_masks, v_masks) else: hs, d_outs, p_outs, e_outs = self.variance_adaptor.forward( hs, ds, ilens, ps, es, d_masks, v_masks) # forward decoder if olens is not None: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder.forward(hs, h_masks) before_outs = self.feat_out.forward(zs).view(zs.shape[0], -1, self.odim) # postnet if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) if is_inference: return before_outs, after_outs else: return before_outs, after_outs, d_outs, p_outs, e_outs
def finalize(self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor): batch_size = len(self._beam_hyps) device = input_ids.device # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append to lists best.append(best_hyp) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score return { "sequences": torch.cat(best, dim=0).view(len(best), -1), "sequence_scores": best_scores }
def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams ids_collect = [] for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) if completes_constraint: beam_hyp.add(final_tokens, final_score) ids_collect.append(beam_id) # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful # generation. In these cases we simply return the highest scoring outputs. if len(ids_collect) < self.num_beam_hyps_to_keep: for beam_id in range(self.num_beams): if beam_id not in ids_collect: batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) if len(ids_collect) >= self.num_beam_hyps_to_keep: break # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append to lists best.append(best_hyp) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { "sequences": decoded, "sequence_scores": best_scores, } )
def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] best_indices = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] best_index = best_hyp_tuple[2] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append hyp to lists best.append(best_hyp) # append indices to list best_indices.append(best_index) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) if len(best_indices) > 0 and best_indices[0] is not None: indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) else: indices = None # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) if indices is not None: indices.fill_(-1) # fill with hypotheses and eos_token_id if the latter fits in for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): decoded[i, : sent_lengths[i]] = hypo if indices is not None: indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { "sequences": decoded, "sequence_scores": best_scores, "beam_indices": indices, } )
def compute_partial_decoded_loss( self, batch: Batch, latent: torch.Tensor, encoder_states: Tuple[torch.Tensor, ...], cand_vecs: torch.LongTensor, label_inds: torch.LongTensor, ) -> torch.Tensor: """ Compute partial loss from decoding outputs. Here, we consider each partially decoded sequence as a separate item from which to compute multiobjective scores. :param batch: batch being considered :param latent: decoder output representations :param encoder_states: encoder output representations :param cand_vecs: character candidate vectors :param label_inds: list of indices indicating which character is correct in the character candidates :return partial_loss: return loss for each batch item as a sum of the partial losses. """ assert self.opt['multiobjective_latent_representation'] == 'decoder_final_layer' assert latent.dim() == 3 and latent.size(0) == cand_vecs.size(0) bsz, seq_len, dim = latent.size() seq_lens = [] partial_char_losses = [] seq_scores = [] stride_length = 2 for stride in range(0, bsz, stride_length): # arbitrary stride for now # Compute new batches of items; latent reps, candidate vectors, etc. end_idx = min(stride + stride_length, bsz) new_bsz = batch.label_vec[stride:end_idx].ne(self.NULL_IDX).sum().item() new_latent = latent.new(new_bsz, seq_len, dim).fill_(0) new_cand_vecs = cand_vecs.new(new_bsz, *cand_vecs.shape[1:]).fill_( self.NULL_IDX ) if new_cand_vecs.dim() == 2: new_cand_vecs = new_cand_vecs.unsqueeze(1).repeat( 1, cand_vecs.size(0), 1 ) new_label_inds = label_inds[stride:end_idx].new(new_bsz).fill_(0) # For each batch item in the stride, we compute seq_length examples # where each example represents a partial output of the decoder. offset = 0 for i in range(stride, end_idx): cand_vecs_i = cand_vecs if cand_vecs.dim() == 2 else cand_vecs[i] seq_len_i = batch.label_vec[i].ne(self.NULL_IDX).sum().item() seq_lens.append(seq_len_i) for j in range(seq_len_i): new_latent[offset + j, 0 : j + 1, :] = latent[ i : i + 1, 0 : j + 1, : ] new_cand_vecs[offset : offset + seq_len_i] = cand_vecs_i new_label_inds[offset : offset + seq_len_i] = label_inds[ i : i + 1 ].repeat(seq_len_i) offset += seq_len_i assert isinstance(new_cand_vecs, torch.LongTensor) seq_score = self.get_multiobjective_output( new_latent, encoder_states, new_cand_vecs, 'partial' ) partial_char_losses.append( self.multiobj_criterion(seq_score, new_label_inds) ) seq_scores.append(seq_score) partial_char_loss = torch.cat(partial_char_losses, dim=0) seq_scores = torch.cat(seq_scores, dim=0) partial_char_loss_metric = partial_char_loss.new(bsz).fill_(0) offset = 0 partial_char_scores = torch.zeros( batch.batchsize, batch.batchsize if cand_vecs.dim() == 2 else cand_vecs.size(1), ).to(latent) for i in range(bsz): partial_char_loss_metric[i] = partial_char_loss[ offset : offset + seq_lens[i] ].mean() partial_char_scores[i] = seq_scores[ partial_char_loss[offset : offset + seq_lens[i]].argmin() ] self.compute_multiobj_metrics( partial_char_loss_metric, partial_char_scores, label_inds, prefix='partial' ) return partial_char_loss
def forward(self, xs: torch.FloatTensor, ilens: torch.LongTensor, ys: torch.FloatTensor, olens: torch.LongTensor, ds: torch.FloatTensor, ps: torch.FloatTensor, es: torch.FloatTensor): # rm padded part xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] ds = ds[:, :max(ilens)] ps = ps[:, :max(olens)] es = es[:, :max(olens)] in_masks = make_non_pad_mask(ilens).to(xs.device) out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) # ignore spk embedding before_outs, after_outs, d_outs, p_outs, e_outs = \ self._forward(xs, ilens, olens, ds, ps, es, in_masks=in_masks, out_masks=out_masks, is_inference=False) if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] if self.use_masking: d_outs = d_outs.masked_select(in_masks) ds = ds.masked_select(in_masks) before_outs = before_outs.masked_select(out_masks) after_outs = after_outs.masked_select(out_masks) ys = ys.masked_select(out_masks) p_outs = p_outs.masked_select(out_masks) e_outs = e_outs.masked_select(out_masks) ps = ps.masked_select(out_masks) es = es.masked_select(out_masks) # calculate loss if self.postnet is None: l1_loss = F.l1_loss(after_outs, ys) else: l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) loss = l1_loss + duration_loss + pitch_loss + energy_loss # report loss report_keys = [{ "l1_loss": l1_loss.item() }, { "duration_loss": duration_loss.item() }, { "pitch_loss": pitch_loss.item() }, { "energy_loss": energy_loss.item() }, { "loss": loss.item() }] if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss
def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, callback_handle: Optional = None, **model_kwargs, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) if callback_handle is not None: callback_handle(sorted_hyps, i, **model_kwargs) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append to lists best.append(best_hyp) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) decoded: torch.LongTensor = input_ids.new( batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, :sent_lengths[i]] = hypo if sent_lengths[i] < self.max_length: decoded[i, sent_lengths[i]] = eos_token_id return UserDict({ "sequences": decoded, "sequence_scores": best_scores, })
def forward(self, input_seq: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: """Runs the Transformer. The Transformer expects both an input as well as a target sequence to be provided, and yields a probability distribution over all possible output tokens for each position in the target sequence. Args: input_seq (torch.LongTensor): The input sequence as (batch-size x input-seq-len)-tensor. target (torch.LongTensor): The target sequence as (batch-size x target-seq-len)-tensor. Returns: torch.FloatTensor: The computed probabilities for each position in ``target`` as a (batch-size x target-seq-len x output-size)-tensor. """ # sanitize args if not isinstance(input_seq, torch.LongTensor) and not isinstance( input_seq, torch.cuda.LongTensor): raise TypeError("<input_seq> has to be a LongTensor!") if input_seq.dim() != 2: raise ValueError("<input_seq> has to have 2 dimensions!") if not isinstance(target, torch.LongTensor) and not isinstance( target, torch.cuda.LongTensor): raise TypeError("<target> has to be a LongTensor!") if target.dim() != 2: raise ValueError("<target> has to have 2 dimensions!") # create a tensor of indices, which is used to retrieve the according positional embeddings below index_seq = input_seq.new(range( input_seq.size(1))).unsqueeze(0).expand(input_seq.size(0), -1) # create padding mask for input padding_mask = util.create_padding_mask(input_seq, self._pad_index) # embed the provided input input_seq = self._word_emb(input_seq) + self._positional_emb(index_seq) # project input to the needed size input_seq = self._input_projection(input_seq) # run the encoder input_seq = self._encoder(input_seq, padding_mask=padding_mask) # create a tensor of indices, which is used to retrieve the positional embeddings for the targets below index_seq = target.new(range(target.size(1))).unsqueeze(0).expand( target.size(0), -1) # embed the provided targets target = self._word_emb(target) + self._positional_emb(index_seq) # project target to the needed size target = self._input_projection(target) # run the decoder output = self._decoder(input_seq, target, padding_mask=padding_mask) # project output to the needed size output = self._output_projection(output) # compute softmax return functional.softmax(output, dim=2)
def forward(self, batch: torch.LongTensor) -> torch.FloatTensor: """Computes the loss function. Args: batch (torch.LongTensor): A batch of training data, as (batch-size x max-seq-len)-tensor. Returns: torch.FloatTensor: The computed loss. """ # sanitize args insanity.sanitize_type("batch", batch, torch.Tensor) if batch.dtype != torch.int64: raise TypeError("<batch> has to be a LongTensor!") if batch.dim() != 2: raise ValueError("<batch> has to be a 2d tensor!") # create the padding mask to use padding_mask = util.create_padding_mask(batch, self._pad_index) # create a tensor of indices, which is used to retrieve the according positional embeddings below index_seq = batch.new(range(batch.size(1))).unsqueeze(0).expand(batch.size(0), -1) # compute the sequence lengths for all samples in the batch seq_len = (batch != self._pad_index).sum(dim=1).cpu().numpy().tolist() # randomly choose the tokens to compute predictions for pred_mask = padding_mask.new(*batch.size()).zero_().long() # all tokens being predicted mask_mask = padding_mask.new(*batch.size()).zero_().long() # token replaced with <MASK> random_mask = padding_mask.new(*batch.size()).zero_().long() # tokens replace with random tokens for sample_idx, sample_len in enumerate(seq_len): # iterate over all samples in the batch # determine how many tokens to computed predictions for num_pred = int(math.ceil(sample_len * self._prediction_rate)) # num of tokens predictions are computed for num_mask = int(math.floor(num_pred * self._mask_rate)) # num of tokens replaced with <MASK> num_random = int(math.ceil(num_pred * self._random_rate)) # num of tokens randomly replaced # randomly select indices to compute predictions for pred_indices = list(range(sample_len)) random.shuffle(pred_indices) pred_indices = pred_indices[:num_pred] # prepare the <MASK>-mask for token_idx in pred_indices[:num_mask]: pred_mask[sample_idx, token_idx] = 1 mask_mask[sample_idx, token_idx] = 1 # prepare the random-mask for token_idx in pred_indices[num_mask:(num_mask + num_random)]: pred_mask[sample_idx, token_idx] = 1 random_mask[sample_idx, token_idx] = 1 # remaining tokens that predictions are computed for are left untouched for token_idx in pred_indices[(num_mask + num_random):]: pred_mask[sample_idx, token_idx] = 1 # replace predicted tokens in the batch appropriately masked_batch = ( batch * (1 - mask_mask) * (1 - random_mask) + mask_mask * batch.new(*batch.size()).fill_(self._mask_index) + random_mask * (batch.new(*batch.size()).double().uniform_() * self._word_emb.num_embeddings).long() ) # embed the batch masked_batch = self._word_emb(masked_batch) + self._pos_emb(index_seq) # encode sequence in the batch using BERT enc = self._model(masked_batch, padding_mask) # turn encodings, the target token indices (that we seek to predict), and the prediction mask, into matrices, # such that each row corresponds with one token enc = enc.view(enc.size(0) * enc.size(1), enc.size(2)) target = batch.view(-1) pred_mask = pred_mask.view(-1) # turn the prediction mask into a tensor of indices (to select below) pred_mask = pred_mask.new(np.where(pred_mask.detach().cpu().numpy())[0]) # fetch embeddings and target values of those tokens that are being predicted enc = enc.index_select(0, pred_mask) target = target.index_select(0, pred_mask) # compute predictions for each encoded token + the according loss pred = self._output_layer(enc) loss = self._loss(pred, target) return loss
def finalize(self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, mems=None) -> Tuple[torch.LongTensor, List[torch.Tensor]]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # need to add best num_beams hypotheses to generated hyps for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): score, best_hyp, mems = sorted_hyps.pop() sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) best.append((best_hyp, mems, score)) # prepare for adding eos sent_max_len = min(sent_lengths.max().item(), self.max_length) decoded: torch.LongTensor = input_ids.new( batch_size * self.num_beam_hyps_to_keep, sent_max_len) scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) # fill with hypotheses and eos_token_id if the latter fits in mems = [] for i, (hypo, mem, score) in enumerate(best): scores[i] = score decoded[i, :sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id mems.append(mem) mems = [ torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0])) ] if mems and mems[0] else None return decoded, mems, scores
def sample_output( model: transformer.Transformer, input_seq: torch.LongTensor, eos_index: int, pad_index: int, max_len: int ) -> torch.LongTensor: """Samples an output sequence based on the provided input. Args: model (:class:`transformer.Transformer`): The model to use. input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a (batch-size x input-seq-len)-tensor. eos_index (int): The index that indicates the end of a sequence. pad_index (int): The index that indicates a padding token in a sequence. max_len (int): The maximum length of the generated output. Returns: torch.LongTensor: The generated output sequence as (batch-size x output-seq-len)-tensor. """ # sanitize args if not isinstance(model, transformer.Transformer): raise TypeError("The <model> has to be a transformer.Transformer!") if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor): raise TypeError("The <input_seq> has to be a LongTensor!") if input_seq.dim() != 2: raise ValueError("<input_seq> has to be a matrix!") if not isinstance(eos_index, int): raise TypeError("The <eos_index> has to be an integer!") if eos_index < 0 or eos_index >= model.output_size: raise ValueError("The <eos_index> is not a legal index in the vocabulary used by <model>!") if not isinstance(pad_index, int): raise TypeError("The <pad_index> has to be an integer!") if pad_index < 0 or pad_index >= model.output_size: raise ValueError("The <pad_index> is not a legal index in the vocabulary used by <model>!") if max_len is not None: if not isinstance(max_len, int): raise TypeError("<max_len> has to be an integer!") if max_len < 1: raise ValueError("<max_len> has to be > 0!") original_mode = model.training # the original mode (train/eval) of the provided model batch_size = input_seq.size(0) # number of samples in the provided input sequence # put model in evaluation mode model.eval() output_seq = [] # used to store the generated outputs for each position finished = [False] * batch_size for _ in range(max_len): # prepare the target to provide to the model # this is the current output with an additional final entry that is supposed to be predicted next # (which is why the concrete value does not matter) current_target = torch.cat(output_seq + [input_seq.new(batch_size, 1).zero_()], dim=1) # run the model probs = model(input_seq, current_target)[:, -1, :] # sample next output form the computed probabilities output = torch.multinomial(probs, 1) # determine which samples have been finished, and replace sampled output with padding for those that are already for sample_idx in range(batch_size): if finished[sample_idx]: output[sample_idx, 0] = pad_index elif output[sample_idx, 0].item() == eos_index: finished[sample_idx] = True # store created output output_seq.append(output) # check whether generation has been finished if all(finished): break # restore original mode of the model model.train(mode=original_mode) return torch.cat(output_seq, dim=1)