Ejemplo n.º 1
0
 def _index_value(self,
                  key,
                  value,
                  index,
                  error_type: Type[Exception] = RuntimeError):
     # Obtain indexed element from value
     if isinstance(value, dict):
         # Return subdict proxy
         return DictIndexProxy(value, index, self._prefix + key)
     elif isinstance(value, tuple):
         # Return tuple of slices
         # Since we can't proxy a tuple, we slice eagerly
         # Use type(value) to support named tuples. (the keys is still index though)
         return new_tuple(
             type(value),
             (self._index_value(f'{key}[{i}]', v, index, error_type)
              for i, v in enumerate(value)))
     elif isinstance(value, (to.Tensor, np.ndarray)):
         # Return slice of ndarray / tensor
         return value[index, ...]
     elif isinstance(value, list):
         # Return list item
         return value[index]
     else:
         # Unsupported type
         raise error_type(
             f'Entry {self._prefix}{key} has un-gettable type {type(value)}'
         )
Ejemplo n.º 2
0
 def _slice_entry(self, entry, index: slice):
     if isinstance(entry, dict):
         return {k: self._slice_entry(v, index) for k, v in entry.items()}
     if isinstance(entry, tuple):
         return new_tuple(type(entry),
                          (self._slice_entry(e, index) for e in entry))
     elif isinstance(entry, (to.Tensor, np.ndarray)):
         return entry[index, ...]
     elif isinstance(entry, list):
         return entry[index]
     else:
         return None  # unsupported
Ejemplo n.º 3
0
    def __map_tensors(self, mapper, elem):
        if isinstance(elem, dict):
            # Modify dict in-place
            for k in elem.keys():
                elem[k] = self.__map_tensors(mapper, elem[k])
            return elem
        if isinstance(elem, tuple):
            # Can't modify in place since it's a tuple
            return new_tuple(type(elem), (self.__map_tensors(mapper, part)
                                          for part in elem))

        # Tensor element
        return mapper(elem)
Ejemplo n.º 4
0
 def _truncate_after_last(self, entry):
     if isinstance(entry, dict):
         return {k: self._truncate_after_last(v) for k, v in entry.items()}
     if isinstance(entry, tuple):
         return new_tuple(type(entry),
                          (self._truncate_after_last(v) for v in entry))
     elif isinstance(entry, (to.Tensor, np.ndarray)):
         if entry.shape[0] == self.length + 1:
             return entry[:-1, ...]
     elif isinstance(entry, list):
         if len(entry) == self.length + 1:
             return entry[:-1]
     # No truncation
     return entry