def test_collect_top_k_probs(self): test_args = test_utils.ModelParamsDict(arch="hybrid_transformer_rnn") _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) self.task = tasks.DictionaryHolderTask(src_dict, tgt_dict) model = self.task.build_model(test_args) use_cuda = torch.cuda.is_available() if use_cuda: model.cuda() model.eval() binarized_source = test_utils.create_dummy_binarized_dataset() binarized_target = test_utils.create_dummy_binarized_dataset( append_eos=True) dataset = language_pair_dataset.LanguagePairDataset( src=binarized_source, src_sizes=binarized_source.sizes, src_dict=self.task.src_dict, tgt=binarized_target, tgt_sizes=binarized_target.sizes, tgt_dict=self.task.dst_dict, left_pad_source=False, ) top_k_scores, top_k_indices = collect_top_k_probs.compute_top_k( task=self.task, models=[model], dataset=dataset, k=3, use_cuda=use_cuda, max_tokens=None, max_sentences=None, progress_bar_args=None, ) batch = language_pair_dataset.collate( [dataset[0]], pad_idx=self.task.src_dict.pad(), eos_idx=self.task.src_dict.eos(), left_pad_source=False, ) sample = batch["net_input"] if use_cuda: sample = utils.move_to_cuda(sample) with torch.no_grad(): net_output = model(**sample) probs = model.get_normalized_probs(net_output, log_probs=False) top_probs, top_indices = torch.topk(probs[0, 0], k=3) if use_cuda: top_probs = top_probs.cpu() top_indices = top_indices.cpu() np.testing.assert_array_equal(top_k_indices[0], top_indices.numpy()) normalized_probs = (top_probs / top_probs.sum()).numpy() np.testing.assert_almost_equal(top_k_scores[0], normalized_probs)
def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `src_lang_id` (LongTensor): a long Tensor which contains source language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `tgt_lang_id` (LongTensor): a long Tensor which contains target language IDs of each sample in the batch """ res = collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: res["net_input"]["src_lang_id"] = torch.LongTensor( [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) if self.tgt_lang_id is not None: res["tgt_lang_id"] = torch.LongTensor( [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) return res
def get_iterator(self, samples, batch_size): batches = [] for i in range(0, math.ceil(len(samples) / batch_size)): sample = samples[i * batch_size:i * batch_size + batch_size] batch = collate(sample, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), left_pad_source=True, state_machine=False) batches.append(batch) return batches
def dummy_dataloader( samples, padding_idx=1, eos_idx=2, batch_size=None, ): if batch_size is None: batch_size = len(samples) # add any missing data to samples for i, sample in enumerate(samples): if "id" not in sample: sample["id"] = i # create dataloader dataset = TestDataset(samples) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)), ) return iter(dataloader)
def dummy_dataloader( samples, padding_idx=1, eos_idx=2, batch_size=None, ): if batch_size is None: batch_size = len(samples) # add any missing data to samples for i, sample in enumerate(samples): if 'id' not in sample: sample['id'] = i # create dataloader dataset = TestDataset(samples) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)), ) return iter(dataloader)
def get_new_sample_for_hypotheses(self, orig_sample): """ Extract hypotheses from *orig_sample* and return a new collated sample. """ ids = orig_sample['id'].tolist() pad_idx = self.source_dictionary.pad() samples = [{ 'id': ids[i], 'source': utils.strip_pad(orig_sample['net_input']['src_tokens'][i, :], pad_idx), 'target': hypo['tokens'], } for i, hypos_i in enumerate(orig_sample['hypos']) for hypo in hypos_i] return language_pair_dataset.collate( samples, pad_idx=pad_idx, eos_idx=self.source_dictionary.eos(), left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, sort=False, )