def set_up_operator(self, aws_hook_mock, **kwargs): self.aws_hook_mock = aws_hook_mock self.ecs_operator_args = { 'task_id': 'task', 'task_definition': 't', 'cluster': 'c', 'overrides': {}, 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', 'placement_constraints': [ {'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf'} ], 'placement_strategy': [{'field': 'memory', 'type': 'binpack'}], 'network_configuration': { 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']} }, 'propagate_tags': 'TASK_DEFINITION', } self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs) self.ecs.get_hook()
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook') def setUp(self, aws_hook_mock): self.aws_hook_mock = aws_hook_mock self.ecs_operator_args = { 'task_id': 'task', 'task_definition': 't', 'cluster': 'c', 'overrides': {}, 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', 'placement_constraints': [{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], 'network_configuration': { 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } } } self.ecs = ECSOperator(**self.ecs_operator_args) def test_init(self): self.assertEqual(self.ecs.region_name, 'eu-west-1') self.assertEqual(self.ecs.task_definition, 't') self.assertEqual(self.ecs.aws_conn_id, None) self.assertEqual(self.ecs.cluster, 'c') self.assertEqual(self.ecs.overrides, {}) self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides', )) @parameterized.expand([ ['EC2', None], ['FARGATE', None], ['EC2', { 'testTagKey': 'testTagValue' }], ]) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') @mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsHook') def test_execute_without_failures(self, launch_type, tags, aws_hook_mock, check_mock, wait_mock): client_mock = aws_hook_mock.return_value.get_client_type.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES ecs = ECSOperator(launch_type=launch_type, tags=tags, **self.ecs_operator_args) ecs.execute(None) aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') extend_args = {} if launch_type == 'FARGATE': extend_args['platformVersion'] = 'LATEST' if tags: extend_args['tags'] = [{ 'key': k, 'value': v } for (k, v) in tags.items()] client_mock.run_task.assert_called_once_with( cluster='c', launchType=launch_type, overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } }, **extend_args) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() self.assertEqual( ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' ) def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with self.assertRaises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_client_type.assert_called_once_with( 'ecs', region_name='eu-west-1') client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'], } }) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with( cluster='c', tasks=['arn']) self.assertEqual( sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is not in success state ", str(e.exception)) self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_host_terminated_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', "containers": [{ "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 # pylint: disable=line-too-long "lastStatus": "RUNNING", "name": "wordpress", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" # noqa: E501 # pylint: disable=line-too-long }], "desiredStatus": "STOPPED", "lastStatus": "STOPPED", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" # noqa: E501 # pylint: disable=line-too-long }] } with self.assertRaises(AirflowException) as e: self.ecs._check_success_task() self.assertIn( "The task was stopped because the host instance terminated:", str(e.exception)) self.assertIn("Host EC2 (", str(e.exception)) self.assertIn(") terminated", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
concurrency=2, catchup=True) SCRAPER = ECSOperator( task_id='news_scraper_historical', dag=NEWS_CRAWLER_DAG, # set this connection id in airflow web UI aws_conn_id="aws_default", cluster="news-crawler-ecs", task_definition="ecs-task-definition", launch_type="FARGATE", overrides={ "containerOverrides": [ { "name": "news-crawler-container", "command": ['{{ds}}'] }, ], }, network_configuration={ "awsvpcConfiguration": { # set these variables in airflow web UI "securityGroups": [Variable.get("SECURITY_GROUP_ID")], "subnets": [Variable.get("SUBNET_ID")], "assignPublicIp": "ENABLED", }, }, awslogs_group="/ecs/ecs-task-definition") # pylint: disable=pointless-statement SCRAPER
hello_world = ECSOperator( task_id="hello_world", dag=dag, aws_conn_id="aws_ecs", cluster="c", task_definition="hello-world", launch_type="FARGATE", overrides={ "containerOverrides": [ { "name": "hello-world-container", "command": ["echo", "hello", "world"], }, ], }, network_configuration={ "awsvpcConfiguration": { "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")], "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")], }, }, tags={ "Customer": "X", "Project": "Y", "Application": "Z", "Version": "0.0.1", "Environment": "Development", }, awslogs_group="/ecs/hello-world", awslogs_stream_prefix= "prefix_b/hello-world-container", # prefix with container name )
class TestECSOperator(unittest.TestCase): @mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsBaseHook') def set_up_operator(self, aws_hook_mock, **kwargs): self.aws_hook_mock = aws_hook_mock self.ecs_operator_args = { 'task_id': 'task', 'task_definition': 't', 'cluster': 'c', 'overrides': {}, 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', 'placement_constraints': [{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], 'placement_strategy': [{ 'field': 'memory', 'type': 'binpack' }], 'network_configuration': { 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } }, 'propagate_tags': 'TASK_DEFINITION', } self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs) self.ecs.get_hook() def setUp(self): self.set_up_operator() # pylint: disable=no-value-for-parameter def test_init(self): assert self.ecs.region_name == 'eu-west-1' assert self.ecs.task_definition == 't' assert self.ecs.aws_conn_id is None assert self.ecs.cluster == 'c' assert self.ecs.overrides == {} self.ecs.get_hook() assert self.ecs.hook == self.aws_hook_mock.return_value self.aws_hook_mock.assert_called_once() def test_template_fields_overrides(self): assert self.ecs.template_fields == ('overrides', ) @parameterized.expand([ ['EC2', None], ['FARGATE', None], ['EC2', { 'testTagKey': 'testTagValue' }], ['', { 'testTagKey': 'testTagValue' }], ]) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock): self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter client_mock = self.aws_hook_mock.return_value.get_conn.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES self.ecs.execute(None) self.aws_hook_mock.return_value.get_conn.assert_called_once() extend_args = {} if launch_type: extend_args['launchType'] = launch_type if launch_type == 'FARGATE': extend_args['platformVersion'] = 'LATEST' if tags: extend_args['tags'] = [{ 'key': k, 'value': v } for (k, v) in tags.items()] client_mock.run_task.assert_called_once_with( cluster='c', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], placementStrategy=[{ 'field': 'memory', 'type': 'binpack' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'] } }, propagateTags='TASK_DEFINITION', **extend_args, ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_conn.return_value resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures with pytest.raises(AirflowException): self.ecs.execute(None) self.aws_hook_mock.return_value.get_conn.assert_called_once() client_mock.run_task.assert_called_once_with( cluster='c', launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', placementConstraints=[{ 'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf' }], placementStrategy=[{ 'field': 'memory', 'type': 'binpack' }], networkConfiguration={ 'awsvpcConfiguration': { 'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'], } }, propagateTags='TASK_DEFINITION', ) def test_wait_end_tasks(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with( cluster='c', tasks=['arn']) assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts def test_check_success_tasks_raises(self): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1 }] }] } with pytest.raises(Exception) as ctx: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. assert "This task is not in success state " in str(ctx.value) assert "'name': 'foo'" in str(ctx.value) assert "'lastStatus': 'STOPPED'" in str(ctx.value) assert "'exitCode': 1" in str(ctx.value) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'PENDING' }] }] } with pytest.raises(Exception) as ctx: self.ecs._check_success_task() # Ordering of str(dict) is not guaranteed. assert "This task is still pending " in str(ctx.value) assert "'name': 'container-name'" in str(ctx.value) assert "'lastStatus': 'PENDING'" in str(ctx.value) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [ { 'name': 'foo', 'exitCode': 1 }, { 'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0 }, ] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_host_terminated_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', "containers": [{ "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 # pylint: disable=line-too-long "lastStatus": "RUNNING", "name": "wordpress", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long }], "desiredStatus": "STOPPED", "lastStatus": "STOPPED", "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11", # noqa: E501 # pylint: disable=line-too-long }] } with pytest.raises(AirflowException) as ctx: self.ecs._check_success_task() assert "The task was stopped because the host instance terminated:" in str( ctx.value) assert "Host EC2 (" in str(ctx.value) assert ") terminated" in str(ctx.value) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { 'tasks': [{ 'containers': [{ 'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0 }] }] } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) @parameterized.expand([ ['EC2', None], ['FARGATE', None], ['EC2', { 'testTagKey': 'testTagValue' }], ['', { 'testTagKey': 'testTagValue' }], ]) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') @mock.patch.object(ECSOperator, '_start_task') def test_reattach_successful(self, launch_type, tags, start_mock, check_mock, wait_mock): self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter client_mock = self.aws_hook_mock.return_value.get_conn.return_value client_mock.describe_task_definition.return_value = { 'taskDefinition': { 'family': 'f' } } client_mock.list_tasks.return_value = { 'taskArns': [ 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' ] } self.ecs.reattach = True self.ecs.execute(None) self.aws_hook_mock.return_value.get_conn.assert_called_once() extend_args = {} if launch_type: extend_args['launchType'] = launch_type if launch_type == 'FARGATE': extend_args['platformVersion'] = 'LATEST' if tags: extend_args['tags'] = [{ 'key': k, 'value': v } for (k, v) in tags.items()] client_mock.describe_task_definition.assert_called_once_with('t') client_mock.list_tasks.assert_called_once_with(cluster='c', launchType=launch_type, desiredStatus='RUNNING', family='f') start_mock.assert_not_called() wait_mock.assert_called_once_with() check_mock.assert_called_once_with() assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' @mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output") def test_execute_xcom_with_log(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = True assert self.ecs.execute( None) == mock_cloudwatch_log_message.return_value @mock.patch.object(ECSOperator, '_last_log_message', return_value=None) def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = True assert self.ecs.execute( None) == mock_cloudwatch_log_message.return_value @mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output") def test_execute_xcom_disabled(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = False assert self.ecs.execute(None) is None
for i in range(int(con_cnt)): esc_taske = ECSOperator( task_id="airflow_task_dev_container_task_" + str(i), dag=dag, aws_conn_id="aws_default", region_name="ap-northeast-2", task_definition="dev-task-definition-01", cluster="dev-aicel-cluster", launch_type="FARGATE", overrides={ "containerOverrides": [ { "name": "dev-task-definition-01", "command": ["python", "/app/test.py"], }, ], }, network_configuration={ "awsvpcConfiguration": { "assignPublicIp": "ENABLED", "subnets": [os.environ.get("SUBNET_ID", "subnet-246db769")], }, }, tags={ "Owner": "airflow", "Environment": "Development", }, awslogs_group="/ecs/dev-task-definition-01", awslogs_stream_prefix="ecs/dev-task-definition-01", retries=AIRFLOW_ECS_OPERATOR_RETRIES, retry_delay=timedelta(seconds=10))