示例#1
0
    def _check_environment_consistency(self,
                                       worker: base_worker.WorkerBase) -> None:
        # It is important that the worker mirrors the caller. Otherwise imports
        # may not resolve, or may resolve to incorrect paths. As a result, the
        # worker must ensure that it faithfully reproduces the caller's
        # environment.
        worker.run("""
            import os
            import sys

            cwd = os.getcwd()
            sys_executable = sys.executable
            sys_path = sys.path
        """)
        self.assertEqual(worker.load("cwd"), os.getcwd())
        self.assertEqual(worker.load("sys_executable"), sys.executable)
        self.assertEqual(worker.load("sys_path"), sys.path)

        # Environment parity is especially important for `torch`, since
        # importing an incorrect version will result in silently garbage
        # results.
        worker.run("""
            import torch
            torch_file = torch.__file__
        """)
        self.assertEqual(worker.load("torch_file"), torch.__file__)

        self._subtest_cleanup(
            worker,
            ("os", "sys", "cwd", "sys_executable", "sys_path", "torch",
             "torch_file"),
        )
示例#2
0
 def _check_load_stmt(self, worker: base_worker.WorkerBase) -> None:
     self.assertDictEqual(
         {
             "a": 1 + 3,
             2: "b"
         },
         worker.load_stmt('{"a": 1 + 3, 2: "b"}'),
     )
     self._subtest_cleanup(worker, ())
示例#3
0
    def _check_basic_store_and_load(self,
                                    worker: base_worker.WorkerBase) -> None:
        worker.store("y", 2)
        self.assertEqual(worker.load("y"), 2)

        worker.run("del y")

        with self.assertRaisesRegex(NameError, "name 'y' is not defined"):
            worker.load("y")
示例#4
0
 def _test_namespace_isolation(self, worker: base_worker.WorkerBase):
     worker_global_vars: typing.Dict[str, str] = worker.load_stmt(
         r"{k: repr(type(v)) for k, v in globals().items()}")
     allowed_keys = {
         "__builtins__",
         subprocess_rpc.WORKER_IMPL_NAMESPACE,
     }
     extra_vars = {
         k: v
         for k, v in worker_global_vars.items() if k not in allowed_keys
     }
     self.assertDictEqual(extra_vars, {})
示例#5
0
    def _check_custom_store_and_load(self,
                                     worker: base_worker.WorkerBase) -> None:
        with self.assertRaisesRegex(ValueError, "unmarshallable object"):
            worker.store("my_class", CustomClass())

        worker.run("""
            class CustomClass:
                pass

            my_class = CustomClass()
        """)
        with self.assertRaisesRegex(ValueError, "unmarshallable object"):
            worker.load("my_class")

        self._subtest_cleanup(worker, ("my_class", "CustomClass"))
示例#6
0
    def _check_complex_stmts(self, worker: base_worker.WorkerBase) -> None:
        worker.run("""
            def test_fn():
                x = 10
                y = 2

                # Make sure we can handle blank lines.
                return x + y
            z = test_fn()
        """)
        self.assertEqual(worker.load("z"), 12)

        # Ensure variables persist across invocations. (In this case, `f`)
        worker.run("z = test_fn() + 1")
        self.assertEqual(worker.load("z"), 13)

        # Ensure invocations have access to global variables.
        worker.store("captured_var", 5)
        worker.run("""
            def test_fn():
                # Make sure closures work properly
                return captured_var + 1
            z = test_fn()
        """)
        self.assertEqual(worker.load("z"), 6)

        self._subtest_cleanup(worker, ("captured_var", "z", "test_fn"))
示例#7
0
 def _subtest_cleanup(self, worker: base_worker.WorkerBase,
                      test_vars: typing.Tuple[str, ...]) -> None:
     worker.run("\n".join([f"del {v}" for v in test_vars]))
     self._test_namespace_isolation(worker)
示例#8
0
    def _test_exceptions(self, worker: base_worker.WorkerBase):
        with self.assertRaisesRegex(AssertionError, "False is not True"):
            worker.run("assert False, 'False is not True'")

        with self.assertRaisesRegex(ValueError, "Test msg"):
            worker.run("raise ValueError('Test msg')")