def execute(self, context): sagemaker = SageMakerHook( sagemaker_conn_id=self.sagemaker_conn_id, use_db_config=self.use_db_config, region_name=self.region_name, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time ) if self.model_config: self.log.info( "Creating SageMaker Model %s for transform job" % self.model_config['ModelName'] ) sagemaker.create_model(self.model_config) self.log.info( "Creating SageMaker transform Job %s." % self.transform_job_config['TransformJobName'] ) response = sagemaker.create_transform_job( self.transform_job_config, wait_for_completion=self.wait_for_completion) if not response['ResponseMetadata']['HTTPStatusCode'] \ == 200: raise AirflowException( 'Sagemaker transform Job creation failed: %s' % response) else: return response
def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): mock_check_training.return_value = True mock_describe.side_effect = \ [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RETURN, 0)] mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPELETED_RETURN } log_attrs = {'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS } mock_session.configure_mock(**attrs) mock_log_session.configure_mock(**log_attrs) mock_client.return_value = mock_session mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=True, check_interval=1) self.assertEqual(mock_describe.call_count, 3) self.assertEqual(mock_session.describe_training_job.call_count, 1)
def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): hook = SageMakerHook() evaluation_result = {'Image': image, 'Role': role} hook.configure_s3_resources(test_evaluation_config) self.assertEqual(test_evaluation_config, evaluation_result) mock_create_bucket.assert_called_once_with(bucket_name=bucket) mock_load_file.assert_called_once_with(path, key, bucket)
def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = { 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPELETED_RETURN } log_attrs = { 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job_with_log( job_name=job_name, positions={}, stream_names=[], instance_count=1, state=LogState.COMPLETE, last_description={}, last_describe_job_call=0) self.assertEqual(response, (LogState.COMPLETE, {}, 0))
def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): mock_check_training.return_value = True mock_describe.side_effect = \ [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RETURN, 0)] mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPELETED_RETURN } log_attrs = { 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS } mock_session.configure_mock(**attrs) mock_log_session.configure_mock(**log_attrs) mock_client.return_value = mock_session mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=True, check_interval=1) self.assertEqual(mock_describe.call_count, 3) self.assertEqual(mock_session.describe_training_job.call_count, 1)
def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = {'describe_training_job.return_value': DESCRIBE_TRAINING_COMPELETED_RETURN } log_attrs = {'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS } mock_time.return_value = 50 mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job_with_log(job_name=job_name, positions={}, stream_names=[], instance_count=1, state=LogState.WAIT_IN_PROGRESS, last_description={}, last_describe_job_call=0) self.assertEqual(response, (LogState.JOB_COMPLETE, {}, 50))
def get_sagemaker_response(self): sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) if self.print_log: if not self.log_resource_inited: self.init_log_resource(sagemaker_hook) self.state, self.last_description, self.last_describe_job_call = \ sagemaker_hook.describe_training_job_with_log(self.job_name, self.positions, self.stream_names, self.instance_count, self.state, self.last_description, self.last_describe_job_call) else: self.last_description = sagemaker_hook.describe_training_job( self.job_name) status = self.state_from_response(self.last_description) if status not in self.non_terminal_states( ) and status not in self.failed_states(): billable_time = \ (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \ self.last_description['ResourceConfig']['InstanceCount'] self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1) return self.last_description
def get_sagemaker_response(self): sagemaker = SageMakerHook( aws_conn_id=self.aws_conn_id, region_name=self.region_name ) self.log.info('Poking Sagemaker Training Job %s', self.job_name) return sagemaker.describe_training_job(self.job_name)
def get_sagemaker_response(self): sagemaker = SageMakerHook( aws_conn_id=self.aws_conn_id, region_name=self.region_name ) self.log.info('Poking Sagemaker Tuning Job %s', self.job_name) return sagemaker.describe_tuning_job(self.job_name)
def test_check_for_url(self, mock_check_bucket, mock_check_key, mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True] mock_check_key.side_effect = [False, True] self.assertRaises(AirflowException, hook.check_for_url, data_url) self.assertRaises(AirflowException, hook.check_for_url, data_url) self.assertEqual(hook.check_for_url(data_url), True)
def test_create_model(self, mock_client): mock_session = mock.Mock() attrs = {'create_model.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_model(create_model_params) mock_session.create_model.assert_called_once_with(**create_model_params) self.assertEqual(response, test_arn_return)
def test_describe_endpoint(self, mock_client): mock_session = mock.Mock() attrs = {'describe_endpoint.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint(endpoint_name) mock_session.describe_endpoint.\ assert_called_once_with(EndpointName=endpoint_name) self.assertEqual(response, 'InProgress')
def test_create_endpoint_config(self, mock_client): mock_session = mock.Mock() attrs = {'create_endpoint_config.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_endpoint_config(create_endpoint_config_params) mock_session.create_endpoint_config\ .assert_called_once_with(**create_endpoint_config_params) self.assertEqual(response, test_arn_return)
def test_describe_model(self, mock_client): mock_session = mock.Mock() attrs = {'describe_model.return_value': model_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_model(model_name) mock_session.describe_model.\ assert_called_once_with(ModelName=model_name) self.assertEqual(response, model_name)
def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): hook = SageMakerHook() evaluation_result = { 'Image': image, 'Role': role } hook.configure_s3_resources(test_evaluation_config) self.assertEqual(test_evaluation_config, evaluation_result) mock_create_bucket.assert_called_once_with(bucket_name=bucket) mock_load_file.assert_called_once_with(path, key, bucket)
def test_describe_training_job(self, mock_client): mock_session = mock.Mock() attrs = {'describe_training_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job(job_name) mock_session.describe_training_job.\ assert_called_once_with(TrainingJobName=job_name) self.assertEqual(response, 'InProgress')
def test_describe_transform_job(self, mock_client): mock_session = mock.Mock() attrs = {'describe_transform_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_transform_job(job_name) mock_session.describe_transform_job.\ assert_called_once_with(TransformJobName=job_name) self.assertEqual(response, 'InProgress')
def test_describe_endpoint_config(self, mock_client): mock_session = mock.Mock() attrs = {'describe_endpoint_config.return_value': config_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint_config(config_name) mock_session.describe_endpoint_config.\ assert_called_once_with(EndpointConfigName=config_name) self.assertEqual(response, config_name)
def test_update_endpoint(self, mock_client): mock_session = mock.Mock() attrs = {'update_endpoint.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.update_endpoint(update_endpoint_params, wait_for_completion=False) mock_session.update_endpoint\ .assert_called_once_with(**update_endpoint_params) self.assertEqual(response, test_arn_return)
def test_create_transform_job(self, mock_client, mock_check_url): mock_check_url.return_value = True mock_session = mock.Mock() attrs = {'create_transform_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_transform_job(create_transform_params, wait_for_completion=False) mock_session.create_transform_job.assert_called_once_with( **create_transform_params) self.assertEqual(response, test_arn_return)
def test_create_tuning_job(self, mock_client, mock_check_tuning): mock_session = mock.Mock() attrs = {'create_hyper_parameter_tuning_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False) mock_session.create_hyper_parameter_tuning_job.\ assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return)
def test_create_tuning_job(self, mock_client, mock_check_tuning): mock_session = mock.Mock() attrs = { 'create_hyper_parameter_tuning_job.return_value': test_arn_return } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') response = hook.create_tuning_job(create_tuning_params) mock_session.create_hyper_parameter_tuning_job.\ assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return)
def test_list_tuning_job(self, mock_client): mock_session = mock.Mock() attrs = {'list_hyper_parameter_tuning_job.return_value': test_list_tuning_job_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') response = hook.list_tuning_job(name_contains=job_name, status_equals='InProgress') mock_session.list_hyper_parameter_tuning_job. \ assert_called_once_with(NameContains=job_name, StatusEquals='InProgress') self.assertEqual(response, test_list_tuning_job_return)
def test_list_training_job(self, mock_client): mock_session = mock.Mock() attrs = { 'list_training_jobs.return_value': test_list_training_job_return } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') response = hook.list_training_job(name_contains=job_name, status_equals='InProgress') mock_session.list_training_jobs. \ assert_called_once_with(NameContains=job_name, StatusEquals='InProgress') self.assertEqual(response, test_list_training_job_return)
def preprocess_config(self): self.log.info( 'Preprocessing the config and doing required s3_operations' ) self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( 'After preprocessing the config is:\n {}'.format( json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': '))) )
def test_create_training_job_db_config(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', use_db_config=True) response = hook_use_db_config.create_training_job(create_training_params, wait_for_completion=False) updated_config = copy.deepcopy(create_training_params) updated_config.update(db_config) mock_session.create_training_job.assert_called_once_with(**updated_config) self.assertEqual(response, test_arn_return)
def test_create_transform_job_db_config(self, mock_client, mock_check_url): mock_check_url.return_value = True mock_session = mock.Mock() attrs = {'create_transform_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook_use_db_config = SageMakerHook( sagemaker_conn_id='sagemaker_test_conn_id', use_db_config=True) response = hook_use_db_config.create_transform_job( create_transform_params, wait_for_completion=False) updated_config = copy.deepcopy(create_transform_params) updated_config.update(db_config) mock_session.create_transform_job.assert_called_once_with( **updated_config) self.assertEqual(response, test_arn_return)
def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True, True] mock_check_key.side_effect = [False, True, False] mock_check_prefix.side_effect = [False, True, True] self.assertRaises(AirflowException, hook.check_s3_url, data_url) self.assertRaises(AirflowException, hook.check_s3_url, data_url) self.assertEqual(hook.check_s3_url(data_url), True) self.assertEqual(hook.check_s3_url(data_url), True)
def test_training_ends_with_wait_on(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_STOPPED_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True) self.assertEqual(mock_session.describe_training_job.call_count, 4)
def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning): mock_check_tuning.return_value = True mock_session = mock.Mock() attrs = { 'create_hyper_parameter_tuning_job.return_value': test_arn_return } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', use_db_config=True) response = hook.create_tuning_job(create_tuning_params) updated_config = copy.deepcopy(create_tuning_params) updated_config.update(db_config) mock_session.create_hyper_parameter_tuning_job. \ assert_called_once_with(**updated_config) self.assertEqual(response, test_arn_return)
def test_training_ends_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=False, check_interval=1) self.assertEqual(mock_session.describe_training_job.call_count, 4)
def execute(self, context): sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id, use_db_config=self.use_db_config, region_name=self.region_name, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time) self.log.info("Creating SageMaker Training Job %s." % self.training_job_config['TrainingJobName']) response = sagemaker.create_training_job( self.training_job_config, wait_for_completion=self.wait_for_completion) if not response['ResponseMetadata']['HTTPStatusCode'] \ == 200: raise AirflowException( 'Sagemaker Training Job creation failed: %s' % response) else: return response
def get_sagemaker_response(self): sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) if self.print_log: if not self.log_resource_inited: self.init_log_resource(sagemaker_hook) self.state, self.last_description, self.last_describe_job_call = \ sagemaker_hook.describe_training_job_with_log(self.job_name, self.positions, self.stream_names, self.instance_count, self.state, self.last_description, self.last_describe_job_call) else: self.last_description = sagemaker_hook.describe_training_job(self.job_name) status = self.state_from_response(self.last_description) if status not in self.non_terminal_states() and status not in self.failed_states(): billable_time = \ (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \ self.last_description['ResourceConfig']['InstanceCount'] self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1) return self.last_description
def execute(self, context): sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id, region_name=self.region_name, use_db_config=self.use_db_config, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time ) self.log.info( "Creating SageMaker Hyper Parameter Tunning Job %s" % self.tuning_job_config['HyperParameterTuningJobName'] ) response = sagemaker.create_tuning_job( self.tuning_job_config, wait_for_completion=self.wait_for_completion ) if not response['ResponseMetadata']['HTTPStatusCode'] \ == 200: raise AirflowException( "Sagemaker Tuning Job creation failed: %s" % response) else: return response
def test_check_valid_training(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() hook.check_training_config(create_training_params) mock_check_url.assert_called_once_with(data_url) # InputDataConfig is optional, verify if check succeeds without InputDataConfig create_training_params_no_inputdataconfig = create_training_params.copy() create_training_params_no_inputdataconfig.pop("InputDataConfig") hook.check_training_config(create_training_params_no_inputdataconfig)
def test_training_throws_error_when_failed_with_wait_on( self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [ DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_STOPPED_RETURN, DESCRIBE_TRAINING_FAILED_RETURN ] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') self.assertRaises(AirflowException, hook.create_training_job, create_training_params, wait_for_completion=True) self.assertEqual(mock_session.describe_training_job.call_count, 4)
def get_sagemaker_response(self): sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id) self.log.info('Poking Sagemaker Transform Job %s', self.job_name) return sagemaker.describe_transform_job(self.job_name)
def get_sagemaker_response(self): sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id) self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name) return sagemaker.describe_endpoint(self.endpoint_name)
def test_multi_stream_iter(self, mock_log_stream): event = {'timestamp': 1} mock_log_stream.side_effect = [iter([event]), iter([]), None] hook = SageMakerHook() event_iter = hook.multi_stream_iter('log', [None, None, None]) self.assertEqual(next(event_iter), (0, event))
class SageMakerBaseOperator(BaseOperator): """ This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) :type config: dict :param aws_conn_id: The AWS connection ID to use. :type aws_conn_id: str """ template_fields = ['config'] template_ext = () ui_color = '#ededed' integer_fields = [] # type: Iterable[Iterable[str]] @apply_defaults def __init__(self, config, aws_conn_id='aws_default', *args, **kwargs): super().__init__(*args, **kwargs) self.aws_conn_id = aws_conn_id self.config = config self.hook = None def parse_integer(self, config, field): if len(field) == 1: if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head = field[0] if head in config: config[head] = int(config[head]) return if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head, tail = field[0], field[1:] if head in config: self.parse_integer(config[head], tail) return def parse_config_integers(self): # Parse the integer fields of training config to integers # in case the config is rendered by Jinja and all fields are str for field in self.integer_fields: self.parse_integer(self.config, field) def expand_role(self): pass def preprocess_config(self): self.log.info( 'Preprocessing the config and doing required s3_operations' ) self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( 'After preprocessing the config is:\n {}'.format( json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': '))) ) def execute(self, context): raise NotImplementedError('Please implement execute() in sub class!')
def test_check_valid_tuning(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() hook.check_tuning_config(create_tuning_params) mock_check_url.assert_called_once_with(data_url)
class SageMakerBaseOperator(BaseOperator): """ This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) :type config: dict :param aws_conn_id: The AWS connection ID to use. :type aws_conn_id: str """ template_fields = ['config'] template_ext = () ui_color = '#ededed' integer_fields = [] @apply_defaults def __init__(self, config, aws_conn_id='aws_default', *args, **kwargs): super(SageMakerBaseOperator, self).__init__(*args, **kwargs) self.aws_conn_id = aws_conn_id self.config = config self.hook = None def parse_integer(self, config, field): if len(field) == 1: if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head = field[0] if head in config: config[head] = int(config[head]) return if isinstance(config, list): for sub_config in config: self.parse_integer(sub_config, field) return head, tail = field[0], field[1:] if head in config: self.parse_integer(config[head], tail) return def parse_config_integers(self): # Parse the integer fields of training config to integers # in case the config is rendered by Jinja and all fields are str for field in self.integer_fields: self.parse_integer(self.config, field) def expand_role(self): pass def preprocess_config(self): self.log.info( 'Preprocessing the config and doing required s3_operations' ) self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( 'After preprocessing the config is:\n {}'.format( json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': '))) ) def execute(self, context): raise NotImplementedError('Please implement execute() in sub class!')
def test_conn(self, mock_get_client_type): hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id')
def test_conn(self, mock_get_client): hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', region_name='us-east-1') self.assertEqual(hook.sagemaker_conn_id, 'sagemaker_test_conn_id') mock_get_client.assert_called_once_with('sagemaker', region_name='us-east-1')