def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): hook = SageMakerHook() evaluation_result = {'Image': image, 'Role': role} hook.configure_s3_resources(test_evaluation_config) self.assertEqual(test_evaluation_config, evaluation_result) mock_create_bucket.assert_called_once_with(bucket_name=bucket) mock_load_file.assert_called_once_with(path, key, bucket)
def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): mock_check_training.return_value = True mock_describe.side_effect = [ (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0), ] mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN, } log_attrs = { 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS, } mock_session.configure_mock(**attrs) mock_log_session.configure_mock(**log_attrs) mock_client.return_value = mock_session mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=True, check_interval=1) self.assertEqual(mock_describe.call_count, 3) self.assertEqual(mock_session.describe_training_job.call_count, 1)
def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = { 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN } log_attrs = { 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS, } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job_with_log( job_name=job_name, positions={}, stream_names=[], instance_count=1, state=LogState.COMPLETE, last_description={}, last_describe_job_call=0, ) self.assertEqual(response, (LogState.COMPLETE, {}, 0))
def get_sagemaker_response(self): sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) if self.print_log: if not self.log_resource_inited: self.init_log_resource(sagemaker_hook) self.state, self.last_description, self.last_describe_job_call = \ sagemaker_hook.describe_training_job_with_log(self.job_name, self.positions, self.stream_names, self.instance_count, self.state, self.last_description, self.last_describe_job_call) else: self.last_description = sagemaker_hook.describe_training_job( self.job_name) status = self.state_from_response(self.last_description) if status not in self.non_terminal_states( ) and status not in self.failed_states(): billable_time = \ (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \ self.last_description['ResourceConfig']['InstanceCount'] self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1) return self.last_description
def test_training_throws_error_when_failed_with_wait( self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [ DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_FAILED_RETURN, DESCRIBE_TRAINING_COMPLETED_RETURN, ], } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') with pytest.raises(AirflowException): hook.create_training_job( create_training_params, wait_for_completion=True, print_log=False, check_interval=1, ) assert mock_session.describe_training_job.call_count == 3
def test_describe_transform_job(self, mock_client): mock_session = mock.Mock() attrs = {'describe_transform_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_transform_job(job_name) mock_session.describe_transform_job.assert_called_once_with(TransformJobName=job_name) assert response == 'InProgress'
def test_describe_endpoint_config(self, mock_client): mock_session = mock.Mock() attrs = {'describe_endpoint_config.return_value': config_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint_config(config_name) mock_session.describe_endpoint_config.assert_called_once_with(EndpointConfigName=config_name) self.assertEqual(response, config_name)
def test_create_model(self, mock_client): mock_session = mock.Mock() attrs = {'create_model.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_model(create_model_params) mock_session.create_model.assert_called_once_with(**create_model_params) self.assertEqual(response, test_arn_return)
def test_create_tuning_job(self, mock_client, mock_check_tuning_config): mock_session = mock.Mock() attrs = {'create_hyper_parameter_tuning_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False) mock_session.create_hyper_parameter_tuning_job.assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return)
def test_update_endpoint(self, mock_client): mock_session = mock.Mock() attrs = {'update_endpoint.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.update_endpoint(update_endpoint_params, wait_for_completion=False) mock_session.update_endpoint.assert_called_once_with(**update_endpoint_params) assert response == test_arn_return
def test_create_transform_job_fs(self, mock_client): mock_session = mock.Mock() attrs = {'create_transform_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_transform_job(create_transform_params_fs, wait_for_completion=False) mock_session.create_transform_job.assert_called_once_with(**create_transform_params_fs) assert response == test_arn_return
def test_create_endpoint_config(self, mock_client): mock_session = mock.Mock() attrs = {'create_endpoint_config.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_endpoint_config(create_endpoint_config_params) mock_session.create_endpoint_config.assert_called_once_with(**create_endpoint_config_params) assert response == test_arn_return
def test_describe_model(self, mock_client): mock_session = mock.Mock() attrs = {'describe_model.return_value': model_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_model(model_name) mock_session.describe_model.assert_called_once_with(ModelName=model_name) assert response == model_name
def test_describe_endpoint(self, mock_client): mock_session = mock.Mock() attrs = {'describe_endpoint.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint(endpoint_name) mock_session.describe_endpoint.assert_called_once_with(EndpointName=endpoint_name) assert response == 'InProgress'
def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True, True] mock_check_key.side_effect = [False, True, False] mock_check_prefix.side_effect = [False, True, True] self.assertRaises(AirflowException, hook.check_s3_url, data_url) self.assertRaises(AirflowException, hook.check_s3_url, data_url) self.assertEqual(hook.check_s3_url(data_url), True) self.assertEqual(hook.check_s3_url(data_url), True)
def test_describe_tuning_job(self, mock_client): mock_session = mock.Mock() attrs = {'describe_hyper_parameter_tuning_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_tuning_job(job_name) mock_session.describe_hyper_parameter_tuning_job.\ assert_called_once_with(HyperParameterTuningJobName=job_name) self.assertEqual(response, 'InProgress')
def test_create_training_job(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_training_job( create_training_params, wait_for_completion=False, print_log=False ) mock_session.create_training_job.assert_called_once_with(**create_training_params) assert response == test_arn_return
def preprocess_config(self): self.log.info( 'Preprocessing the config and doing required s3_operations') self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info('After preprocessing the config is:\n {}'.format( json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': '))))
def preprocess_config(self): """Process the config into a usable form.""" self.log.info( 'Preprocessing the config and doing required s3_operations' ) self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( "After preprocessing the config is:\n %s", json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")), )
def get_hook(self) -> SageMakerHook: """Get SageMakerHook""" if self.hook: return self.hook self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) return self.hook
def test_training_ends_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=False, check_interval=1) self.assertEqual(mock_session.describe_training_job.call_count, 4)
def init_log_resource(self, hook: SageMakerHook) -> None: """Set tailing LogState for associated training job.""" description = hook.describe_training_job(self.job_name) self.instance_count = description['ResourceConfig']['InstanceCount'] status = description['TrainingJobStatus'] job_already_completed = status not in self.non_terminal_states() self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING self.last_description = description self.last_describe_job_call = time.monotonic() self.log_resource_inited = True
def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True, True] mock_check_key.side_effect = [False, True, False] mock_check_prefix.side_effect = [False, True, True] with pytest.raises(AirflowException): hook.check_s3_url(data_url) with pytest.raises(AirflowException): hook.check_s3_url(data_url) assert hook.check_s3_url(data_url) is True assert hook.check_s3_url(data_url) is True
def test_check_valid_training(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() hook.check_training_config(create_training_params) mock_check_url.assert_called_once_with(data_url) # InputDataConfig is optional, verify if check succeeds without InputDataConfig create_training_params_no_inputdataconfig = create_training_params.copy() create_training_params_no_inputdataconfig.pop("InputDataConfig") hook.check_training_config(create_training_params_no_inputdataconfig)
def hook(self): """Return SageMakerHook""" return SageMakerHook(aws_conn_id=self.aws_conn_id)
class SageMakerBaseOperator(BaseOperator): """ This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) :type config: dict :param aws_conn_id: The AWS connection ID to use. :type aws_conn_id: str """ template_fields = ['config'] template_ext = () ui_color = '#ededed' integer_fields = [] # type: Iterable[Iterable[str]] @apply_defaults def __init__(self, config, aws_conn_id='aws_default', *args, **kwargs): super().__init__(*args, **kwargs) self.aws_conn_id = aws_conn_id self.config = config self.hook = None def parse_integer(self, config, field): """Recursive method for parsing string fields holding integer values to integers.""" if len(field) == 1: if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head = field[0] if head in config: config[head] = int(config[head]) return if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head, tail = field[0], field[1:] if head in config: self.parse_integer(config[head], tail) return def parse_config_integers(self): """ Parse the integer fields of training config to integers in case the config is rendered by Jinja and all fields are str. """ for field in self.integer_fields: self.parse_integer(self.config, field) def expand_role(self): """Placeholder for calling boto3's expand_role(), which expands an IAM role name into an ARN.""" def preprocess_config(self): """Process the config into a usable form.""" self.log.info( 'Preprocessing the config and doing required s3_operations') self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( "After preprocessing the config is:\n %s", json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")), ) def execute(self, context): raise NotImplementedError('Please implement execute() in sub class!')
def get_hook(self): """Get SageMakerHook""" if not self.hook: self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) return self.hook
def test_conn(self, mock_get_client_type): hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id')
def test_check_valid_tuning(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() hook.check_tuning_config(create_tuning_params) mock_check_url.assert_called_once_with(data_url)
def test_multi_stream_iter(self, mock_log_stream): event = {'timestamp': 1} mock_log_stream.side_effect = [iter([event]), iter([]), None] hook = SageMakerHook() event_iter = hook.multi_stream_iter('log', [None, None, None]) self.assertEqual(next(event_iter), (0, event))