Exemplo n.º 1
0
 def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
     with enable_torch_dispatch_mode(LoggingTensorMode):
         z = torch.ones([2, 3])
         self.assertTrue(isinstance(z, LoggingTensorMode))
         with no_dispatch():
             expected = torch.ones([2, 3])
             self.assertEqual(z.elem, expected)
Exemplo n.º 2
0
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, NonWrapperSublass) else e

                def wrap(e):
                    return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e

                # no_dispatch is only needed if you use enable_python_mode.
                # It prevents infinite recursion.
                with no_dispatch():
                    rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
                logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
                return rs
Exemplo n.º 3
0
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, SubclassWithNone) else e

                def wrap(e):
                    return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e

                # no_dispatch is only needed if you use enable_python_mode.
                # It prevents infinite recursion.
                with no_dispatch():
                    rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
                if func.overloadpacket.__name__ == "add":
                    return None
                else:
                    return rs
Exemplo n.º 4
0
 def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
     if func == torch.ops.aten.split:
         with no_dispatch():
             return list_type(torch.split(*args))
     else:
         raise AssertionError(f"unrecognized func: {func}")
Exemplo n.º 5
0
 def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
     with no_dispatch():
         return func(*args, **kwargs)
Exemplo n.º 6
0
 def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
     with no_dispatch():
         return cls._torch_dispatch(func, types, args, kwargs)