class TestECSOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.ecs = ECSOperator(task_id='task', task_definition='t', cluster='c', overrides={}, aws_conn_id=None, region_name='eu-west-1') def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides', )) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t') wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual( self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' ) def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t') def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with( cluster='c', tasks=['arn']) self.assertEquals( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is not in success state ", str(e.exception)) self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): self.aws_hook_mock = aws_hook_mock self.ecs_operator_args = { 'task_id': 'task', 'task_definition': 't', 'cluster': 'c', 'overrides': {}, 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', 'placement_constraints': [{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], 'network_configuration': { 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } } } self.ecs = ECSOperator(**self.ecs_operator_args) def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides', )) @parameterized.expand([['EC2'], ['FARGATE']]) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def test_execute_without_failures(self, launch_type, aws_hook_mock, check_mock, wait_mock): client_mock = aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES ecs = ECSOperator(launch_type=launch_type, **self.ecs_operator_args) ecs.execute(None) aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') extend_args = {} if launch_type == 'FARGATE': extend_args['platformVersion'] = 'LATEST' client_mock.run_task.assert_called_once_with( cluster='c', launchType=launch_type, overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } }, **extend_args) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual( ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' ) def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'], } }) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with( cluster='c', tasks=['arn']) self.assertEqual( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is not in success state ", str(e.exception)) self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_host_terminated_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', "containers": [{ "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 "lastStatus": "RUNNING", "name": "wordpress", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" # noqa: E501 }], "desiredStatus": "STOPPED", "lastStatus": "STOPPED", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" # noqa: E501 }] } with self.assertRaises(AirflowException) as e: self.ecs._check_success_task() self.assertIn( "The task was stopped because the host instance terminated:", str(e.exception)) self.assertIn("Host EC2 (", str(e.exception)) self.assertIn(") terminated", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.ecs = ECSOperator( task_id='task', task_definition='t', cluster='c', overrides={}, aws_conn_id=None, region_name='eu-west-1') def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides',)) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', overrides={}, startedBy='Airflow', taskDefinition='t' ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', overrides={}, startedBy='Airflow', taskDefinition='t' ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}") client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}") client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_mutliple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.ecs = ECSOperator( task_id='task', task_definition='t', cluster='c', overrides={}, aws_conn_id=None, region_name='eu-west-1', group='group', placement_constraints=[ { 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' } ], network_configuration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'] } } ) def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides',)) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[ { 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' } ], platformVersion='LATEST', networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'] } } ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[ { 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' } ], platformVersion='LATEST', networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'] } } ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is not in success state ", str(e.exception)) self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() self.aws_hook_mock = aws_hook_mock self.ecs_operator_args = { 'task_id': 'task', 'task_definition': 't', 'cluster': 'c', 'overrides': {}, 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', 'placement_constraints': [{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], 'network_configuration': { 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } } } self.ecs = ECSOperator(**self.ecs_operator_args) def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides',)) @parameterized.expand([ ['EC2'], ['FARGATE'] ]) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') def test_execute_without_failures(self, launch_type, aws_hook_mock, check_mock, wait_mock): client_mock = aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES ecs = ECSOperator(launch_type=launch_type, **self.ecs_operator_args) ecs.execute(None) aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') extend_args = {} if launch_type == 'FARGATE': extend_args['platformVersion'] = 'LATEST' client_mock.run_task.assert_called_once_with( cluster='c', launchType=launch_type, overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[ { 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' } ], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } }, **extend_args ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual(ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[ { 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' } ], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'], } } ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with( cluster='c', tasks=['arn']) self.assertEqual( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is not in success state ", str(e.exception)) self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) client_mock.describe_tasks.assert_called_once_with( cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) client_mock.describe_tasks.assert_called_once_with( cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with( cluster='c', tasks=['arn']) def test_host_terminated_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', "containers": [ { "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 "lastStatus": "RUNNING", "name": "wordpress", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" # noqa: E501 } ], "desiredStatus": "STOPPED", "lastStatus": "STOPPED", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" # noqa: E501 }] } with self.assertRaises(AirflowException) as e: self.ecs._check_success_task() self.assertIn( "The task was stopped because the host instance terminated:", str(e.exception)) self.assertIn("Host EC2 (", str(e.exception)) self.assertIn(") terminated", str(e.exception)) client_mock.describe_tasks.assert_called_once_with( cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with( cluster='c', tasks=['arn'])