def test_include( self, builder_context: PipelineBuilderContext, recommendation_type: RecommendationType, cuda_version: str ) -> None: """Test including this pipeline unit.""" builder_context.decision_type = None builder_context.recommendation_type = recommendation_type builder_context.project.runtime_environment.cuda_version = cuda_version assert builder_context.is_adviser_pipeline() assert list(TensorFlowCUDASieve.should_include(builder_context)) != []
def test_pre_run( self, context: Context, cuda_version: str, expected_tf_1_support: str, expected_tf_2_support: str ) -> None: """Test initializing the pipeline unit.""" unit = TensorFlowCUDASieve() assert unit._tf_1_cuda_support is unit._EMPTY assert unit._tf_2_cuda_support is unit._EMPTY unit._messages_logged.add("foo") assert unit._messages_logged context.project.runtime_environment.cuda_version = cuda_version with unit.assigned_context(context): unit.pre_run() assert not unit._messages_logged assert unit._tf_1_cuda_support is getattr(unit, expected_tf_1_support) assert unit._tf_2_cuda_support is getattr(unit, expected_tf_2_support)
def test_sieve_multiple(self, context: Context) -> None: """Test proper implementation of the filtering mechanism.""" context.project.runtime_environment.cuda_version = "10.0" source = Source("https://pypi.org/simple") pv_1 = PackageVersion( name="tensorflow-gpu", version="==1.12.0", develop=False, index=source, ) pv_2 = PackageVersion( name="tensorflow", version="==2.0.0", develop=False, index=source, ) pv_3 = PackageVersion( name="tensorflow", version="==1.13.0", develop=False, index=source, ) unit = TensorFlowCUDASieve() with unit.assigned_context(context): unit.pre_run() result = list(unit.run((pv for pv in (pv_1, pv_2, pv_3)))) assert len(result) == 1 assert result[0] == pv_2
def test_no_include( self, builder_context: PipelineBuilderContext, recommendation_type: RecommendationType, decision_type: DecisionType, cuda_version: str, ) -> None: """Test not including this pipeline unit.""" builder_context.decision_type = decision_type builder_context.recommendation_type = recommendation_type builder_context.project.runtime_environment.cuda_version = cuda_version assert builder_context.is_adviser_pipeline( ) or builder_context.is_dependency_monkey_pipeline() assert TensorFlowCUDASieve.should_include(builder_context) is None
def test_unknown_tensorflow(self, context: Context, package_name: str) -> None: """Test not discarding if an unknown TensorFlow release is spotted.""" context.project.runtime_environment.cuda_version = "10.0" package_version = PackageVersion( name=package_name, version="==42.30.03", develop=False, index=Source("https://pypi.org/simple"), ) unit = TensorFlowCUDASieve() with unit.assigned_context(context): unit.pre_run() result = list(unit.run((pv for pv in (package_version,)))) assert len(result) == 1 assert result[0] is package_version, "The pipeline unit should keep the unknown TensorFlow release"
def test_run_no_yield(self, context: Context, pv: Tuple[str, str], cuda_version: str) -> None: """Test discarding packages that do not conform to the support matrix. See the official docs for listing: https://www.tensorflow.org/install/source#gpu """ context.project.runtime_environment.cuda_version = cuda_version package_version = PackageVersion( name=pv[0], version=f"=={pv[1]}", develop=False, index=Source("https://pypi.org/simple"), ) unit = TensorFlowCUDASieve() with unit.assigned_context(context): unit.pre_run() result = list(unit.run((pv for pv in (package_version,)))) assert len(result) == 0
def test_run_yield(self, context: Context, package_name: str, package_version: str, cuda_version: str) -> None: """Test packages the pipeline unit yields respecting CUDA version used. See the official docs for listing: https://www.tensorflow.org/install/source#gpu """ context.project.runtime_environment.cuda_version = cuda_version package_version = PackageVersion( name=package_name, version=f"=={package_version}", develop=False, index=Source("https://pypi.org/simple"), ) unit = TensorFlowCUDASieve() with unit.assigned_context(context): unit.pre_run() result = list(unit.run((pv for pv in (package_version,)))) assert len(result) == 1 assert result[0] is package_version