示例#1
0
 def __init__(
     self,
     examples,
     padded_keys=None,
     device_prep_keys=None,
     padding_func=batch_pad_right,
     padding_kwargs={},
     apply_default_convert=True,
     nonpadded_stack=True,
 ):
     self.__keys = list(examples[0].keys())
     self.__padded_keys = []
     self.__device_prep_keys = []
     for key in self.__keys:
         values = [example[key] for example in examples]
         # Default convert usually does the right thing (numpy2torch etc.)
         if apply_default_convert:
             values = default_convert(values)
         if (padded_keys is not None and key in padded_keys) or (
             padded_keys is None and isinstance(values[0], torch.Tensor)
         ):
             # Padding and PaddedData
             self.__padded_keys.append(key)
             padded = PaddedData(*padding_func(values, **padding_kwargs))
             setattr(self, key, padded)
         else:
             # Default PyTorch collate usually does the right thing
             # (convert lists of equal sized tensors to batch tensors, etc.)
             if nonpadded_stack:
                 values = mod_default_collate(values)
             setattr(self, key, values)
         if (device_prep_keys is not None and key in device_prep_keys) or (
             device_prep_keys is None and isinstance(values[0], torch.Tensor)
         ):
             self.__device_prep_keys.append(key)
示例#2
0
    def __getitem__(self, i):
        data = self.datas[i]
        if isinstance(input, LazyInputData):
            data = data.generate()

        return default_convert(dict(
            feature=data.feature,
            target=data.target,
        ))
示例#3
0
 def calculate_scores(
     self, arm_indices: List[int], arm_contexts: Tuple[np.ndarray, ...]
 ) -> List[float]:
     if self.reward_model:
         inputs: torch.Tensor = default_convert(arm_contexts)
         scores: torch.Tensor = self.reward_model(*inputs)
         return scores.detach().cpu().numpy().tolist()
     else:
         return list(np.zeros(len(arm_indices)))
示例#4
0
 def _to_tensor(obj):
     if type(obj) is torch.Tensor:
         return obj
     # converting obj[:] makes sure we get the data out of any db.feature object
     maybe_tensor = default_convert(obj[:])
     if type(maybe_tensor) is torch.Tensor:
         return maybe_tensor
     try:
         obj = torch.tensor(obj)
     except Exception as e:
         raise e
     return obj
示例#5
0
def tokenize_plus(tokenizer, max_len, pad_to_max_length, text):
    if max_len is None:
        max_len = tokenizer.max_len
    encode_dict = tokenizer.encode_plus(  # using encode_plus for single text tokenization (in seqtools.smap).
        text,
        add_special_tokens=True,
        # return_tensors="pt", # pt makes a batched tensor (1,512), flat it to avoid wrong batching
        max_length=max_len,
        pad_to_max_length=pad_to_max_length,
    )
    for key in encode_dict:
        encode_dict[key] = np.array(encode_dict[key])
    return default_convert(encode_dict)
示例#6
0
def _convert_to_tensor(data: Any) -> Any:
    """
    Maps all kind of collections and numbers to tensors

    Args:
        data: the data to convert to tensor

    Returns:
        the converted data

    """
    if isinstance(data, numbers.Number):
        return torch.tensor([data])
    else:
        return default_convert(data)
示例#7
0
def collate_episodes(batch):
    """Manually processes collected batch of experience

    Args:
        batch (TYPE): Current mini batch of replay data

    Returns:
        TYPE: Processed mini batch of replay data
    """
    batch_dict = {
        j: [collate.default_convert(s[i]) for s in batch][0]
        for i, j in enumerate(Experience._fields)
    }

    return batch_dict.values()
示例#8
0
    def __call__(self, *args, **kwargs):
        batch = self._collate_fn(*args, **kwargs)

        if self._transforms is not None:

            if isinstance(batch, Mapping):
                batch = self._transforms(**batch)
            elif isinstance(batch, Sequence):
                batch = self._transforms(*batch)
            else:
                batch = self._transforms(batch)

        if self._auto_convert:
            batch = default_convert(batch)

        return batch
示例#9
0
文件: loader.py 项目: kiminh/rising
    def __call__(self, *args, **kwargs) -> Any:
        """
        Apply batch workflow: collate -> augmentation -> default_convert

        Args:
            *args: positional batch arguments
            **kwargs: keyword batch arguments

        Returns:
            Any: batched and augmented data
        """
        batch = self._collate_fn(*args, **kwargs)

        if self._transforms is not None:
            batch = self._transform_call(batch, self._transforms)

        if self._auto_convert:
            batch = default_convert(batch)

        return batch
示例#10
0
 def __getitem__(self, i):
     return default_convert(self.dataset[i])
示例#11
0
 def get_sample_batch(self):
     return default_convert(self.train_dataset[0][0])
示例#12
0
 def collate_fn(self, batch):
     return collate.default_convert(batch)
示例#13
0
def fa_convert(t):
    return (default_convert(t) if isinstance(t, _collate_types) else type(t)(
        [fa_convert(s)
         for s in t]) if isinstance(t, Sequence) else default_convert(t))
示例#14
0
def fa_convert(t):
    "A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s"
    return (default_convert(t) if isinstance(t, _collate_types) else type(t)(
        [fa_convert(s)
         for s in t]) if isinstance(t, Sequence) else default_convert(t))