def _prepare_batch_for_alignment(self, sample, hypothesis):
     src_tokens = sample['net_input']['src_tokens']
     bsz = src_tokens.shape[0]
     src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1)
     src_lengths = sample['net_input']['src_lengths']
     src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size)
     prev_output_tokens = data_utils.collate_tokens(
         [beam['tokens'] for example in hypothesis for beam in example],
         self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True,
     )
     tgt_tokens = data_utils.collate_tokens(
         [beam['tokens'] for example in hypothesis for beam in example],
         self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False,
     )
     return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
Beispiel #2
0
 def merge(key, left_pad, move_eos_to_beginning=False):
     if key == 'path':
         random_ids = random.sample(range(PATH_NUM),
                                    kwargs.get('num_sample'))
         batch_terminals = [s['path.terminals'] for s in samples]
         head = [[terminals[2 * idx] for idx in random_ids]
                 for terminals in batch_terminals]
         tail = [[terminals[2 * idx + 1] for idx in random_ids]
                 for terminals in batch_terminals]
         batch_paths = [s['path'] for s in samples]
         body = [[path[idx] for idx in random_ids] for path in batch_paths]
         return data_utils.collate_paths(
             head,
             body,
             tail,
             pad_idx,
             eos_idx,
             left_pad,
             move_eos_to_beginning,
         )
     else:
         return data_utils.collate_tokens(
             [s[key] for s in samples],
             pad_idx,
             eos_idx,
             left_pad,
             move_eos_to_beginning,
         )
Beispiel #3
0
 def merge(key, left_pad, move_eos_to_beginning=False):
     return data_utils.collate_tokens(
         [s[key] for s in samples],
         pad_idx,
         eos_idx,
         left_pad,
         move_eos_to_beginning,
     )
Beispiel #4
0
 def merge(key):
     return data_utils.collate_tokens([s[key] for s in samples], pad_idx)
Beispiel #5
0
 def collater(self, samples):
     return data_utils.collate_tokens(samples,
                                      self.pad_idx,
                                      left_pad=self.left_pad)