class TestAWSDataSyncOperatorDelete(AWSDataSyncTestCaseBase): def set_up_operator(self, task_arn="self"): if task_arn == "self": task_arn = self.task_arn # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_delete_task_operator", dag=self.dag, task_arn=task_arn, delete_task_after_execution=True, wait_interval_seconds=0, ) def test_init(self, mock_get_conn): self.set_up_operator() # Airflow built-ins self.assertEqual(self.datasync.task_id, MOCK_DATA["delete_task_id"]) # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") # Assignments self.assertEqual(self.datasync.task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_not_called() def test_init_fails(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: with self.assertRaises(AirflowException): self.set_up_operator(task_arn=None) # ### Check mocks: mock_get_conn.assert_not_called() def test_delete_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() # Check how many tasks and locations we have tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 1) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) self.assertEqual(result["TaskArn"], self.task_arn) # Assert -1 additional task and 0 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 0) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # ### Check mocks: mock_get_conn.assert_called() def test_execute_specific_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) result = self.datasync.execute(None) self.assertEqual(result["TaskArn"], task_arn) self.assertEqual(self.datasync.task_arn, task_arn) # ### Check mocks: mock_get_conn.assert_called() def test_xcom_push(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() pushed_task_arn = ti.xcom_pull( task_ids=self.datasync.task_id, key="return_value" )["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called()
class TestAWSDataSyncOperator(AWSDataSyncTestCaseBase): def set_up_operator(self, task_arn="self"): if task_arn == "self": task_arn = self.task_arn # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_task_operator", dag=self.dag, wait_interval_seconds=0, task_arn=task_arn, ) def test_init(self, mock_get_conn): self.set_up_operator() # Airflow built-ins self.assertEqual(self.datasync.task_id, MOCK_DATA["task_id"]) # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") self.assertEqual(self.datasync.wait_interval_seconds, 0) # Assignments self.assertEqual(self.datasync.task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_not_called() def test_init_fails(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: with self.assertRaises(AirflowException): self.set_up_operator(task_arn=None) # ### Check mocks: mock_get_conn.assert_not_called() def test_execute_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: # Configure the Operator with the specific task_arn self.set_up_operator() self.assertEqual(self.datasync.task_arn, self.task_arn) # Check how many tasks and locations we have tasks = self.client.list_tasks() len_tasks_before = len(tasks["Tasks"]) locations = self.client.list_locations() len_locations_before = len(locations["Locations"]) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) task_execution_arn = result["TaskExecutionArn"] self.assertIsNotNone(task_execution_arn) # Assert 0 additional task and 0 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), len_tasks_before) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), len_locations_before) # Check with the DataSync client what happened task_execution = self.client.describe_task_execution( TaskExecutionArn=task_execution_arn ) self.assertEqual(task_execution["Status"], "SUCCESS") # Insist that this specific task was executed, not anything else task_execution_arn = task_execution["TaskExecutionArn"] # format of task_execution_arn: # arn:aws:datasync:us-east-1:111222333444:task/task-00000000000000003/execution/exec-00000000000000004 # format of task_arn: # arn:aws:datasync:us-east-1:111222333444:task/task-00000000000000003 self.assertEqual("/".join(task_execution_arn.split("/")[:2]), self.task_arn) # ### Check mocks: mock_get_conn.assert_called() @mock.patch.object(AWSDataSyncHook, "wait_for_task_execution") def test_failed_task(self, mock_wait, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client mock_wait.return_value = False # ### Begin tests: self.set_up_operator() # Execute the task with self.assertRaises(AirflowException): self.datasync.execute(None) # ### Check mocks: mock_get_conn.assert_called() @mock.patch.object(AWSDataSyncHook, "wait_for_task_execution") def test_killed_task(self, mock_wait, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: # Kill the task when doing wait_for_task_execution def kill_task(*args): self.datasync.on_kill() return True mock_wait.side_effect = kill_task self.set_up_operator() # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) task_execution_arn = result["TaskExecutionArn"] self.assertIsNotNone(task_execution_arn) # Verify the task was killed task = self.client.describe_task(TaskArn=self.task_arn) self.assertEqual(task["Status"], "AVAILABLE") task_execution = self.client.describe_task_execution( TaskExecutionArn=task_execution_arn ) self.assertEqual(task_execution["Status"], "ERROR") # ### Check mocks: mock_get_conn.assert_called() def test_execute_specific_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) result = self.datasync.execute(None) self.assertEqual(result["TaskArn"], task_arn) self.assertEqual(self.datasync.task_arn, task_arn) # ### Check mocks: mock_get_conn.assert_called() def test_xcom_push(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value") self.assertIsNotNone(xcom_result) # ### Check mocks: mock_get_conn.assert_called()
class TestAWSDataSyncOperatorGetTasks(AWSDataSyncTestCaseBase): def set_up_operator( self, task_arn=None, source_location_uri=SOURCE_LOCATION_URI, destination_location_uri=DESTINATION_LOCATION_URI, allow_random_task_choice=False ): # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_get_tasks_operator", dag=self.dag, task_arn=task_arn, source_location_uri=source_location_uri, destination_location_uri=destination_location_uri, create_source_location_kwargs=MOCK_DATA["create_source_location_kwargs"], create_destination_location_kwargs=MOCK_DATA[ "create_destination_location_kwargs" ], create_task_kwargs=MOCK_DATA["create_task_kwargs"], allow_random_task_choice=allow_random_task_choice, wait_interval_seconds=0, ) def test_init(self, mock_get_conn): self.set_up_operator() # Airflow built-ins self.assertEqual(self.datasync.task_id, MOCK_DATA["get_task_id"]) # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") self.assertFalse(self.datasync.allow_random_location_choice) # Assignments self.assertEqual( self.datasync.source_location_uri, MOCK_DATA["source_location_uri"] ) self.assertEqual( self.datasync.destination_location_uri, MOCK_DATA["destination_location_uri"], ) self.assertFalse(self.datasync.allow_random_task_choice) # ### Check mocks: mock_get_conn.assert_not_called() def test_init_fails(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: with self.assertRaises(AirflowException): self.set_up_operator(source_location_uri=None) with self.assertRaises(AirflowException): self.set_up_operator(destination_location_uri=None) with self.assertRaises(AirflowException): self.set_up_operator( source_location_uri=None, destination_location_uri=None ) # ### Check mocks: mock_get_conn.assert_not_called() def test_get_no_location(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() locations = self.client.list_locations() for location in locations["Locations"]: self.client.delete_location(LocationArn=location["LocationArn"]) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 0) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) locations = self.client.list_locations() self.assertIsNotNone(result) self.assertEqual(len(locations), 2) # ### Check mocks: mock_get_conn.assert_called() def test_get_no_tasks2(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() tasks = self.client.list_tasks() for task in tasks["Tasks"]: self.client.delete_task(TaskArn=task["TaskArn"]) tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 0) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) # ### Check mocks: mock_get_conn.assert_called() def test_get_one_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: # Make sure we dont cheat self.set_up_operator() self.assertEqual(self.datasync.task_arn, None) # Check how many tasks and locations we have tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 1) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) task_arn = result["TaskArn"] self.assertIsNotNone(task_arn) self.assertTrue(task_arn) self.assertEqual(task_arn, self.task_arn) # Assert 0 additional task and 0 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 1) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # ### Check mocks: mock_get_conn.assert_called() def test_get_many_tasks(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, ) # Check how many tasks and locations we have tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 2) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # Execute the task with self.assertRaises(AirflowException): self.datasync.execute(None) # Assert 0 additional task and 0 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 2) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) self.set_up_operator(task_arn=self.task_arn, allow_random_task_choice=True) self.datasync.execute(None) # ### Check mocks: mock_get_conn.assert_called() def test_execute_specific_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) result = self.datasync.execute(None) self.assertEqual(result["TaskArn"], task_arn) self.assertEqual(self.datasync.task_arn, task_arn) # ### Check mocks: mock_get_conn.assert_called() def test_xcom_push(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() pushed_task_arn = ti.xcom_pull( task_ids=self.datasync.task_id, key="return_value" )["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called()
class TestAWSDataSyncOperatorUpdate(AWSDataSyncTestCaseBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.datasync = None def set_up_operator(self, task_arn="self", update_task_kwargs="default"): if task_arn == "self": task_arn = self.task_arn if update_task_kwargs == "default": update_task_kwargs = { "Options": {"VerifyMode": "BEST_EFFORT", "Atime": "NONE"} } # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_update_task_operator", dag=self.dag, task_arn=task_arn, update_task_kwargs=update_task_kwargs, wait_interval_seconds=0, ) def test_init(self, mock_get_conn): self.set_up_operator() # Airflow built-ins self.assertEqual(self.datasync.task_id, MOCK_DATA["update_task_id"]) # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") # Assignments self.assertEqual(self.datasync.task_arn, self.task_arn) self.assertEqual( self.datasync.update_task_kwargs, MOCK_DATA["update_task_kwargs"] ) # ### Check mocks: mock_get_conn.assert_not_called() def test_init_fails(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: with self.assertRaises(AirflowException): self.set_up_operator(task_arn=None) # ### Check mocks: mock_get_conn.assert_not_called() def test_update_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() # Check task before update task = self.client.describe_task(TaskArn=self.task_arn) self.assertNotIn("Options", task) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) self.assertEqual(result["TaskArn"], self.task_arn) self.assertIsNotNone(self.datasync.task_arn) # Check it was updated task = self.client.describe_task(TaskArn=self.task_arn) self.assertEqual(task["Options"], UPDATE_TASK_KWARGS["Options"]) # ### Check mocks: mock_get_conn.assert_called() def test_execute_specific_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) result = self.datasync.execute(None) self.assertEqual(result["TaskArn"], task_arn) self.assertEqual(self.datasync.task_arn, task_arn) # ### Check mocks: mock_get_conn.assert_called() def test_xcom_push(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() pushed_task_arn = ti.xcom_pull( task_ids=self.datasync.task_id, key="return_value" )["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called()
class TestAWSDataSyncOperatorCreate(AWSDataSyncTestCaseBase): def set_up_operator( self, task_arn=None, source_location_uri=SOURCE_LOCATION_URI, destination_location_uri=DESTINATION_LOCATION_URI, allow_random_location_choice=False ): # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_create_task_operator", dag=self.dag, task_arn=task_arn, source_location_uri=source_location_uri, destination_location_uri=destination_location_uri, create_task_kwargs={"Options": {"VerifyMode": "NONE", "Atime": "NONE"}}, create_source_location_kwargs={ "Subdirectory": SOURCE_SUBDIR, "ServerHostname": SOURCE_HOST_NAME, "User": "******", "Password": "******", "AgentArns": ["some_agent"], }, create_destination_location_kwargs={ "S3BucketArn": DESTINATION_LOCATION_ARN, "S3Config": {"BucketAccessRoleArn": "myrole"}, }, allow_random_location_choice=allow_random_location_choice, wait_interval_seconds=0, ) def test_init(self, mock_get_conn): self.set_up_operator() # Airflow built-ins self.assertEqual(self.datasync.task_id, MOCK_DATA["create_task_id"]) # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") self.assertFalse(self.datasync.allow_random_task_choice) self.assertFalse( # Empty dict self.datasync.task_execution_kwargs ) # Assignments self.assertEqual( self.datasync.source_location_uri, MOCK_DATA["source_location_uri"] ) self.assertEqual( self.datasync.destination_location_uri, MOCK_DATA["destination_location_uri"], ) self.assertEqual( self.datasync.create_task_kwargs, MOCK_DATA["create_task_kwargs"] ) self.assertEqual( self.datasync.create_source_location_kwargs, MOCK_DATA["create_source_location_kwargs"], ) self.assertEqual( self.datasync.create_destination_location_kwargs, MOCK_DATA["create_destination_location_kwargs"], ) self.assertFalse( self.datasync.allow_random_location_choice ) # ### Check mocks: mock_get_conn.assert_not_called() def test_init_fails(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: with self.assertRaises(AirflowException): self.set_up_operator(source_location_uri=None) with self.assertRaises(AirflowException): self.set_up_operator(destination_location_uri=None) with self.assertRaises(AirflowException): self.set_up_operator( source_location_uri=None, destination_location_uri=None ) # ### Check mocks: mock_get_conn.assert_not_called() def test_create_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() # Delete all tasks: tasks = self.client.list_tasks() for task in tasks["Tasks"]: self.client.delete_task(TaskArn=task["TaskArn"]) # Check how many tasks and locations we have tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 0) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) task_arn = result["TaskArn"] # Assert 1 additional task and 0 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 1) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # Check task metadata task = self.client.describe_task(TaskArn=task_arn) self.assertEqual(task["Options"], CREATE_TASK_KWARGS["Options"]) # ### Check mocks: mock_get_conn.assert_called() def test_create_task_and_location(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() # Delete all tasks: tasks = self.client.list_tasks() for task in tasks["Tasks"]: self.client.delete_task(TaskArn=task["TaskArn"]) # Delete all locations: locations = self.client.list_locations() for location in locations["Locations"]: self.client.delete_location(LocationArn=location["LocationArn"]) # Check how many tasks and locations we have tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 0) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 0) # Execute the task result = self.datasync.execute(None) self.assertIsNotNone(result) # Assert 1 additional task and 2 additional locations tasks = self.client.list_tasks() self.assertEqual(len(tasks["Tasks"]), 1) locations = self.client.list_locations() self.assertEqual(len(locations["Locations"]), 2) # ### Check mocks: mock_get_conn.assert_called() def test_dont_create_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: tasks = self.client.list_tasks() tasks_before = len(tasks["Tasks"]) self.set_up_operator(task_arn=self.task_arn) self.datasync.execute(None) tasks = self.client.list_tasks() tasks_after = len(tasks["Tasks"]) self.assertEqual(tasks_before, tasks_after) # ### Check mocks: mock_get_conn.assert_called() def create_task_many_locations(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: # Create duplicate source location to choose from self.client.create_location_smb( **MOCK_DATA["create_source_location_kwargs"] ) self.set_up_operator(task_arn=self.task_arn) with self.assertRaises(AirflowException): self.datasync.execute(None) self.set_up_operator(task_arn=self.task_arn, allow_random_location_choice=True) self.datasync.execute(None) # ### Check mocks: mock_get_conn.assert_called() def test_execute_specific_task(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) result = self.datasync.execute(None) self.assertEqual(result["TaskArn"], task_arn) self.assertEqual(self.datasync.task_arn, task_arn) # ### Check mocks: mock_get_conn.assert_called() def test_xcom_push(self, mock_get_conn): # ### Set up mocks: mock_get_conn.return_value = self.client # ### Begin tests: self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value") self.assertIsNotNone(xcom_result) # ### Check mocks: mock_get_conn.assert_called()