コード例 #1
0
ファイル: test_tf_cuda.py プロジェクト: tlegen-k/adviser
    def test_sieve_multiple(self, context: Context) -> None:
        """Test proper implementation of the filtering mechanism."""
        context.project.runtime_environment.cuda_version = "10.0"
        source = Source("https://pypi.org/simple")
        pv_1 = PackageVersion(
            name="tensorflow-gpu",
            version="==1.12.0",
            develop=False,
            index=source,
        )
        pv_2 = PackageVersion(
            name="tensorflow",
            version="==2.0.0",
            develop=False,
            index=source,
        )
        pv_3 = PackageVersion(
            name="tensorflow",
            version="==1.13.0",
            develop=False,
            index=source,
        )

        unit = TensorFlowCUDASieve()
        with unit.assigned_context(context):
            unit.pre_run()
            result = list(unit.run((pv for pv in (pv_1, pv_2, pv_3))))
            assert len(result) == 1
            assert result[0] == pv_2
コード例 #2
0
ファイル: test_tf_cuda.py プロジェクト: tlegen-k/adviser
    def test_unknown_tensorflow(self, context: Context, package_name: str) -> None:
        """Test not discarding if an unknown TensorFlow release is spotted."""
        context.project.runtime_environment.cuda_version = "10.0"
        package_version = PackageVersion(
            name=package_name,
            version="==42.30.03",
            develop=False,
            index=Source("https://pypi.org/simple"),
        )

        unit = TensorFlowCUDASieve()
        with unit.assigned_context(context):
            unit.pre_run()
            result = list(unit.run((pv for pv in (package_version,))))
            assert len(result) == 1
            assert result[0] is package_version, "The pipeline unit should keep the unknown TensorFlow release"
コード例 #3
0
ファイル: test_tf_cuda.py プロジェクト: tlegen-k/adviser
    def test_run_no_yield(self, context: Context, pv: Tuple[str, str], cuda_version: str) -> None:
        """Test discarding packages that do not conform to the support matrix.

        See the official docs for listing:
           https://www.tensorflow.org/install/source#gpu
        """
        context.project.runtime_environment.cuda_version = cuda_version
        package_version = PackageVersion(
            name=pv[0],
            version=f"=={pv[1]}",
            develop=False,
            index=Source("https://pypi.org/simple"),
        )

        unit = TensorFlowCUDASieve()
        with unit.assigned_context(context):
            unit.pre_run()
            result = list(unit.run((pv for pv in (package_version,))))
            assert len(result) == 0
コード例 #4
0
ファイル: test_tf_cuda.py プロジェクト: tlegen-k/adviser
    def test_run_yield(self, context: Context, package_name: str, package_version: str, cuda_version: str) -> None:
        """Test packages the pipeline unit yields respecting CUDA version used.

        See the official docs for listing:
           https://www.tensorflow.org/install/source#gpu
        """
        context.project.runtime_environment.cuda_version = cuda_version
        package_version = PackageVersion(
            name=package_name,
            version=f"=={package_version}",
            develop=False,
            index=Source("https://pypi.org/simple"),
        )

        unit = TensorFlowCUDASieve()
        with unit.assigned_context(context):
            unit.pre_run()
            result = list(unit.run((pv for pv in (package_version,))))
            assert len(result) == 1
            assert result[0] is package_version
コード例 #5
0
ファイル: test_tf_cuda.py プロジェクト: tlegen-k/adviser
    def test_pre_run(
        self, context: Context, cuda_version: str, expected_tf_1_support: str, expected_tf_2_support: str
    ) -> None:
        """Test initializing the pipeline unit."""
        unit = TensorFlowCUDASieve()

        assert unit._tf_1_cuda_support is unit._EMPTY
        assert unit._tf_2_cuda_support is unit._EMPTY

        unit._messages_logged.add("foo")
        assert unit._messages_logged

        context.project.runtime_environment.cuda_version = cuda_version

        with unit.assigned_context(context):
            unit.pre_run()

        assert not unit._messages_logged
        assert unit._tf_1_cuda_support is getattr(unit, expected_tf_1_support)
        assert unit._tf_2_cuda_support is getattr(unit, expected_tf_2_support)