def test_collect_top_k_probs(self):
        test_args = test_utils.ModelParamsDict(arch="hybrid_transformer_rnn")
        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        self.task = tasks.DictionaryHolderTask(src_dict, tgt_dict)
        model = self.task.build_model(test_args)

        use_cuda = torch.cuda.is_available()
        if use_cuda:
            model.cuda()
        model.eval()

        binarized_source = test_utils.create_dummy_binarized_dataset()
        binarized_target = test_utils.create_dummy_binarized_dataset(
            append_eos=True)
        dataset = language_pair_dataset.LanguagePairDataset(
            src=binarized_source,
            src_sizes=binarized_source.sizes,
            src_dict=self.task.src_dict,
            tgt=binarized_target,
            tgt_sizes=binarized_target.sizes,
            tgt_dict=self.task.dst_dict,
            left_pad_source=False,
        )

        top_k_scores, top_k_indices = collect_top_k_probs.compute_top_k(
            task=self.task,
            models=[model],
            dataset=dataset,
            k=3,
            use_cuda=use_cuda,
            max_tokens=None,
            max_sentences=None,
            progress_bar_args=None,
        )

        batch = language_pair_dataset.collate(
            [dataset[0]],
            pad_idx=self.task.src_dict.pad(),
            eos_idx=self.task.src_dict.eos(),
            left_pad_source=False,
        )

        sample = batch["net_input"]
        if use_cuda:
            sample = utils.move_to_cuda(sample)

        with torch.no_grad():
            net_output = model(**sample)
            probs = model.get_normalized_probs(net_output, log_probs=False)

        top_probs, top_indices = torch.topk(probs[0, 0], k=3)
        if use_cuda:
            top_probs = top_probs.cpu()
            top_indices = top_indices.cpu()

        np.testing.assert_array_equal(top_k_indices[0], top_indices.numpy())
        normalized_probs = (top_probs / top_probs.sum()).numpy()
        np.testing.assert_almost_equal(top_k_scores[0], normalized_probs)
Ejemplo n.º 2
0
    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                   IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
        )
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res["net_input"]["src_tokens"]
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res["net_input"]["src_lang_id"] = torch.LongTensor(
                    [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
            if self.tgt_lang_id is not None:
                res["tgt_lang_id"] = torch.LongTensor(
                    [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
        return res
 def get_iterator(self, samples, batch_size):
     batches = []
     for i in range(0, math.ceil(len(samples) / batch_size)):
         sample = samples[i * batch_size:i * batch_size + batch_size]
         batch = collate(sample,
                         pad_idx=self.src_dict.pad(),
                         eos_idx=self.src_dict.eos(),
                         left_pad_source=True,
                         state_machine=False)
         batches.append(batch)
     return batches
Ejemplo n.º 4
0
def dummy_dataloader(
    samples,
    padding_idx=1,
    eos_idx=2,
    batch_size=None,
):
    if batch_size is None:
        batch_size = len(samples)

    # add any missing data to samples
    for i, sample in enumerate(samples):
        if "id" not in sample:
            sample["id"] = i

    # create dataloader
    dataset = TestDataset(samples)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
    )
    return iter(dataloader)
Ejemplo n.º 5
0
def dummy_dataloader(
    samples,
    padding_idx=1,
    eos_idx=2,
    batch_size=None,
):
    if batch_size is None:
        batch_size = len(samples)

    # add any missing data to samples
    for i, sample in enumerate(samples):
        if 'id' not in sample:
            sample['id'] = i

    # create dataloader
    dataset = TestDataset(samples)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
    )
    return iter(dataloader)
Ejemplo n.º 6
0
 def get_new_sample_for_hypotheses(self, orig_sample):
     """
     Extract hypotheses from *orig_sample* and return a new collated sample.
     """
     ids = orig_sample['id'].tolist()
     pad_idx = self.source_dictionary.pad()
     samples = [{
         'id':
         ids[i],
         'source':
         utils.strip_pad(orig_sample['net_input']['src_tokens'][i, :],
                         pad_idx),
         'target':
         hypo['tokens'],
     } for i, hypos_i in enumerate(orig_sample['hypos'])
                for hypo in hypos_i]
     return language_pair_dataset.collate(
         samples,
         pad_idx=pad_idx,
         eos_idx=self.source_dictionary.eos(),
         left_pad_source=self.args.left_pad_source,
         left_pad_target=self.args.left_pad_target,
         sort=False,
     )