def _index_value(self, key, value, index, error_type: Type[Exception] = RuntimeError): # Obtain indexed element from value if isinstance(value, dict): # Return subdict proxy return DictIndexProxy(value, index, self._prefix + key) elif isinstance(value, tuple): # Return tuple of slices # Since we can't proxy a tuple, we slice eagerly # Use type(value) to support named tuples. (the keys is still index though) return new_tuple( type(value), (self._index_value(f'{key}[{i}]', v, index, error_type) for i, v in enumerate(value))) elif isinstance(value, (to.Tensor, np.ndarray)): # Return slice of ndarray / tensor return value[index, ...] elif isinstance(value, list): # Return list item return value[index] else: # Unsupported type raise error_type( f'Entry {self._prefix}{key} has un-gettable type {type(value)}' )
def _slice_entry(self, entry, index: slice): if isinstance(entry, dict): return {k: self._slice_entry(v, index) for k, v in entry.items()} if isinstance(entry, tuple): return new_tuple(type(entry), (self._slice_entry(e, index) for e in entry)) elif isinstance(entry, (to.Tensor, np.ndarray)): return entry[index, ...] elif isinstance(entry, list): return entry[index] else: return None # unsupported
def __map_tensors(self, mapper, elem): if isinstance(elem, dict): # Modify dict in-place for k in elem.keys(): elem[k] = self.__map_tensors(mapper, elem[k]) return elem if isinstance(elem, tuple): # Can't modify in place since it's a tuple return new_tuple(type(elem), (self.__map_tensors(mapper, part) for part in elem)) # Tensor element return mapper(elem)
def _truncate_after_last(self, entry): if isinstance(entry, dict): return {k: self._truncate_after_last(v) for k, v in entry.items()} if isinstance(entry, tuple): return new_tuple(type(entry), (self._truncate_after_last(v) for v in entry)) elif isinstance(entry, (to.Tensor, np.ndarray)): if entry.shape[0] == self.length + 1: return entry[:-1, ...] elif isinstance(entry, list): if len(entry) == self.length + 1: return entry[:-1] # No truncation return entry