def test_algorithm_enable_network_isolation_no_product_id(session): session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", role="SageMakerRole", train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is False
def test_algorithm_enable_network_isolation_with_product_id(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["ProductId"] = "some-product-id" session.sagemaker_client.describe_algorithm = Mock(return_value=response) estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", role="SageMakerRole", train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is True