示例#1
0
    def __torch_function__(self, func, types, args=(), kwargs=None):
        """
        Called whenever a torch.* or torch.nn.functional.* method is being called on a storch.Tensor. This wraps
        that method in the deterministic wrapper to properly handle all input arguments and outputs.
        """
        if kwargs is None:
            kwargs = {}
        func_name = func.__name__
        if func_name in exception_methods:
            raise IllegalStorchExposeError(
                "Calling method " + func_name + " with storch tensors is not allowed."
            )
        if func_name in excluded_methods:
            return func(*args, **kwargs)

        if func_name in expand_methods:
            # Automatically expand empty plate dimensions. This is necessary for some loss functions, which
            # assume both inputs have exactly the same elements.
            return storch.wrappers._handle_deterministic(
                func, args, kwargs, expand_plates=True
            )
        # if func_name in unwrap_only_methods:
        #     return storch.wrappers._unpack_wrapper(func)(*args, *kwargs)

        return storch.wrappers._handle_deterministic(func, args, kwargs)
示例#2
0
 def wrapper(*args, **kwargs):
     for a in args:
         if isinstance(a, storch.Tensor):
             raise IllegalStorchExposeError(
                 "It is not allowed to call this method using storch.Tensor, likely "
                 "because it exposes its wrapped tensor to Python.")
     return fn(*args, **kwargs)
示例#3
0
    def __getattr__(self, item):
        """
        Called whenever an attribute is called on a storch.Tensor object that is not directly implemented by storch.Tensor.
        It defers it to the underlying torch.Tensor. If it is a callable (ie, torch.Tensor implements a function
        with the name item), it will wrap this callable with a deterministic wrapper.

        TODO: This should probably filter the methods
        """
        attr = getattr(torch.Tensor, item)
        if isinstance(attr, Callable):
            func_name = attr.__name__
            if func_name in exception_methods:
                raise IllegalStorchExposeError(
                    "Calling method "
                    + func_name
                    + " with storch tensors is not allowed."
                )
            if func_name in excluded_methods:
                return attr
            return storch.wrappers._self_deterministic(attr, self)
示例#4
0
 def __contains__(self, item):
     raise IllegalStorchExposeError(
         "It is not allowed to expose storch tensors via in statements."
     )
示例#5
0
 def numpy(self):
     raise IllegalStorchExposeError(
         "It is not allowed to convert storch tensors to numpy arrays. Make sure to unwrap "
         "storch tensors to normal torch tensor to use this tensor as a np.array."
     )
示例#6
0
 def __nonzero__(self) -> builtins.bool:
     raise IllegalStorchExposeError(
         "It is not allowed to convert storch tensors to boolean. Make sure to unwrap "
         "storch tensors to normal torch tensor to use this tensor as a boolean."
     )
示例#7
0
 def __long__(self):
     raise IllegalStorchExposeError(
         "It is not allowed to convert storch tensors to long. Make sure to unwrap "
         "storch tensors to normal torch tensor to use this tensor as a long."
     )
示例#8
0
 def __index__(self) -> int:
     raise IllegalStorchExposeError("Cannot use storch tensors as index.")
示例#9
0
 def __and__(self, other):
     if isinstance(other, bool):
         raise IllegalStorchExposeError(
             "Calling 'and' with a bool exposes the underlying tensor as a bool."
         )
     return storch.deterministic(self._tensor.__and__)(other)