Пример #1
0
 def test_configure_s3_resources(self, mock_load_file, mock_create_bucket):
     hook = SageMakerHook()
     evaluation_result = {'Image': image, 'Role': role}
     hook.configure_s3_resources(test_evaluation_config)
     self.assertEqual(test_evaluation_config, evaluation_result)
     mock_create_bucket.assert_called_once_with(bucket_name=bucket)
     mock_load_file.assert_called_once_with(path, key, bucket)
Пример #2
0
 def test_training_with_logs(self, mock_describe, mock_client,
                             mock_log_client, mock_check_training):
     mock_check_training.return_value = True
     mock_describe.side_effect = [
         (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN,
          0),
         (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0),
         (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0),
     ]
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.return_value':
         DESCRIBE_TRAINING_COMPLETED_RETURN,
     }
     log_attrs = {
         'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS,
         'get_log_events.side_effect': STREAM_LOG_EVENTS,
     }
     mock_session.configure_mock(**attrs)
     mock_log_session.configure_mock(**log_attrs)
     mock_client.return_value = mock_session
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params,
                              wait_for_completion=True,
                              print_log=True,
                              check_interval=1)
     self.assertEqual(mock_describe.call_count, 3)
     self.assertEqual(mock_session.describe_training_job.call_count, 1)
