def collate(self, examples: List[List[str]]) -> Batch:
        # For `MonoTextData`, each example is represented as a list of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        text_ids = [
            self._vocab.map_tokens_to_ids_py(sent) for sent in examples
        ]
        text_ids, lengths = padded_batch(text_ids,
                                         self._pad_length,
                                         pad_value=self._vocab.pad_token_id)
        # Also pad the examples
        pad_length = self._pad_length or max(lengths)
        examples = [
            sent + [''] *
            (pad_length - len(sent)) if len(sent) < pad_length else sent
            for sent in examples
        ]

        text_ids = torch.from_numpy(text_ids).to(device=self.device)
        lengths = torch.tensor(lengths, dtype=torch.long, device=self.device)
        return Batch(len(examples),
                     text=examples,
                     text_ids=text_ids,
                     length=lengths)
 def collate(self, examples) -> Batch:
     transposed_examples = map(list, zip(*examples))
     batch: Dict[str, Any] = {}
     for i, transposed_example in enumerate(transposed_examples):
         kth_batch = self._databases[i].collate(transposed_example)
         for key, name in self._names[i].items():
             batch.update({name: kth_batch[key]})
     return Batch(len(examples), batch=batch)
Beispiel #3
0
 def collate(self, examples: List[Example]) -> Batch:
     src_pad_length = max(len(src) for src, _ in examples)
     tgt_pad_length = max(len(tgt) for _, tgt in examples)
     batch_size = len(examples)
     src_indices = np.zeros((batch_size, src_pad_length), dtype=np.int64)
     tgt_indices = np.zeros((batch_size, tgt_pad_length), dtype=np.int64)
     for b_idx, (src, tgt) in enumerate(examples):
         src_indices[b_idx, :len(src)] = src
         tgt_indices[b_idx, :len(tgt)] = tgt
     src_indices = torch.from_numpy(src_indices).to(device=self.device)
     tgt_indices = torch.from_numpy(tgt_indices).to(device=self.device)
     return Batch(batch_size, src=src_indices, tgt=tgt_indices)
    def collate(self, examples: List[Tuple[List[str], List[str]]]) -> Batch:
        # For `PairedTextData`, each example is represented as a tuple of list
        # of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        src_examples = [example[0] for example in examples]
        source_ids = [
            self._src_vocab.map_tokens_to_ids_py(sent) for sent in src_examples
        ]
        source_ids, source_lengths = \
            padded_batch(source_ids,
                         self._src_pad_length,
                         pad_value=self._src_vocab.pad_token_id)
        src_pad_length = self._src_pad_length or max(source_lengths)
        src_examples = [
            sent + [''] * (src_pad_length - len(sent))
            if len(sent) < src_pad_length else sent for sent in src_examples
        ]

        source_ids = torch.from_numpy(source_ids).to(device=self.device)
        source_lengths = torch.tensor(source_lengths,
                                      dtype=torch.long,
                                      device=self.device)

        tgt_examples = [example[1] for example in examples]
        target_ids = [
            self._tgt_vocab.map_tokens_to_ids_py(sent) for sent in tgt_examples
        ]
        target_ids, target_lengths = \
            padded_batch(target_ids,
                         self._tgt_pad_length,
                         pad_value=self._tgt_vocab.pad_token_id)
        tgt_pad_length = self._tgt_pad_length or max(target_lengths)
        tgt_examples = [
            sent + [''] * (tgt_pad_length - len(sent))
            if len(sent) < tgt_pad_length else sent for sent in tgt_examples
        ]

        target_ids = torch.from_numpy(target_ids).to(device=self.device)
        target_lengths = torch.tensor(target_lengths,
                                      dtype=torch.long,
                                      device=self.device)

        return Batch(len(examples),
                     source_text=src_examples,
                     source_text_ids=source_ids,
                     source_length=source_lengths,
                     target_text=tgt_examples,
                     target_text_ids=target_ids,
                     target_length=target_lengths)
Beispiel #5
0
 def collate(self, examples: List[Dict[str, Any]]) -> Batch:
     batch = {}
     for key, descriptor in self._features.items():
         values = [ex[key] for ex in examples]
         if descriptor.shape is not None:
             # FixedLenFeature, do not pad.
             # NumPy functions work on PyTorch tensors too.
             if len(descriptor.shape) > 0 and descriptor.shape[0] is None:
                 values, _ = padded_batch(values)
             else:
                 values = np.stack(values, axis=0)
             if (not torch.is_tensor(values)
                     and descriptor.dtype not in [np.str_, np.bytes_]):
                 values = torch.from_numpy(values)
         else:
             # VarLenFeature, just put everything in a Python list.
             pass
         batch[key] = values
     return Batch(len(examples), batch)
Beispiel #6
0
 def collate(self, examples: List[Example]) -> Batch:
     numbers = np.asarray([ex[0] for ex in examples])
     strings = np.asarray([ex[1] for ex in examples])
     return Batch(len(numbers), numbers=numbers, strings=strings)
Beispiel #7
0
 def collate(self, examples):
     return Batch(len(examples), text=examples)
Beispiel #8
0
 def collate(self, examples: List[Union[int, float]]) -> Batch:
     # convert the list of strings into appropriate tensors here
     examples_np = np.array(examples, dtype=self._to_data_type)
     collated_examples = torch.from_numpy(examples_np).to(
         device=self.device)
     return Batch(len(examples), batch={self.data_name: collated_examples})