コード例 #1
0
    def collate(self, examples: List[Tuple[List[str], List[str]]]) -> Batch:
        # For `PairedTextData`, each example is represented as a tuple of list
        # of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        src_examples = [example[0] for example in examples]
        source_ids = [
            self._src_vocab.map_tokens_to_ids_py(sent) for sent in src_examples
        ]
        source_ids, source_lengths = \
            padded_batch(source_ids,
                         self._src_pad_length,
                         pad_value=self._src_vocab.pad_token_id)
        src_pad_length = self._src_pad_length or max(source_lengths)
        src_examples = [
            sent + [''] * (src_pad_length - len(sent))
            if len(sent) < src_pad_length else sent for sent in src_examples
        ]

        source_ids = torch.from_numpy(source_ids).to(device=self.device)
        source_lengths = torch.tensor(source_lengths,
                                      dtype=torch.long,
                                      device=self.device)

        tgt_examples = [example[1] for example in examples]
        target_ids = [
            self._tgt_vocab.map_tokens_to_ids_py(sent) for sent in tgt_examples
        ]
        target_ids, target_lengths = \
            padded_batch(target_ids,
                         self._tgt_pad_length,
                         pad_value=self._tgt_vocab.pad_token_id)
        tgt_pad_length = self._tgt_pad_length or max(target_lengths)
        tgt_examples = [
            sent + [''] * (tgt_pad_length - len(sent))
            if len(sent) < tgt_pad_length else sent for sent in tgt_examples
        ]

        target_ids = torch.from_numpy(target_ids).to(device=self.device)
        target_lengths = torch.tensor(target_lengths,
                                      dtype=torch.long,
                                      device=self.device)

        return Batch(len(examples),
                     source_text=src_examples,
                     source_text_ids=source_ids,
                     source_length=source_lengths,
                     target_text=tgt_examples,
                     target_text_ids=target_ids,
                     target_length=target_lengths)
コード例 #2
0
    def collate(self, examples: List[List[str]]) -> Batch:
        # For `MonoTextData`, each example is represented as a list of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        text_ids = [
            self._vocab.map_tokens_to_ids_py(sent) for sent in examples
        ]
        text_ids, lengths = padded_batch(text_ids,
                                         self._pad_length,
                                         pad_value=self._vocab.pad_token_id)
        # Also pad the examples
        pad_length = self._pad_length or max(lengths)
        examples = [
            sent + [''] *
            (pad_length - len(sent)) if len(sent) < pad_length else sent
            for sent in examples
        ]

        text_ids = torch.from_numpy(text_ids)
        lengths = torch.tensor(lengths, dtype=torch.long)
        batch = {
            self.text_name: examples,
            self.text_id_name: text_ids,
            self.length_name: lengths
        }
        return Batch(len(examples), batch=batch)
コード例 #3
0
 def collate(self, examples: List[Dict[str, Any]]) -> Batch:
     batch = {}
     for key, descriptor in self._features.items():
         values = [ex[key] for ex in examples]
         if descriptor.collate_method is not CollateMethod.List:
             # NumPy functions work on PyTorch tensors too.
             if descriptor.collate_method is CollateMethod.StackedTensor:
                 values = np.stack(values, axis=0)
             else:  # padded_tensor
                 values, _ = padded_batch(values)
             if (not isinstance(values, torch.Tensor)
                     and descriptor.dtype not in [np.str_, np.bytes_]):
                 values = torch.from_numpy(values)
         else:
             # Just put everything in a Python list.
             pass
         batch[key] = values
     return Batch(len(examples), batch)
コード例 #4
0
 def collate(self, examples: List[Dict[str, Any]]) -> Batch:
     batch = {}
     for key, descriptor in self._features.items():
         values = [ex[key] for ex in examples]
         if descriptor.shape is not None:
             # FixedLenFeature, do not pad.
             # NumPy functions work on PyTorch tensors too.
             if len(descriptor.shape) > 0 and descriptor.shape[0] is None:
                 values, _ = padded_batch(values)
             else:
                 values = np.stack(values, axis=0)
             if (not isinstance(values, torch.Tensor)
                     and descriptor.dtype not in [np.str_, np.bytes_]):
                 values = torch.from_numpy(values)
         else:
             # VarLenFeature, just put everything in a Python list.
             pass
         batch[key] = values
     return Batch(len(examples), batch)