def _check_environment_consistency(self, worker: base_worker.WorkerBase) -> None: # It is important that the worker mirrors the caller. Otherwise imports # may not resolve, or may resolve to incorrect paths. As a result, the # worker must ensure that it faithfully reproduces the caller's # environment. worker.run(""" import os import sys cwd = os.getcwd() sys_executable = sys.executable sys_path = sys.path """) self.assertEqual(worker.load("cwd"), os.getcwd()) self.assertEqual(worker.load("sys_executable"), sys.executable) self.assertEqual(worker.load("sys_path"), sys.path) # Environment parity is especially important for `torch`, since # importing an incorrect version will result in silently garbage # results. worker.run(""" import torch torch_file = torch.__file__ """) self.assertEqual(worker.load("torch_file"), torch.__file__) self._subtest_cleanup( worker, ("os", "sys", "cwd", "sys_executable", "sys_path", "torch", "torch_file"), )
def _check_load_stmt(self, worker: base_worker.WorkerBase) -> None: self.assertDictEqual( { "a": 1 + 3, 2: "b" }, worker.load_stmt('{"a": 1 + 3, 2: "b"}'), ) self._subtest_cleanup(worker, ())
def _check_basic_store_and_load(self, worker: base_worker.WorkerBase) -> None: worker.store("y", 2) self.assertEqual(worker.load("y"), 2) worker.run("del y") with self.assertRaisesRegex(NameError, "name 'y' is not defined"): worker.load("y")
def _test_namespace_isolation(self, worker: base_worker.WorkerBase): worker_global_vars: typing.Dict[str, str] = worker.load_stmt( r"{k: repr(type(v)) for k, v in globals().items()}") allowed_keys = { "__builtins__", subprocess_rpc.WORKER_IMPL_NAMESPACE, } extra_vars = { k: v for k, v in worker_global_vars.items() if k not in allowed_keys } self.assertDictEqual(extra_vars, {})
def _check_custom_store_and_load(self, worker: base_worker.WorkerBase) -> None: with self.assertRaisesRegex(ValueError, "unmarshallable object"): worker.store("my_class", CustomClass()) worker.run(""" class CustomClass: pass my_class = CustomClass() """) with self.assertRaisesRegex(ValueError, "unmarshallable object"): worker.load("my_class") self._subtest_cleanup(worker, ("my_class", "CustomClass"))
def _check_complex_stmts(self, worker: base_worker.WorkerBase) -> None: worker.run(""" def test_fn(): x = 10 y = 2 # Make sure we can handle blank lines. return x + y z = test_fn() """) self.assertEqual(worker.load("z"), 12) # Ensure variables persist across invocations. (In this case, `f`) worker.run("z = test_fn() + 1") self.assertEqual(worker.load("z"), 13) # Ensure invocations have access to global variables. worker.store("captured_var", 5) worker.run(""" def test_fn(): # Make sure closures work properly return captured_var + 1 z = test_fn() """) self.assertEqual(worker.load("z"), 6) self._subtest_cleanup(worker, ("captured_var", "z", "test_fn"))
def _subtest_cleanup(self, worker: base_worker.WorkerBase, test_vars: typing.Tuple[str, ...]) -> None: worker.run("\n".join([f"del {v}" for v in test_vars])) self._test_namespace_isolation(worker)
def _test_exceptions(self, worker: base_worker.WorkerBase): with self.assertRaisesRegex(AssertionError, "False is not True"): worker.run("assert False, 'False is not True'") with self.assertRaisesRegex(ValueError, "Test msg"): worker.run("raise ValueError('Test msg')")