def execute(self, context):
        sagemaker = SageMakerHook(
            sagemaker_conn_id=self.sagemaker_conn_id,
            use_db_config=self.use_db_config,
            region_name=self.region_name,
            check_interval=self.check_interval,
            max_ingestion_time=self.max_ingestion_time
        )

        if self.model_config:
            self.log.info(
                "Creating SageMaker Model %s for transform job"
                % self.model_config['ModelName']
            )
            sagemaker.create_model(self.model_config)

        self.log.info(
            "Creating SageMaker transform Job %s."
            % self.transform_job_config['TransformJobName']
        )
        response = sagemaker.create_transform_job(
            self.transform_job_config,
            wait_for_completion=self.wait_for_completion)
        if not response['ResponseMetadata']['HTTPStatusCode'] \
           == 200:
            raise AirflowException(
                'Sagemaker transform Job creation failed: %s' % response)
        else:
            return response
コード例 #2
0
 def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training):
     mock_check_training.return_value = True
     mock_describe.side_effect = \
         [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0),
          (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0),
          (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RETURN, 0)]
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return,
              'describe_training_job.return_value':
                  DESCRIBE_TRAINING_COMPELETED_RETURN
              }
     log_attrs = {'describe_log_streams.side_effect':
                  LIFECYCLE_LOG_STREAMS,
                  'get_log_events.side_effect':
                  STREAM_LOG_EVENTS
                  }
     mock_session.configure_mock(**attrs)
     mock_log_session.configure_mock(**log_attrs)
     mock_client.return_value = mock_session
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params, wait_for_completion=True,
                              print_log=True, check_interval=1)
     self.assertEqual(mock_describe.call_count, 3)
     self.assertEqual(mock_session.describe_training_job.call_count, 1)
コード例 #3
0
 def test_configure_s3_resources(self, mock_load_file, mock_create_bucket):
     hook = SageMakerHook()
     evaluation_result = {'Image': image, 'Role': role}
     hook.configure_s3_resources(test_evaluation_config)
     self.assertEqual(test_evaluation_config, evaluation_result)
     mock_create_bucket.assert_called_once_with(bucket_name=bucket)
     mock_load_file.assert_called_once_with(path, key, bucket)
コード例 #4
0
 def test_describe_training_job_with_logs_complete(self, mock_client,
                                                   mock_log_client):
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {
         'describe_training_job.return_value':
         DESCRIBE_TRAINING_COMPELETED_RETURN
     }
     log_attrs = {
         'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS,
         'get_log_events.side_effect': STREAM_LOG_EVENTS
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     mock_log_session.configure_mock(**log_attrs)
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_training_job_with_log(
         job_name=job_name,
         positions={},
         stream_names=[],
         instance_count=1,
         state=LogState.COMPLETE,
         last_description={},
         last_describe_job_call=0)
     self.assertEqual(response, (LogState.COMPLETE, {}, 0))
コード例 #5
0
 def test_training_with_logs(self, mock_describe, mock_client,
                             mock_log_client, mock_check_training):
     mock_check_training.return_value = True
     mock_describe.side_effect = \
         [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0),
          (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0),
          (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RETURN, 0)]
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.return_value':
         DESCRIBE_TRAINING_COMPELETED_RETURN
     }
     log_attrs = {
         'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS,
         'get_log_events.side_effect': STREAM_LOG_EVENTS
     }
     mock_session.configure_mock(**attrs)
     mock_log_session.configure_mock(**log_attrs)
     mock_client.return_value = mock_session
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params,
                              wait_for_completion=True,
                              print_log=True,
                              check_interval=1)
     self.assertEqual(mock_describe.call_count, 3)
     self.assertEqual(mock_session.describe_training_job.call_count, 1)
