class TestAWSBatchOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.batch = AWSBatchOperator( task_id='task', job_name='51455483-c62c-48ac-9b88-53a6a725baa3', queue='queue', job_definition='hello-world', max_retries=5, overrides={}, aws_conn_id=None, region_name='eu-west-1') def test_init(self): self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3') self.assertEqual(self.batch.queue, 'queue') self.assertEqual(self.batch.job_definition, 'hello-world') self.assertEqual(self.batch.max_retries, 5) self.assertEqual(self.batch.overrides, {}) self.assertEqual(self.batch.region_name, 'eu-west-1') self.assertEqual(self.batch.aws_conn_id, None) self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.batch.template_fields, ('overrides',)) @mock.patch.object(AWSBatchOperator, '_wait_for_task_ended') @mock.patch.object(AWSBatchOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1') client_mock.submit_job.assert_called_once_with( jobQueue='queue', jobName='51455483-c62c-48ac-9b88-53a6a725baa3', containerOverrides={}, jobDefinition='hello-world' ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19') def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = "" with self.assertRaises(AirflowException): self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1') client_mock.submit_job.assert_called_once_with( jobQueue='queue', jobName='51455483-c62c-48ac-9b88-53a6a725baa3', containerOverrides={}, jobDefinition='hello-world' ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock self.batch._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('job_execution_complete') client_mock.get_waiter.return_value.wait.assert_called_once_with( jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'] ) self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('No job found for ', str(e.exception)) def test_check_success_tasks_raises_failed(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'FAILED', 'attempts': [{ 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This containers encounter an error during execution ', str(e.exception)) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'RUNNABLE' }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This task is still pending ', str(e.exception)) def test_check_success_tasks_raises_mutliple(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'FAILED', 'attempts': [{ 'exitCode': 1 }, { 'exitCode': 10 }] }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This containers encounter an error during execution ', str(e.exception)) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'SUCCEEDED' }] } self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. client_mock.describe_jobs.assert_called_once_with(jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
class TestAWSBatchOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.batch = AWSBatchOperator( task_id='task', job_name='51455483-c62c-48ac-9b88-53a6a725baa3', job_queue='queue', job_definition='hello-world', max_retries=5, overrides={}, aws_conn_id=None, region_name='eu-west-1') def test_init(self): self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3') self.assertEqual(self.batch.job_queue, 'queue') self.assertEqual(self.batch.job_definition, 'hello-world') self.assertEqual(self.batch.max_retries, 5) self.assertEqual(self.batch.overrides, {}) self.assertEqual(self.batch.region_name, 'eu-west-1') self.assertEqual(self.batch.aws_conn_id, None) self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.batch.template_fields, ('overrides', )) @mock.patch.object(AWSBatchOperator, '_wait_for_task_ended') @mock.patch.object(AWSBatchOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'batch', region_name='eu-west-1') client_mock.submit_job.assert_called_once_with( jobQueue='queue', jobName='51455483-c62c-48ac-9b88-53a6a725baa3', containerOverrides={}, jobDefinition='hello-world') wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19') def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = "" with self.assertRaises(AirflowException): self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'batch', region_name='eu-west-1') client_mock.submit_job.assert_called_once_with( jobQueue='queue', jobName='51455483-c62c-48ac-9b88-53a6a725baa3', containerOverrides={}, jobDefinition='hello-world') def test_wait_end_tasks(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock self.batch._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with( 'job_execution_complete') client_mock.get_waiter.return_value.wait.assert_called_once_with( jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19']) self.assertEquals( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = {'jobs': []} with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('No job found for ', str(e.exception)) def test_check_success_tasks_raises_failed(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'FAILED', 'attempts': [{ 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This containers encounter an error during execution ', str(e.exception)) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'RUNNABLE' }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This task is still pending ', str(e.exception)) def test_check_success_tasks_raises_mutliple(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'FAILED', 'attempts': [{ 'exitCode': 1 }, { 'exitCode': 10 }] }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn('This containers encounter an error during execution ', str(e.exception)) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19' self.batch.client = client_mock client_mock.describe_jobs.return_value = { 'jobs': [{ 'status': 'SUCCEEDED' }] } self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. client_mock.describe_jobs.assert_called_once_with( jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
class TestAWSBatchOperator(unittest.TestCase): MAX_RETRIES = 2 STATUS_RETRIES = 3 @mock.patch("airflow.contrib.operators.awsbatch_operator.AwsHook") def setUp(self, aws_hook_mock): self.aws_hook_mock = aws_hook_mock self.batch = AWSBatchOperator( task_id="task", job_name=JOB_NAME, job_queue="queue", job_definition="hello-world", max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, overrides={}, array_properties=None, aws_conn_id=None, region_name="eu-west-1", ) def test_init(self): self.assertEqual(self.batch.job_name, JOB_NAME) self.assertEqual(self.batch.job_queue, "queue") self.assertEqual(self.batch.job_definition, "hello-world") self.assertEqual(self.batch.max_retries, self.MAX_RETRIES) self.assertEqual(self.batch.status_retries, self.STATUS_RETRIES) self.assertEqual(self.batch.parameters, None) self.assertEqual(self.batch.overrides, {}) self.assertEqual(self.batch.array_properties, {}) self.assertEqual(self.batch.region_name, "eu-west-1") self.assertEqual(self.batch.aws_conn_id, None) self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.batch.template_fields, ( "job_name", "overrides", "parameters", )) @mock.patch.object(AWSBatchOperator, "_wait_for_task_ended") @mock.patch.object(AWSBatchOperator, "_check_success_task") def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( "batch", region_name="eu-west-1") client_mock.submit_job.assert_called_once_with( jobQueue="queue", jobName=JOB_NAME, containerOverrides={}, jobDefinition="hello-world", arrayProperties={}, parameters=None, ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(self.batch.jobId, JOB_ID) def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = "" with self.assertRaises(AirflowException): self.batch.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( "batch", region_name="eu-west-1") client_mock.submit_job.assert_called_once_with( jobQueue="queue", jobName=JOB_NAME, containerOverrides={}, jobDefinition="hello-world", arrayProperties={}, parameters=None, ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock self.batch._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with( "job_execution_complete") client_mock.get_waiter.return_value.wait.assert_called_once_with( jobs=[JOB_ID]) self.assertEqual( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") def test_poll_job_status_success(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock mock_randint.return_value = 0 # don't pause in unit tests client_mock.get_waiter.return_value.wait.side_effect = ValueError() client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "SUCCEEDED" }] } self.batch._wait_for_task_ended() client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") def test_poll_job_status_running(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock mock_randint.return_value = 0 # don't pause in unit tests client_mock.get_waiter.return_value.wait.side_effect = ValueError() client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "RUNNING" }] } self.batch._wait_for_task_ended() # self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES) client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID]) self.assertEqual(client_mock.describe_jobs.call_count, self.MAX_RETRIES) @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") def test_poll_job_status_hit_api_throttle(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock mock_randint.return_value = 0 # don't pause in unit tests client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError( error_response={"Error": { "Code": "TooManyRequestsException" }}, operation_name="get job description", ) with self.assertRaises(Exception) as e: self.batch._poll_for_task_ended() self.assertIn("Failed to get job description", str(e.exception)) client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID]) self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = {"jobs": []} with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("Failed to get job description", str(e.exception)) def test_check_success_tasks_raises_failed(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "FAILED", "statusReason": "This is an error reason", "attempts": [{ "exitCode": 1 }], }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception)) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "RUNNABLE" }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("Job ({}) is still pending".format(JOB_ID), str(e.exception)) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "FAILED", "statusReason": "This is an error reason", "attempts": [{ "exitCode": 1 }, { "exitCode": 10 }], }] } with self.assertRaises(Exception) as e: self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception)) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = { "jobs": [{ "jobId": JOB_ID, "status": "SUCCEEDED" }] } self.batch._check_success_task() # Ordering of str(dict) is not guaranteed. client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) def test_check_success_task_raises_without_jobs(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.describe_jobs.return_value = {"jobs": []} with self.assertRaises(Exception) as e: self.batch._check_success_task() client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID]) self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES) self.assertIn("Failed to get job description", str(e.exception)) def test_kill_job(self): client_mock = mock.Mock() self.batch.jobId = JOB_ID self.batch.client = client_mock client_mock.terminate_job.return_value = {} self.batch.on_kill() client_mock.terminate_job.assert_called_once_with( jobId=JOB_ID, reason="Task killed by the user")