def test_algorithm_create_transformer_without_completed_training_job(session): session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", role="SageMakerRole", train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) with pytest.raises(RuntimeError) as error: estimator.transformer(instance_count=1, instance_type="ml.m4.xlarge") assert "No finished training job found associated with this estimator" in str(error)
def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['ProductId'] = 'some-product-id' sagemaker_session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( algorithm_arn= 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, sagemaker_session=sagemaker_session, ) estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') assert transformer.env is None
def test_algorithm_create_transformer(create_model, sagemaker_session): sagemaker_session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( algorithm_arn= 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, sagemaker_session=sagemaker_session, ) estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') assert isinstance(transformer, Transformer) create_model.assert_called() assert transformer.model_name == 'my-model'