コード例 #6
0
 def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_client, mock_log_client):
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {'describe_training_job.return_value':
              DESCRIBE_TRAINING_COMPELETED_RETURN
              }
     log_attrs = {'describe_log_streams.side_effect':
                  LIFECYCLE_LOG_STREAMS,
                  'get_log_events.side_effect':
                  STREAM_LOG_EVENTS
                  }
     mock_time.return_value = 50
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     mock_log_session.configure_mock(**log_attrs)
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_training_job_with_log(job_name=job_name,
                                                    positions={},
                                                    stream_names=[],
                                                    instance_count=1,
                                                    state=LogState.WAIT_IN_PROGRESS,
                                                    last_description={},
                                                    last_describe_job_call=0)
     self.assertEqual(response, (LogState.JOB_COMPLETE, {}, 50))
コード例 #7
0
    def get_sagemaker_response(self):
        sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
        if self.print_log:
            if not self.log_resource_inited:
                self.init_log_resource(sagemaker_hook)
            self.state, self.last_description, self.last_describe_job_call = \
                sagemaker_hook.describe_training_job_with_log(self.job_name,
                                                              self.positions, self.stream_names,
                                                              self.instance_count, self.state,
                                                              self.last_description,
                                                              self.last_describe_job_call)
        else:
            self.last_description = sagemaker_hook.describe_training_job(
                self.job_name)

        status = self.state_from_response(self.last_description)
        if status not in self.non_terminal_states(
        ) and status not in self.failed_states():
            billable_time = \
                (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
                self.last_description['ResourceConfig']['InstanceCount']
            self.log.info('Billable seconds: %s',
                          int(billable_time.total_seconds()) + 1)

        return self.last_description
コード例 #8
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(
            aws_conn_id=self.aws_conn_id,
            region_name=self.region_name
        )

        self.log.info('Poking Sagemaker Training Job %s', self.job_name)
        return sagemaker.describe_training_job(self.job_name)
コード例 #9
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(
            aws_conn_id=self.aws_conn_id,
            region_name=self.region_name
        )

        self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
        return sagemaker.describe_tuning_job(self.job_name)
コード例 #10
0
 def test_check_for_url(self, mock_check_bucket, mock_check_key,
                        mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     mock_check_bucket.side_effect = [False, True, True]
     mock_check_key.side_effect = [False, True]
     self.assertRaises(AirflowException, hook.check_for_url, data_url)
     self.assertRaises(AirflowException, hook.check_for_url, data_url)
     self.assertEqual(hook.check_for_url(data_url), True)
コード例 #11
0
 def test_create_model(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'create_model.return_value':
              test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_model(create_model_params)
     mock_session.create_model.assert_called_once_with(**create_model_params)
     self.assertEqual(response, test_arn_return)
コード例 #12
0
 def test_describe_endpoint(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_endpoint.return_value': 'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_endpoint(endpoint_name)
     mock_session.describe_endpoint.\
         assert_called_once_with(EndpointName=endpoint_name)
     self.assertEqual(response, 'InProgress')
コード例 #13
0
 def test_create_endpoint_config(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'create_endpoint_config.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_endpoint_config(create_endpoint_config_params)
     mock_session.create_endpoint_config\
         .assert_called_once_with(**create_endpoint_config_params)
     self.assertEqual(response, test_arn_return)
コード例 #14
0
 def test_describe_model(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_model.return_value': model_name}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_model(model_name)
     mock_session.describe_model.\
         assert_called_once_with(ModelName=model_name)
     self.assertEqual(response, model_name)
コード例 #15
0
 def test_configure_s3_resources(self, mock_load_file, mock_create_bucket):
     hook = SageMakerHook()
     evaluation_result = {
         'Image': image,
         'Role': role
     }
     hook.configure_s3_resources(test_evaluation_config)
     self.assertEqual(test_evaluation_config, evaluation_result)
     mock_create_bucket.assert_called_once_with(bucket_name=bucket)
     mock_load_file.assert_called_once_with(path, key, bucket)
コード例 #16
0
 def test_describe_training_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_training_job.return_value': 'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_training_job(job_name)
     mock_session.describe_training_job.\
         assert_called_once_with(TrainingJobName=job_name)
     self.assertEqual(response, 'InProgress')
コード例 #17
0
 def test_describe_transform_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_transform_job.return_value': 'InProgress'}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_transform_job(job_name)
     mock_session.describe_transform_job.\
         assert_called_once_with(TransformJobName=job_name)
     self.assertEqual(response, 'InProgress')
コード例 #18
0
 def test_describe_endpoint_config(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_endpoint_config.return_value':
              config_name}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_endpoint_config(config_name)
     mock_session.describe_endpoint_config.\
         assert_called_once_with(EndpointConfigName=config_name)
     self.assertEqual(response, config_name)
コード例 #19
0
 def test_describe_model(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'describe_model.return_value':
              model_name}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.describe_model(model_name)
     mock_session.describe_model.\
         assert_called_once_with(ModelName=model_name)
     self.assertEqual(response, model_name)
コード例 #20
0
 def test_update_endpoint(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'update_endpoint.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.update_endpoint(update_endpoint_params,
                                     wait_for_completion=False)
     mock_session.update_endpoint\
         .assert_called_once_with(**update_endpoint_params)
     self.assertEqual(response, test_arn_return)
コード例 #21
0
 def test_create_transform_job(self, mock_client, mock_check_url):
     mock_check_url.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_transform_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_transform_job(create_transform_params,
                                          wait_for_completion=False)
     mock_session.create_transform_job.assert_called_once_with(
         **create_transform_params)
     self.assertEqual(response, test_arn_return)
コード例 #22
0
 def test_create_tuning_job(self, mock_client, mock_check_tuning):
     mock_session = mock.Mock()
     attrs = {'create_hyper_parameter_tuning_job.return_value':
              test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_tuning_job(create_tuning_params,
                                       wait_for_completion=False)
     mock_session.create_hyper_parameter_tuning_job.\
         assert_called_once_with(**create_tuning_params)
     self.assertEqual(response, test_arn_return)
コード例 #23
0
 def test_create_tuning_job(self, mock_client, mock_check_tuning):
     mock_session = mock.Mock()
     attrs = {
         'create_hyper_parameter_tuning_job.return_value': test_arn_return
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id')
     response = hook.create_tuning_job(create_tuning_params)
     mock_session.create_hyper_parameter_tuning_job.\
         assert_called_once_with(**create_tuning_params)
     self.assertEqual(response, test_arn_return)
コード例 #24
0
 def test_update_endpoint(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'update_endpoint.return_value':
              test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.update_endpoint(update_endpoint_params,
                                     wait_for_completion=False)
     mock_session.update_endpoint\
         .assert_called_once_with(**update_endpoint_params)
     self.assertEqual(response, test_arn_return)
コード例 #25
0
 def test_list_tuning_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {'list_hyper_parameter_tuning_job.return_value':
              test_list_tuning_job_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id')
     response = hook.list_tuning_job(name_contains=job_name,
                                     status_equals='InProgress')
     mock_session.list_hyper_parameter_tuning_job. \
         assert_called_once_with(NameContains=job_name,
                                 StatusEquals='InProgress')
     self.assertEqual(response, test_list_tuning_job_return)
コード例 #26
0
 def test_list_training_job(self, mock_client):
     mock_session = mock.Mock()
     attrs = {
         'list_training_jobs.return_value': test_list_training_job_return
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id')
     response = hook.list_training_job(name_contains=job_name,
                                       status_equals='InProgress')
     mock_session.list_training_jobs. \
         assert_called_once_with(NameContains=job_name,
                                 StatusEquals='InProgress')
     self.assertEqual(response, test_list_training_job_return)
コード例 #27
0
    def preprocess_config(self):
        self.log.info(
            'Preprocessing the config and doing required s3_operations'
        )
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            'After preprocessing the config is:\n {}'.format(
                json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
        )
コード例 #28
0
 def test_create_training_job_db_config(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id',
                                        use_db_config=True)
     response = hook_use_db_config.create_training_job(create_training_params,
                                                       wait_for_completion=False)
     updated_config = copy.deepcopy(create_training_params)
     updated_config.update(db_config)
     mock_session.create_training_job.assert_called_once_with(**updated_config)
     self.assertEqual(response, test_arn_return)
コード例 #29
0
 def test_create_transform_job_db_config(self, mock_client, mock_check_url):
     mock_check_url.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_transform_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook_use_db_config = SageMakerHook(
         sagemaker_conn_id='sagemaker_test_conn_id', use_db_config=True)
     response = hook_use_db_config.create_transform_job(
         create_transform_params, wait_for_completion=False)
     updated_config = copy.deepcopy(create_transform_params)
     updated_config.update(db_config)
     mock_session.create_transform_job.assert_called_once_with(
         **updated_config)
     self.assertEqual(response, test_arn_return)
コード例 #30
0
 def test_check_s3_url(self,
                       mock_check_prefix,
                       mock_check_bucket,
                       mock_check_key,
                       mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     mock_check_bucket.side_effect = [False, True, True, True]
     mock_check_key.side_effect = [False, True, False]
     mock_check_prefix.side_effect = [False, True, True]
     self.assertRaises(AirflowException,
                       hook.check_s3_url, data_url)
     self.assertRaises(AirflowException,
                       hook.check_s3_url, data_url)
     self.assertEqual(hook.check_s3_url(data_url), True)
     self.assertEqual(hook.check_s3_url(data_url), True)
コード例 #31
0
 def test_training_ends_with_wait_on(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return,
              'describe_training_job.side_effect':
                  [DESCRIBE_TRAINING_INPROGRESS_RETURN,
                   DESCRIBE_TRAINING_STOPPING_RETURN,
                   DESCRIBE_TRAINING_STOPPED_RETURN,
                   DESCRIBE_TRAINING_COMPELETED_RETURN]
              }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params, wait_for_completion=True)
     self.assertEqual(mock_session.describe_training_job.call_count, 4)
コード例 #32
0
 def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning):
     mock_check_tuning.return_value = True
     mock_session = mock.Mock()
     attrs = {
         'create_hyper_parameter_tuning_job.return_value': test_arn_return
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id',
                          use_db_config=True)
     response = hook.create_tuning_job(create_tuning_params)
     updated_config = copy.deepcopy(create_tuning_params)
     updated_config.update(db_config)
     mock_session.create_hyper_parameter_tuning_job. \
         assert_called_once_with(**updated_config)
     self.assertEqual(response, test_arn_return)
コード例 #33
0
 def test_training_ends_with_wait(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return,
              'describe_training_job.side_effect':
              [DESCRIBE_TRAINING_INPROGRESS_RETURN,
               DESCRIBE_TRAINING_STOPPING_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN]
              }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params, wait_for_completion=True,
                              print_log=False, check_interval=1)
     self.assertEqual(mock_session.describe_training_job.call_count, 4)
コード例 #34
0
    def execute(self, context):
        sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id,
                                  use_db_config=self.use_db_config,
                                  region_name=self.region_name,
                                  check_interval=self.check_interval,
                                  max_ingestion_time=self.max_ingestion_time)

        self.log.info("Creating SageMaker Training Job %s." %
                      self.training_job_config['TrainingJobName'])
        response = sagemaker.create_training_job(
            self.training_job_config,
            wait_for_completion=self.wait_for_completion)
        if not response['ResponseMetadata']['HTTPStatusCode'] \
           == 200:
            raise AirflowException(
                'Sagemaker Training Job creation failed: %s' % response)
        else:
            return response
コード例 #35
0
    def get_sagemaker_response(self):
        sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
        if self.print_log:
            if not self.log_resource_inited:
                self.init_log_resource(sagemaker_hook)
            self.state, self.last_description, self.last_describe_job_call = \
                sagemaker_hook.describe_training_job_with_log(self.job_name,
                                                              self.positions, self.stream_names,
                                                              self.instance_count, self.state,
                                                              self.last_description,
                                                              self.last_describe_job_call)
        else:
            self.last_description = sagemaker_hook.describe_training_job(self.job_name)

        status = self.state_from_response(self.last_description)
        if status not in self.non_terminal_states() and status not in self.failed_states():
            billable_time = \
                (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
                self.last_description['ResourceConfig']['InstanceCount']
            self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)

        return self.last_description
コード例 #36
0
    def preprocess_config(self):
        self.log.info(
            'Preprocessing the config and doing required s3_operations'
        )
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            'After preprocessing the config is:\n {}'.format(
                json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
        )
    def execute(self, context):
        sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id,
                                  region_name=self.region_name,
                                  use_db_config=self.use_db_config,
                                  check_interval=self.check_interval,
                                  max_ingestion_time=self.max_ingestion_time
                                  )

        self.log.info(
            "Creating SageMaker Hyper Parameter Tunning Job %s"
            % self.tuning_job_config['HyperParameterTuningJobName']
        )

        response = sagemaker.create_tuning_job(
            self.tuning_job_config,
            wait_for_completion=self.wait_for_completion
        )
        if not response['ResponseMetadata']['HTTPStatusCode'] \
           == 200:
            raise AirflowException(
                "Sagemaker Tuning Job creation failed: %s" % response)
        else:
            return response
コード例 #38
0
    def test_check_valid_training(self, mock_check_url, mock_client):
        mock_client.return_value = None
        hook = SageMakerHook()
        hook.check_training_config(create_training_params)
        mock_check_url.assert_called_once_with(data_url)

        # InputDataConfig is optional, verify if check succeeds without InputDataConfig
        create_training_params_no_inputdataconfig = create_training_params.copy()
        create_training_params_no_inputdataconfig.pop("InputDataConfig")
        hook.check_training_config(create_training_params_no_inputdataconfig)
コード例 #39
0
 def test_training_throws_error_when_failed_with_wait_on(
         self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.side_effect': [
             DESCRIBE_TRAINING_INPROGRESS_RETURN,
             DESCRIBE_TRAINING_STOPPING_RETURN,
             DESCRIBE_TRAINING_STOPPED_RETURN,
             DESCRIBE_TRAINING_FAILED_RETURN
         ]
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1')
     self.assertRaises(AirflowException,
                       hook.create_training_job,
                       create_training_params,
                       wait_for_completion=True)
     self.assertEqual(mock_session.describe_training_job.call_count, 4)
コード例 #40
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
        return sagemaker.describe_transform_job(self.job_name)
コード例 #41
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
        return sagemaker.describe_endpoint(self.endpoint_name)
コード例 #42
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
        return sagemaker.describe_endpoint(self.endpoint_name)
コード例 #43
0
 def test_multi_stream_iter(self, mock_log_stream):
     event = {'timestamp': 1}
     mock_log_stream.side_effect = [iter([event]), iter([]), None]
     hook = SageMakerHook()
     event_iter = hook.multi_stream_iter('log', [None, None, None])
     self.assertEqual(next(event_iter), (0, event))
コード例 #44
0
class SageMakerBaseOperator(BaseOperator):
    """
    This is the base operator for all SageMaker operators.

    :param config: The configuration necessary to start a training job (templated)
    :type config: dict
    :param aws_conn_id: The AWS connection ID to use.
    :type aws_conn_id: str
    """

    template_fields = ['config']
    template_ext = ()
    ui_color = '#ededed'

    integer_fields = []  # type: Iterable[Iterable[str]]

    @apply_defaults
    def __init__(self,
                 config,
                 aws_conn_id='aws_default',
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.aws_conn_id = aws_conn_id
        self.config = config
        self.hook = None

    def parse_integer(self, config, field):
        if len(field) == 1:
            if isinstance(config, list):
                for sub_config in config:
                    self.parse_integer(sub_config, field)
                return
            head = field[0]
            if head in config:
                config[head] = int(config[head])
            return

        if isinstance(config, list):
            for sub_config in config:
                self.parse_integer(sub_config, field)
            return

        head, tail = field[0], field[1:]
        if head in config:
            self.parse_integer(config[head], tail)
        return

    def parse_config_integers(self):
        # Parse the integer fields of training config to integers
        # in case the config is rendered by Jinja and all fields are str
        for field in self.integer_fields:
            self.parse_integer(self.config, field)

    def expand_role(self):
        pass

    def preprocess_config(self):
        self.log.info(
            'Preprocessing the config and doing required s3_operations'
        )
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            'After preprocessing the config is:\n {}'.format(
                json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
        )

    def execute(self, context):
        raise NotImplementedError('Please implement execute() in sub class!')
コード例 #45
0
 def test_check_valid_tuning(self, mock_check_url, mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     hook.check_tuning_config(create_tuning_params)
     mock_check_url.assert_called_once_with(data_url)
コード例 #46
0
    def get_sagemaker_response(self):
        sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
        return sagemaker.describe_transform_job(self.job_name)
コード例 #47
0
class SageMakerBaseOperator(BaseOperator):
    """
    This is the base operator for all SageMaker operators.

    :param config: The configuration necessary to start a training job (templated)
    :type config: dict
    :param aws_conn_id: The AWS connection ID to use.
    :type aws_conn_id: str
    """

    template_fields = ['config']
    template_ext = ()
    ui_color = '#ededed'

    integer_fields = []

    @apply_defaults
    def __init__(self,
                 config,
                 aws_conn_id='aws_default',
                 *args, **kwargs):
        super(SageMakerBaseOperator, self).__init__(*args, **kwargs)

        self.aws_conn_id = aws_conn_id
        self.config = config
        self.hook = None

    def parse_integer(self, config, field):
        if len(field) == 1:
            if isinstance(config, list):
                for sub_config in config:
                    self.parse_integer(sub_config, field)
                return
            head = field[0]
            if head in config:
                config[head] = int(config[head])
            return

        if isinstance(config, list):
            for sub_config in config:
                self.parse_integer(sub_config, field)
            return

        head, tail = field[0], field[1:]
        if head in config:
            self.parse_integer(config[head], tail)
        return

    def parse_config_integers(self):
        # Parse the integer fields of training config to integers
        # in case the config is rendered by Jinja and all fields are str
        for field in self.integer_fields:
            self.parse_integer(self.config, field)

    def expand_role(self):
        pass

    def preprocess_config(self):
        self.log.info(
            'Preprocessing the config and doing required s3_operations'
        )
        self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)

        self.hook.configure_s3_resources(self.config)
        self.parse_config_integers()
        self.expand_role()

        self.log.info(
            'After preprocessing the config is:\n {}'.format(
                json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')))
        )

    def execute(self, context):
        raise NotImplementedError('Please implement execute() in sub class!')
コード例 #48
0
 def test_multi_stream_iter(self, mock_log_stream):
     event = {'timestamp': 1}
     mock_log_stream.side_effect = [iter([event]), iter([]), None]
     hook = SageMakerHook()
     event_iter = hook.multi_stream_iter('log', [None, None, None])
     self.assertEqual(next(event_iter), (0, event))
コード例 #49
0
 def test_conn(self, mock_get_client_type):
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id')
コード例 #50
0
 def test_check_valid_tuning(self, mock_check_url, mock_client):
     mock_client.return_value = None
     hook = SageMakerHook()
     hook.check_tuning_config(create_tuning_params)
     mock_check_url.assert_called_once_with(data_url)
コード例 #51
0
 def test_conn(self, mock_get_client):
     hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id',
                          region_name='us-east-1')
     self.assertEqual(hook.sagemaker_conn_id, 'sagemaker_test_conn_id')
     mock_get_client.assert_called_once_with('sagemaker',
                                             region_name='us-east-1')