def __torch_function__(self, func, types, args=(), kwargs=None): """ Called whenever a torch.* or torch.nn.functional.* method is being called on a storch.Tensor. This wraps that method in the deterministic wrapper to properly handle all input arguments and outputs. """ if kwargs is None: kwargs = {} func_name = func.__name__ if func_name in exception_methods: raise IllegalStorchExposeError( "Calling method " + func_name + " with storch tensors is not allowed." ) if func_name in excluded_methods: return func(*args, **kwargs) if func_name in expand_methods: # Automatically expand empty plate dimensions. This is necessary for some loss functions, which # assume both inputs have exactly the same elements. return storch.wrappers._handle_deterministic( func, args, kwargs, expand_plates=True ) # if func_name in unwrap_only_methods: # return storch.wrappers._unpack_wrapper(func)(*args, *kwargs) return storch.wrappers._handle_deterministic(func, args, kwargs)
def wrapper(*args, **kwargs): for a in args: if isinstance(a, storch.Tensor): raise IllegalStorchExposeError( "It is not allowed to call this method using storch.Tensor, likely " "because it exposes its wrapped tensor to Python.") return fn(*args, **kwargs)
def __getattr__(self, item): """ Called whenever an attribute is called on a storch.Tensor object that is not directly implemented by storch.Tensor. It defers it to the underlying torch.Tensor. If it is a callable (ie, torch.Tensor implements a function with the name item), it will wrap this callable with a deterministic wrapper. TODO: This should probably filter the methods """ attr = getattr(torch.Tensor, item) if isinstance(attr, Callable): func_name = attr.__name__ if func_name in exception_methods: raise IllegalStorchExposeError( "Calling method " + func_name + " with storch tensors is not allowed." ) if func_name in excluded_methods: return attr return storch.wrappers._self_deterministic(attr, self)
def __contains__(self, item): raise IllegalStorchExposeError( "It is not allowed to expose storch tensors via in statements." )
def numpy(self): raise IllegalStorchExposeError( "It is not allowed to convert storch tensors to numpy arrays. Make sure to unwrap " "storch tensors to normal torch tensor to use this tensor as a np.array." )
def __nonzero__(self) -> builtins.bool: raise IllegalStorchExposeError( "It is not allowed to convert storch tensors to boolean. Make sure to unwrap " "storch tensors to normal torch tensor to use this tensor as a boolean." )
def __long__(self): raise IllegalStorchExposeError( "It is not allowed to convert storch tensors to long. Make sure to unwrap " "storch tensors to normal torch tensor to use this tensor as a long." )
def __index__(self) -> int: raise IllegalStorchExposeError("Cannot use storch tensors as index.")
def __and__(self, other): if isinstance(other, bool): raise IllegalStorchExposeError( "Calling 'and' with a bool exposes the underlying tensor as a bool." ) return storch.deterministic(self._tensor.__and__)(other)