コード例 #1
0
    def test_switch_back_modified_path(self):
        new_path = os.pathsep.join(("/bar/foo", "quux/baz"))
        os.environ["PATH"] = os.pathsep.join(
            (new_path, self._original_os_path))
        os.environ["PYTHONPATH"] = os.pathsep.join(
            (new_path, self._original_os_pythonpath))

        _tensorflow_magics._tensorflow_version("2.x")
        _tensorflow_magics._tensorflow_version("1.x")

        tf1_path = _tensorflow_magics._VERSIONS["1"].path

        self._assert_starts_with(sys.path[0], tf1_path)
        self.assertEqual(sys.path, self._original_sys_path)

        (path_head, path_tail) = os.environ["PATH"].split(os.pathsep, 1)
        self._assert_starts_with(path_head, tf1_path)
        self._assert_ends_with(path_head, "bin")
        self.assertEqual(
            path_tail,
            os.pathsep.join([new_path] +
                            self._original_os_path.split(os.pathsep, 1)[1:]))
        (pythonpath_head,
         pythonpath_tail) = os.environ["PYTHONPATH"].split(os.pathsep, 1)
        self._assert_starts_with(pythonpath_head, tf1_path)
        self.assertEqual(
            pythonpath_tail,
            os.pathsep.join(
                [new_path] +
                self._original_os_pythonpath.split(os.pathsep, 1)[1:]))
コード例 #2
0
    def test_switch_back_default_path(self):
        _tensorflow_magics._tensorflow_version("2.x")
        _tensorflow_magics._tensorflow_version("1.x")

        self.assertEqual(sys.path, self._original_sys_path)
        self.assertEqual(os.environ["PATH"], self._original_os_path)
        self.assertEqual(os.environ["PYTHONPATH"],
                         self._original_os_pythonpath)
コード例 #3
0
  def test_switch_back_no_paths(self):
    os.environ.pop("PATH", None)
    os.environ.pop("PYTHONPATH", None)

    _tensorflow_magics._tensorflow_version("2.x")
    _tensorflow_magics._tensorflow_version("1.x")

    self.assertEqual(sys.path, self._original_sys_path)
    self.assertEqual(os.environ.get("PATH", ""), "")
    self.assertEqual(os.environ.get("PYTHONPATH", ""), "")
コード例 #4
0
    def test_switch_1x_to_2x_default_path(self):
        _tensorflow_magics._tensorflow_version("2.x")

        self.assertEqual(sys.path, self._original_sys_path[1:])
        self.assertEqual(
            os.environ["PATH"],
            os.pathsep.join(self._original_os_path.split(os.pathsep, 1)[1:]))
        self.assertEqual(
            os.environ["PYTHONPATH"],
            os.pathsep.join(
                self._original_os_pythonpath.split(os.pathsep, 1)[1:]))
コード例 #5
0
    def test_switch_back_with_paths(self):
        original_pythonpath = os.pathsep.join(("/foo/bar", "/baz/quux"))
        original_os_path = os.pathsep.join(("/bar/foo", "/quux/baz"))
        os.environ["PYTHONPATH"] = original_pythonpath
        os.environ["PATH"] = original_os_path

        _tensorflow_magics._tensorflow_version("1.x")
        _tensorflow_magics._tensorflow_version("2.x")

        self.assertEqual(sys.path, self._original_sys_path)
        self.assertEqual(os.environ["PATH"], original_os_path)
        self.assertEqual(os.environ["PYTHONPATH"], original_pythonpath)
コード例 #6
0
  def test_tpu_version_switch(self):
    _tensorflow_magics._get_tf_version = mock.Mock(
        return_value=_tensorflow_magics._VERSIONS["2"].version)
    _tensorflow_magics._tensorflow_version("2.x")
    _tensorflow_magics._get_tf_version = mock.Mock(
        return_value=_tensorflow_magics._VERSIONS["1"].version)
    _tensorflow_magics._tensorflow_version("1.x")

    expected = "http://0.0.0.0:8475/requestversion/{}"
    calls = [
        mock.call(expected.format(_tensorflow_magics._VERSIONS["2"].version)),
        mock.call(expected.format(_tensorflow_magics._VERSIONS["1"].version)),
    ]
    self.assertEqual(requests.post.mock_calls, calls)
