def test_push_mode_instance_errors(self): class A(TorchDispatchMode): pass with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'): with push_torch_dispatch_mode(A(inner=None)): pass
def test_torch_dispatch_mode_stack(self) -> None: logs = [] class Logger(TorchDispatchMode): def __init__(self, name): self.name = name def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} logs.append(self.name) return func(*args, **kwargs) x = torch.randn(1) with push_torch_dispatch_mode(partial(Logger, "A")): with push_torch_dispatch_mode(partial(Logger, "B")): x + x self.assertEqual(logs, ["B", "A"])
def test_push_torch_dispatch_mode(self) -> None: class ErrorA(RuntimeError): def __init__(self, msg=None): return super().__init__(msg) class A(TorchDispatchMode): def __init__(self, msg=None): self.msg = msg def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise ErrorA(self.msg) x = torch.randn(3) with self.assertRaises(ErrorA): with push_torch_dispatch_mode(A): torch.add(x, x) with self.assertRaisesRegex(ErrorA, r"partial constructor"): with push_torch_dispatch_mode(partial(A, "partial constructor")): x + x
def dispatch_trace( root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None, trace_factory_functions: bool = False, ) -> GraphModule: tracer = PythonKeyTracer() if trace_factory_functions: with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)): graph = tracer.trace(root, concrete_args) else: graph = tracer.trace(root, concrete_args) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name)
def test_push_mode_returns_unrelated(self): with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'): with push_torch_dispatch_mode(lambda *, inner: None): pass
def capture_logs_with_logging_tensor_mode(): with push_torch_dispatch_mode(LoggingTensorMode), capture_logs( True) as logs: yield logs