def collate(self, examples: List[Tuple[List[str], List[str]]]) -> Batch: # For `PairedTextData`, each example is represented as a tuple of list # of strings. # `_collate` takes care of padding and numericalization. # If `pad_length` is `None`, pad to the longest sentence in the batch. src_examples = [example[0] for example in examples] source_ids = [ self._src_vocab.map_tokens_to_ids_py(sent) for sent in src_examples ] source_ids, source_lengths = \ padded_batch(source_ids, self._src_pad_length, pad_value=self._src_vocab.pad_token_id) src_pad_length = self._src_pad_length or max(source_lengths) src_examples = [ sent + [''] * (src_pad_length - len(sent)) if len(sent) < src_pad_length else sent for sent in src_examples ] source_ids = torch.from_numpy(source_ids).to(device=self.device) source_lengths = torch.tensor(source_lengths, dtype=torch.long, device=self.device) tgt_examples = [example[1] for example in examples] target_ids = [ self._tgt_vocab.map_tokens_to_ids_py(sent) for sent in tgt_examples ] target_ids, target_lengths = \ padded_batch(target_ids, self._tgt_pad_length, pad_value=self._tgt_vocab.pad_token_id) tgt_pad_length = self._tgt_pad_length or max(target_lengths) tgt_examples = [ sent + [''] * (tgt_pad_length - len(sent)) if len(sent) < tgt_pad_length else sent for sent in tgt_examples ] target_ids = torch.from_numpy(target_ids).to(device=self.device) target_lengths = torch.tensor(target_lengths, dtype=torch.long, device=self.device) return Batch(len(examples), source_text=src_examples, source_text_ids=source_ids, source_length=source_lengths, target_text=tgt_examples, target_text_ids=target_ids, target_length=target_lengths)
def collate(self, examples: List[List[str]]) -> Batch: # For `MonoTextData`, each example is represented as a list of strings. # `_collate` takes care of padding and numericalization. # If `pad_length` is `None`, pad to the longest sentence in the batch. text_ids = [ self._vocab.map_tokens_to_ids_py(sent) for sent in examples ] text_ids, lengths = padded_batch(text_ids, self._pad_length, pad_value=self._vocab.pad_token_id) # Also pad the examples pad_length = self._pad_length or max(lengths) examples = [ sent + [''] * (pad_length - len(sent)) if len(sent) < pad_length else sent for sent in examples ] text_ids = torch.from_numpy(text_ids) lengths = torch.tensor(lengths, dtype=torch.long) batch = { self.text_name: examples, self.text_id_name: text_ids, self.length_name: lengths } return Batch(len(examples), batch=batch)
def collate(self, examples: List[Dict[str, Any]]) -> Batch: batch = {} for key, descriptor in self._features.items(): values = [ex[key] for ex in examples] if descriptor.collate_method is not CollateMethod.List: # NumPy functions work on PyTorch tensors too. if descriptor.collate_method is CollateMethod.StackedTensor: values = np.stack(values, axis=0) else: # padded_tensor values, _ = padded_batch(values) if (not isinstance(values, torch.Tensor) and descriptor.dtype not in [np.str_, np.bytes_]): values = torch.from_numpy(values) else: # Just put everything in a Python list. pass batch[key] = values return Batch(len(examples), batch)
def collate(self, examples: List[Dict[str, Any]]) -> Batch: batch = {} for key, descriptor in self._features.items(): values = [ex[key] for ex in examples] if descriptor.shape is not None: # FixedLenFeature, do not pad. # NumPy functions work on PyTorch tensors too. if len(descriptor.shape) > 0 and descriptor.shape[0] is None: values, _ = padded_batch(values) else: values = np.stack(values, axis=0) if (not isinstance(values, torch.Tensor) and descriptor.dtype not in [np.str_, np.bytes_]): values = torch.from_numpy(values) else: # VarLenFeature, just put everything in a Python list. pass batch[key] = values return Batch(len(examples), batch)