def test_sensor(self, mock_describe_job, hook_init, mock_client): hook_init.return_value = None mock_describe_job.side_effect = [ DESCRIBE_TRAINING_INPROGRESS_RESPONSE, DESCRIBE_TRAINING_STOPPING_RESPONSE, DESCRIBE_TRAINING_COMPELETED_RESPONSE ] sensor = SageMakerTrainingSensor(task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', print_log=False) sensor.execute(None) # make sure we called 3 times(terminated when its compeleted) self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params calls = [ mock.call(aws_conn_id='aws_test'), mock.call(aws_conn_id='aws_test'), mock.call(aws_conn_id='aws_test') ] hook_init.assert_has_calls(calls)
def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log, hook_init, mock_log_client, mock_client): hook_init.return_value = None mock_describe_job.return_value = DESCRIBE_TRAINING_COMPLETED_RESPONSE mock_describe_job_with_log.side_effect = [ (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0), (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RESPONSE, 0) ] sensor = SageMakerTrainingSensor(task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', print_log=True) sensor.execute(None) self.assertEqual(mock_describe_job_with_log.call_count, 3) self.assertEqual(mock_describe_job.call_count, 1) calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls)
def test_sensor_with_failure(self, mock_describe_job, hook_init, mock_client): hook_init.return_value = None mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RESPONSE] sensor = SageMakerTrainingSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', print_log=False, ) with pytest.raises(AirflowException): sensor.execute(None) mock_describe_job.assert_called_once_with('test_job_name')