Пример #3
0
 def test_describe_training_job_with_logs_complete(self, mock_client,
                                                   mock_log_client):
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {
         'describe_training_job.return_value':
         DESCRIBE_TRAINING_COMPLETED_RETURN
     }
     log_attrs = {
         'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS,
         'get_log_events.side_effect': STREAM_LOG_EVENTS,
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     mock_log_session.configure_mock(**log_attrs)
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_training_job_with_log(
         job_name=job_name,
         positions={},
         stream_names=[],
         instance_count=1,
         state=LogState.COMPLETE,
         last_description={},
         last_describe_job_call=0,
     )
     self.assertEqual(response, (LogState.COMPLETE, {}, 0))
Пример #4
0
    def get_sagemaker_response(self):
        sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
        if self.print_log:
            if not self.log_resource_inited:
                self.init_log_resource(sagemaker_hook)
            self.state, self.last_description, self.last_describe_job_call = \
                sagemaker_hook.describe_training_job_with_log(self.job_name,
                                                              self.positions, self.stream_names,
                                                              self.instance_count, self.state,
                                                              self.last_description,
                                                              self.last_describe_job_call)
        else:
            self.last_description = sagemaker_hook.describe_training_job(
                self.job_name)

        status = self.state_from_response(self.last_description)
        if status not in self.non_terminal_states(
        ) and status not in self.failed_states():
            billable_time = \
                (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
                self.last_description['ResourceConfig']['InstanceCount']
            self.log.info('Billable seconds: %s',
                          int(billable_time.total_seconds()) + 1)

        return self.last_description
Пример #5
0
 def test_training_throws_error_when_failed_with_wait(
         self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.side_effect': [
             DESCRIBE_TRAINING_INPROGRESS_RETURN,
             DESCRIBE_TRAINING_STOPPING_RETURN,
             DESCRIBE_TRAINING_FAILED_RETURN,
             DESCRIBE_TRAINING_COMPLETED_RETURN,
         ],
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     with pytest.raises(AirflowException):
         hook.create_training_job(
             create_training_params,
             wait_for_completion=True,
             print_log=False,
             check_interval=1,
         )
     assert mock_session.describe_training_job.call_count == 3
Пример #6
0
 def test_describe_transform_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_transform_job.return_value': 'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_transform_job(job_name)
     mock_session.describe_transform_job.assert_called_once_with(TransformJobName=job_name)
     assert response == 'InProgress'
Пример #7
0
 def test_describe_endpoint_config(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_endpoint_config.return_value': config_name}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_endpoint_config(config_name)
     mock_session.describe_endpoint_config.assert_called_once_with(EndpointConfigName=config_name)
     self.assertEqual(response, config_name)
Пример #8
0
 def test_create_model(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'create_model.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_model(create_model_params)
     mock_session.create_model.assert_called_once_with(**create_model_params)
     self.assertEqual(response, test_arn_return)
Пример #9
0
 def test_create_tuning_job(self, mock_client, mock_check_tuning_config):
     mock_session = mock.Mock()
     attrs = {'create_hyper_parameter_tuning_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False)
     mock_session.create_hyper_parameter_tuning_job.assert_called_once_with(**create_tuning_params)
     self.assertEqual(response, test_arn_return)
Пример #10
0
 def test_update_endpoint(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'update_endpoint.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.update_endpoint(update_endpoint_params, wait_for_completion=False)
     mock_session.update_endpoint.assert_called_once_with(**update_endpoint_params)
     assert response == test_arn_return
Пример #11
0
 def test_create_transform_job_fs(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'create_transform_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_transform_job(create_transform_params_fs, wait_for_completion=False)
     mock_session.create_transform_job.assert_called_once_with(**create_transform_params_fs)
     assert response == test_arn_return
Пример #12
0
 def test_create_endpoint_config(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'create_endpoint_config.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_endpoint_config(create_endpoint_config_params)
     mock_session.create_endpoint_config.assert_called_once_with(**create_endpoint_config_params)
     assert response == test_arn_return
Пример #13
0
 def test_describe_model(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_model.return_value': model_name}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_model(model_name)
     mock_session.describe_model.assert_called_once_with(ModelName=model_name)
     assert response == model_name
Пример #14
0
 def test_describe_endpoint(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_endpoint.return_value': 'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_endpoint(endpoint_name)
     mock_session.describe_endpoint.assert_called_once_with(EndpointName=endpoint_name)
     assert response == 'InProgress'
Пример #15
0
 def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     mock_check_bucket.side_effect = [False, True, True, True]
     mock_check_key.side_effect = [False, True, False]
     mock_check_prefix.side_effect = [False, True, True]
     self.assertRaises(AirflowException, hook.check_s3_url, data_url)
     self.assertRaises(AirflowException, hook.check_s3_url, data_url)
     self.assertEqual(hook.check_s3_url(data_url), True)
     self.assertEqual(hook.check_s3_url(data_url), True)
Пример #16
0
 def test_describe_tuning_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_hyper_parameter_tuning_job.return_value':
              'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_tuning_job(job_name)
     mock_session.describe_hyper_parameter_tuning_job.\
         assert_called_once_with(HyperParameterTuningJobName=job_name)
     self.assertEqual(response, 'InProgress')
Пример #17
0
 def test_create_training_job(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_training_job(
         create_training_params, wait_for_completion=False, print_log=False
     )
     mock_session.create_training_job.assert_called_once_with(**create_training_params)
     assert response == test_arn_return
Пример #18
0
    def preprocess_config(self):
        self.log.info(
            'Preprocessing the config and doing required s3_operations')
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info('After preprocessing the config is:\n {}'.format(
            json.dumps(self.config,
                       sort_keys=True,
                       indent=4,
                       separators=(',', ': '))))
    def preprocess_config(self):
        """Process the config into a usable form."""
        self.log.info(
            'Preprocessing the config and doing required s3_operations'
        )
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            "After preprocessing the config is:\n %s",
            json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")),
        )
Пример #20
0
    def get_hook(self) -> SageMakerHook:
        """Get SageMakerHook"""
        if self.hook:
            return self.hook

        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
        return self.hook
Пример #21
0
 def test_training_ends_with_wait(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return,
              'describe_training_job.side_effect':
              [DESCRIBE_TRAINING_INPROGRESS_RETURN,
               DESCRIBE_TRAINING_STOPPING_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN]
              }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params, wait_for_completion=True,
                              print_log=False, check_interval=1)
     self.assertEqual(mock_session.describe_training_job.call_count, 4)
Пример #22
0
 def init_log_resource(self, hook: SageMakerHook) -> None:
     """Set tailing LogState for associated training job."""
     description = hook.describe_training_job(self.job_name)
     self.instance_count = description['ResourceConfig']['InstanceCount']
     status = description['TrainingJobStatus']
     job_already_completed = status not in self.non_terminal_states()
     self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING
     self.last_description = description
     self.last_describe_job_call = time.monotonic()
     self.log_resource_inited = True
Пример #23
0
 def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     mock_check_bucket.side_effect = [False, True, True, True]
     mock_check_key.side_effect = [False, True, False]
     mock_check_prefix.side_effect = [False, True, True]
     with pytest.raises(AirflowException):
         hook.check_s3_url(data_url)
     with pytest.raises(AirflowException):
         hook.check_s3_url(data_url)
     assert hook.check_s3_url(data_url) is True
     assert hook.check_s3_url(data_url) is True
Пример #24
0
    def test_check_valid_training(self, mock_check_url, mock_client):
        mock_client.return_value = None
        hook = SageMakerHook()
        hook.check_training_config(create_training_params)
        mock_check_url.assert_called_once_with(data_url)

        # InputDataConfig is optional, verify if check succeeds without InputDataConfig
        create_training_params_no_inputdataconfig = create_training_params.copy()
        create_training_params_no_inputdataconfig.pop("InputDataConfig")
        hook.check_training_config(create_training_params_no_inputdataconfig)
Пример #25
0
 def hook(self):
     """Return SageMakerHook"""
     return SageMakerHook(aws_conn_id=self.aws_conn_id)
class SageMakerBaseOperator(BaseOperator):
    """
    This is the base operator for all SageMaker operators.

    :param config: The configuration necessary to start a training job (templated)
    :type config: dict
    :param aws_conn_id: The AWS connection ID to use.
    :type aws_conn_id: str
    """

    template_fields = ['config']
    template_ext = ()
    ui_color = '#ededed'

    integer_fields = []  # type: Iterable[Iterable[str]]

    @apply_defaults
    def __init__(self, config, aws_conn_id='aws_default', *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.aws_conn_id = aws_conn_id
        self.config = config
        self.hook = None

    def parse_integer(self, config, field):
        """Recursive method for parsing string fields holding integer values to integers."""
        if len(field) == 1:
            if isinstance(config, list):
                for sub_config in config:
                    self.parse_integer(sub_config, field)
                return
            head = field[0]
            if head in config:
                config[head] = int(config[head])
            return

        if isinstance(config, list):
            for sub_config in config:
                self.parse_integer(sub_config, field)
            return

        head, tail = field[0], field[1:]
        if head in config:
            self.parse_integer(config[head], tail)
        return

    def parse_config_integers(self):
        """
        Parse the integer fields of training config to integers in case the config is rendered by Jinja and
        all fields are str.
        """
        for field in self.integer_fields:
            self.parse_integer(self.config, field)

    def expand_role(self):
        """Placeholder for calling boto3's expand_role(), which expands an IAM role name into an ARN."""

    def preprocess_config(self):
        """Process the config into a usable form."""
        self.log.info(
            'Preprocessing the config and doing required s3_operations')
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            "After preprocessing the config is:\n %s",
            json.dumps(self.config,
                       sort_keys=True,
                       indent=4,
                       separators=(",", ": ")),
        )

    def execute(self, context):
        raise NotImplementedError('Please implement execute() in sub class!')
Пример #27
0
 def get_hook(self):
     """Get SageMakerHook"""
     if not self.hook:
         self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
     return self.hook
Пример #28
0
 def test_conn(self, mock_get_client_type):
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id')
Пример #29
0
 def test_check_valid_tuning(self, mock_check_url, mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     hook.check_tuning_config(create_tuning_params)
     mock_check_url.assert_called_once_with(data_url)
Пример #30
0
 def test_multi_stream_iter(self, mock_log_stream):
     event = {'timestamp': 1}
     mock_log_stream.side_effect = [iter([event]), iter([]), None]
     hook = SageMakerHook()
     event_iter = hook.multi_stream_iter('log', [None, None, None])
     self.assertEqual(next(event_iter), (0, event))