def test_push_mode_instance_errors(self):
        class A(TorchDispatchMode):
            pass

        with self.assertRaisesRegex(ValueError,
                                    'instance of TorchDispatchMode'):
            with push_torch_dispatch_mode(A(inner=None)):
                pass
    def test_torch_dispatch_mode_stack(self) -> None:
        logs = []

        class Logger(TorchDispatchMode):
            def __init__(self, name):
                self.name = name

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                if kwargs is None:
                    kwargs = {}
                logs.append(self.name)
                return func(*args, **kwargs)

        x = torch.randn(1)
        with push_torch_dispatch_mode(partial(Logger, "A")):
            with push_torch_dispatch_mode(partial(Logger, "B")):
                x + x
        self.assertEqual(logs, ["B", "A"])
    def test_push_torch_dispatch_mode(self) -> None:
        class ErrorA(RuntimeError):
            def __init__(self, msg=None):
                return super().__init__(msg)

        class A(TorchDispatchMode):
            def __init__(self, msg=None):
                self.msg = msg

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                raise ErrorA(self.msg)

        x = torch.randn(3)
        with self.assertRaises(ErrorA):
            with push_torch_dispatch_mode(A):
                torch.add(x, x)

        with self.assertRaisesRegex(ErrorA, r"partial constructor"):
            with push_torch_dispatch_mode(partial(A, "partial constructor")):
                x + x
Beispiel #4
0
def dispatch_trace(
        root: Union[torch.nn.Module, Callable],
        concrete_args: Optional[Tuple[Any, ...]] = None,
        trace_factory_functions: bool = False,
) -> GraphModule:
    tracer = PythonKeyTracer()
    if trace_factory_functions:
        with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)):
            graph = tracer.trace(root, concrete_args)
    else:
        graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)
 def test_push_mode_returns_unrelated(self):
     with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'):
         with push_torch_dispatch_mode(lambda *, inner: None):
             pass
Beispiel #6
0
def capture_logs_with_logging_tensor_mode():
    with push_torch_dispatch_mode(LoggingTensorMode), capture_logs(
            True) as logs:
        yield logs