def _create_container( self, id: int, status: ContainerInstanceStatus) -> ContainerInstance: return ContainerInstance( f"arn:aws:ecs:region:account_id:task/container_id_{id}", f"192.0.2.{id}", status, )
def gen_dummy_container_instance() -> ContainerInstance: """Creates a dummy container instance to be used in unit tests""" return ContainerInstance( instance_id= "arn:aws:ecs:us-west-2:000000000000:task/cluster-name/subnet", status=ContainerInstanceStatus.COMPLETED, ip_address="10.0.10.242", )
def setUp(self) -> None: self.stage_state_instance = StageStateInstance( instance_id="stage_state_instance", stage_name="test_stage", status=StageStateInstanceStatus.COMPLETED, containers=[ ContainerInstance( instance_id="test_container_instance_1", ip_address="192.0.2.4", status=ContainerInstanceStatus.COMPLETED, ), ContainerInstance( instance_id="test_container_instance_2", ip_address="192.0.2.5", status=ContainerInstanceStatus.COMPLETED, ), ], creation_ts=1646642432, end_ts=1646642432 + 5, )
async def test_wait_for_containers_fail(self, get_containers) -> None: container_1_start = ContainerInstance( "arn:aws:ecs:region:account_id:task/container_id_1", "192.0.2.0", ContainerInstanceStatus.STARTED, ) container_2_start = ContainerInstance( "arn:aws:ecs:region:account_id:task/container_id_2", "192.0.2.1", ContainerInstanceStatus.STARTED, ) container_1_complete = ContainerInstance( "arn:aws:ecs:region:account_id:task/container_id_1", "192.0.2.0", ContainerInstanceStatus.COMPLETED, ) container_2_fail = ContainerInstance( "arn:aws:ecs:region:account_id:task/container_id_2", "192.0.2.1", ContainerInstanceStatus.FAILED, ) get_containers.side_effect = [ [container_1_start], [container_1_complete], [container_2_fail], ] containers = [ container_1_start, container_2_start, ] updated_containers = await RunBinaryBaseService.wait_for_containers_async( self.onedocker_svc, containers, poll=0) self.assertEqual(updated_containers[0], container_1_complete) self.assertEqual(updated_containers[1], container_2_fail)
async def _run_sub_test( wait_for_containers: bool, expected_container_status: ContainerInstanceStatus, ) -> None: with patch.object(PIDStage, "update_instance_containers"): test_onedocker_binary_config = OneDockerBinaryConfig( tmp_directory="/test_tmp_directory/", binary_version="latest", repository_path="test_path/", ) container = ContainerInstance( instance_id="123", ip_address="192.0.2.0", status=expected_container_status, ) mock_sharder.return_value = [container] stage = PIDShardStage( stage=UnionPIDStage.PUBLISHER_SHARD, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=test_onedocker_binary_config, ) test_input_path = "foo" test_output_path = "bar" test_num_shards = 1 test_hmac_key = "CoXbp7BOEvAN9L1CB2DAORHHr3hB7wE7tpxMYm07tc0=" shard_path = PIDShardStage.get_sharded_filepath( test_output_path, 0) self.assertEqual(f"{test_output_path}_0", shard_path) res = await stage.shard( "123", test_input_path, test_output_path, test_num_shards, test_hmac_key, wait_for_containers=wait_for_containers, ) self.assertEqual( PIDStage.get_stage_status_from_containers([container]), res, ) mock_sharder.assert_called_once()
async def test_get_status_logs_a_helpful_error_when_the_validation_fails( self, mock_get_pc_status_from_stage_state) -> None: pc_instance = self._pc_instance task_id = "test-task-id-123" cluster_name = "test-cluster-name" account_id = "1234567890" region = "us-west-1" instance_id = f"arn:aws:ecs:{region}:{account_id}:task/{cluster_name}/{task_id}" container_instance = ContainerInstance(instance_id=instance_id) stage_state_instance = StageStateInstance( instance_id="instance-id-0", stage_name="stage-name-1", containers=[container_instance], ) unioned_pc_instances = [stage_state_instance] # pyre-fixme[8]: Attribute has type `List[Union[StageStateInstance, # PCSMPCInstance, PIDInstance, PostProcessingInstance]]`; used as # `List[StageStateInstance]`. pc_instance.instances = unioned_pc_instances expected_status = PrivateComputationInstanceStatus.INPUT_DATA_VALIDATION_FAILED onedocker_svc_mock = MagicMock() onedocker_svc_mock.get_cluster.side_effect = [cluster_name] pc_validator_config = PCValidatorConfig( region=region, pc_pre_validator_enabled=True, ) failed_task_link = f"https://{region}.console.aws.amazon.com/ecs/home?region={region}#/clusters/{cluster_name}/tasks/{task_id}/details" logger_mock = MagicMock() mock_get_pc_status_from_stage_state.side_effect = [expected_status] stage_service = InputDataValidationStageService( pc_validator_config, onedocker_svc_mock, self.onedocker_binary_config_map) stage_service._logger = logger_mock status = stage_service.get_status(pc_instance) self.assertEqual(status, expected_status) logger_mock.error.assert_called_with( f"[PCPreValidation] - stage failed because of some failed validations. Please check the logs in ECS for task id '{task_id}' to see the validation issues:\n" + f"Failed task link: {failed_task_link}")
async def _run_sub_test( wait_for_containers: bool, expected_container_status: ContainerInstanceStatus, ) -> None: with patch.object( CppUnionPIDDataPreparerService, "prepare_on_container_async" ) as mock_prepare_on_container_async, patch.object( PIDStage, "update_instance_containers"): container = ContainerInstance( instance_id="123", ip_address="192.0.2.0", status=expected_container_status, ) mock_prepare_on_container_async.return_value = container stage = PIDPrepareStage( stage=UnionPIDStage.PUBLISHER_PREPARE, instance_repository=mock_instance_repo, storage_svc="STORAGE", # pyre-ignore onedocker_svc="ONEDOCKER", # pyre-ignore onedocker_binary_config=MagicMock( task_definition="offline-task:1#container", tmp_directory="/tmp/", binary_version="latest", ), ) res = await stage.prepare( instance_id="123", input_path="in", output_path="out", num_shards=1, wait_for_containers=wait_for_containers, ) self.assertEqual( PIDStage.get_stage_status_from_containers([container]), res, )
async def _run_sub_test(wait_for_containers: bool, ) -> None: ip = "192.0.2.0" container = ContainerInstance(instance_id="123", ip_address=ip) mock_onedocker_svc.start_containers = MagicMock( return_value=[container]) mock_onedocker_svc.wait_for_pending_containers = AsyncMock( return_value=[container]) container.status = (ContainerInstanceStatus.COMPLETED if wait_for_containers else ContainerInstanceStatus.STARTED) mock_wait_for_containers_async.return_value = [container] test_onedocker_binary_config = OneDockerBinaryConfig( tmp_directory="/test_tmp_directory/", binary_version="latest", repository_path="test_path/", ) stage = PIDShardStage( stage=UnionPIDStage.PUBLISHER_SHARD, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=test_onedocker_binary_config, ) instance_id = "444" stage_input = PIDStageInput( input_paths=["in"], output_paths=["out"], num_shards=123, instance_id=instance_id, ) # Basic test: All good with patch.object(PIDShardStage, "files_exist") as mock_fe: mock_fe.return_value = True stage = PIDShardStage( stage=UnionPIDStage.PUBLISHER_SHARD, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=test_onedocker_binary_config, ) status = await stage.run( stage_input, wait_for_containers=wait_for_containers) self.assertEqual( PIDStageStatus.COMPLETED if wait_for_containers else PIDStageStatus.STARTED, status, ) mock_onedocker_svc.start_containers.assert_called_once() if wait_for_containers: mock_wait_for_containers_async.assert_called_once() else: mock_wait_for_containers_async.assert_not_called() # instance status is updated to READY, STARTED, then COMPLETED mock_instance_repo.read.assert_called_with(instance_id) self.assertEqual(mock_instance_repo.read.call_count, 4) self.assertEqual(mock_instance_repo.update.call_count, 4) # Input not ready with patch.object(PIDShardStage, "files_exist") as mock_fe: mock_fe.return_value = False status = await stage.run( stage_input, wait_for_containers=wait_for_containers) self.assertEqual(PIDStageStatus.FAILED, status) # Multiple input paths (invariant exception) with patch.object(PIDShardStage, "files_exist") as mock_fe: with self.assertRaises(ValueError): mock_fe.return_value = True stage_input.input_paths = ["in1", "in2"] stage = PIDShardStage( stage=UnionPIDStage.PUBLISHER_SHARD, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=test_onedocker_binary_config, ) status = await stage.run( stage_input, wait_for_containers=wait_for_containers) self.assertEqual( PIDStageStatus.COMPLETED if wait_for_containers else PIDStageStatus.STARTED, status, )
async def create_container_instance(self) -> ContainerInstance: return ContainerInstance( instance_id="test_container_instance_123", ip_address="127.0.0.1", status=ContainerInstanceStatus.COMPLETED, )
async def _run_sub_test(wait_for_containers: bool) -> None: ip = "192.0.2.0" container = ContainerInstance(instance_id="123", ip_address=ip) mock_onedocker_service.start_containers = MagicMock( return_value=[container]) mock_onedocker_service.wait_for_pending_containers = AsyncMock( return_value=[container]) container.status = (ContainerInstanceStatus.COMPLETED if wait_for_containers else ContainerInstanceStatus.STARTED) mock_wait_for_containers_async.return_value = [container] with patch.object(PIDProtocolRunStage, "files_exist") as mock_files_exist, patch.object( PIDProtocolRunStage, "put_server_ips") as mock_put_server_ips: mock_files_exist.return_value = True num_shards = 2 input_path = "in" output_path = "out" # Run publisher publisher_run_stage = PIDProtocolRunStage( stage=UnionPIDStage.PUBLISHER_RUN_PID, instance_repository=mock_instance_repo, storage_svc=mock_storage_service, onedocker_svc=mock_onedocker_service, onedocker_binary_config=self.onedocker_binary_config, ) instance_id = "123" stage_input = PIDStageInput( input_paths=[input_path], output_paths=[output_path], num_shards=num_shards, instance_id=instance_id, ) # if we are waiting for containers, then the stage should finish # otherwise, it should start and then return self.assertEqual( PIDStageStatus.COMPLETED if wait_for_containers else PIDStageStatus.STARTED, await publisher_run_stage.run( stage_input=stage_input, wait_for_containers=wait_for_containers, ), ) # Check create_instances_async was called with the correct parameters if wait_for_containers: mock_wait_for_containers_async.assert_called_once() else: mock_wait_for_containers_async.assert_not_called() mock_onedocker_service.start_containers.assert_called_once() ( _, called_kwargs, ) = mock_onedocker_service.start_containers.call_args_list[0] self.assertEqual(num_shards, len(called_kwargs["cmd_args_list"])) # Check `put_payload` was called with the correct parameters mock_put_server_ips.assert_called_once_with( instance_id=instance_id, server_ips=[ip]) # if wait for containers is False, there are 4 updates. # if wait_for_containers is True, then there is another update # that updates the instance status and containers to complete, so 5 mock_instance_repo.read.assert_called_with(instance_id) self.assertEqual(mock_instance_repo.read.call_count, 4 + int(wait_for_containers)) self.assertEqual( mock_instance_repo.update.call_count, 4 + int(wait_for_containers), )
async def _run_sub_test(wait_for_containers: bool, ) -> None: ip = "192.0.2.0" container = ContainerInstance( instance_id="123", ip_address=ip, status=ContainerInstanceStatus.STARTED) mock_onedocker_svc.start_containers = MagicMock( return_value=[container]) mock_onedocker_svc.wait_for_pending_containers = AsyncMock( return_value=[container]) container.status = (ContainerInstanceStatus.COMPLETED if wait_for_containers else ContainerInstanceStatus.STARTED) mock_wait_for_containers_async.return_value = [container] stage = PIDPrepareStage( stage=UnionPIDStage.PUBLISHER_PREPARE, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=MagicMock( task_definition="offline-task:1#container", tmp_directory="/tmp/", binary_version="latest", ), ) instance_id = "444" stage_input = PIDStageInput( input_paths=["in"], output_paths=["out"], num_shards=2, instance_id=instance_id, ) # Basic test: All good with patch.object(PIDPrepareStage, "files_exist") as mock_fe: mock_fe.return_value = True stage = PIDPrepareStage( stage=UnionPIDStage.PUBLISHER_PREPARE, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=MagicMock( task_definition="offline-task:1#container", tmp_directory="/tmp/", binary_version="latest", ), ) status = await stage.run( stage_input, wait_for_containers=wait_for_containers) self.assertEqual( PIDStageStatus.COMPLETED if wait_for_containers else PIDStageStatus.STARTED, status, ) self.assertEqual( mock_onedocker_svc.start_containers.call_count, 2) if wait_for_containers: self.assertEqual(mock_wait_for_containers_async.call_count, 2) else: mock_wait_for_containers_async.assert_not_called() mock_instance_repo.read.assert_called_with(instance_id) self.assertEqual(mock_instance_repo.read.call_count, 4) self.assertEqual(mock_instance_repo.update.call_count, 4) with patch.object(PIDPrepareStage, "files_exist") as mock_fe, patch.object( PIDPrepareStage, "prepare") as mock_prepare: mock_fe.return_value = True status = await stage.run( stage_input, wait_for_containers=wait_for_containers) mock_prepare.assert_called_with(instance_id, "in", "out", 2, wait_for_containers, None) # Input not ready with patch.object(PIDPrepareStage, "files_exist") as mock_fe: mock_fe.return_value = False status = await stage.run( stage_input, wait_for_containers=wait_for_containers) self.assertEqual(PIDStageStatus.FAILED, status) # Multiple input paths (invariant exception) with patch.object(PIDPrepareStage, "files_exist") as mock_fe: with self.assertRaises(ValueError): mock_fe.return_value = True stage_input.input_paths = ["in1", "in2"] stage = PIDPrepareStage( stage=UnionPIDStage.PUBLISHER_PREPARE, instance_repository=mock_instance_repo, storage_svc=mock_storage_svc, onedocker_svc=mock_onedocker_svc, onedocker_binary_config=MagicMock( task_definition="offline-task:1#container", tmp_directory="/tmp/", binary_version="latest", ), ) status = await stage.run( stage_input, wait_for_containers=wait_for_containers)