コード例 #7
0
    def test_handle_tf_install_after_setting_version(self):
        _tensorflow_magics._tensorflow_version("1.x")
        _tensorflow_magics._handle_tf_install()

        # _handle_tf_install should be a no-op because magic was invoked.
        self.assertTrue(_tensorflow_magics._explicitly_set())
        tf1_path = _tensorflow_magics._VERSIONS["1"].path

        self._assert_starts_with(sys.path[0], tf1_path)
        path_head = os.environ["PATH"].split(os.pathsep)[0]
        self._assert_starts_with(path_head, tf1_path)
        self._assert_ends_with(path_head, "bin")
        self._assert_starts_with(os.environ["PYTHONPATH"].split(os.pathsep)[0],
                                 tf1_path)
コード例 #8
0
  def test_switch_1x_to_2x_no_paths(self):
    os.environ.pop("PATH", None)
    os.environ.pop("PYTHONPATH", None)
    tf2_path = _tensorflow_magics._VERSIONS["2"].path

    _tensorflow_magics._tensorflow_version("2.x")

    self.assertEqual(sys.path[1:], self._original_sys_path)
    self._assert_starts_with(sys.path[0], tf2_path)

    self._assert_starts_with(os.environ["PYTHONPATH"], tf2_path)
    self._assert_len(os.environ["PYTHONPATH"].split(os.pathsep), 1)

    os_path_head, os_path_tail = os.environ["PATH"].split(os.pathsep, 1)
    self._assert_starts_with(os_path_head, tf2_path)
    self._assert_ends_with(os_path_head, "bin")
    self.assertEqual(os_path_tail, "")
コード例 #9
0
    def test_switch_1x_to_2x_modified_path(self):
        new_path = os.pathsep.join(("/bar/foo", "quux/baz"))
        os.environ["PATH"] = os.pathsep.join(
            (new_path, self._original_os_path))
        os.environ["PYTHONPATH"] = os.pathsep.join(
            (new_path, self._original_os_pythonpath))

        _tensorflow_magics._tensorflow_version("2.x")

        self.assertEqual(sys.path, self._original_sys_path[1:])
        self.assertEqual(
            os.environ["PATH"],
            os.pathsep.join([new_path] +
                            self._original_os_path.split(os.pathsep, 1)[1:]))
        self.assertEqual(
            os.environ["PYTHONPATH"],
            os.pathsep.join(
                [new_path] +
                self._original_os_pythonpath.split(os.pathsep, 1)[1:]))
コード例 #10
0
  def test_switch_1x_to_2x_existing_paths(self):
    original_pythonpath = os.pathsep.join(("/foo/bar", "/baz/quux"))
    original_os_path = os.pathsep.join(("/bar/foo", "/quux/baz"))
    os.environ["PYTHONPATH"] = original_pythonpath
    os.environ["PATH"] = original_os_path
    tf2_path = _tensorflow_magics._VERSIONS["2"].path

    _tensorflow_magics._tensorflow_version("2.x")

    self.assertEqual(sys.path[1:], self._original_sys_path)
    self._assert_starts_with(sys.path[0], tf2_path)

    (pythonpath_head,
     pythonpath_tail) = os.environ["PYTHONPATH"].split(os.pathsep, 1)
    self._assert_starts_with(pythonpath_head, tf2_path)
    self.assertEqual(pythonpath_tail, original_pythonpath)

    (os_path_head, os_path_tail) = os.environ["PATH"].split(os.pathsep, 1)
    self._assert_starts_with(os_path_head, tf2_path)
    self._assert_ends_with(os_path_head, "bin")
    self.assertEqual(os_path_tail, original_os_path)
コード例 #11
0
    def test_tpu_version_switch(self):
        _tensorflow_magics._get_tf_version = mock.Mock(
            return_value=_tensorflow_magics._VERSIONS["1"].version)
        _tensorflow_magics._tensorflow_version("1.x")
        _tensorflow_magics._get_tf_version = mock.Mock(
            return_value=_tensorflow_magics._VERSIONS["2"].version)
        _tensorflow_magics._tensorflow_version("2.x")

        expected = "http://0.0.0.0:8475/requestversion/{}"
        # TODO(b/151765674): Remove this comment and initial call once TPU defaults
        # to TF1.
        # We expect one additional call from initialization as the TPU VM defaults
        # to TF1 and we default to TF2.
        calls = [
            mock.call(
                expected.format(_tensorflow_magics._VERSIONS["2"].version)),
            mock.call(
                expected.format(_tensorflow_magics._VERSIONS["1"].version)),
            mock.call(
                expected.format(_tensorflow_magics._VERSIONS["2"].version)),
        ]
        self.assertEqual(requests.post.mock_calls, calls)
コード例 #12
0
  def test_switch_back_does_not_import(self):
    _tensorflow_magics._tensorflow_version("2.x")
    _tensorflow_magics._tensorflow_version("1.x")

    self.assertNotIn("tensorflow", sys.modules)