async def test_pid_run_stage_with_exception( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_pid_shard_stage, mock_pid_prepare_stage, ) -> None: mock_pid_shard_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) instance_id = "344" protocol = PIDProtocol.UNION_PID pid_role = PIDRole.PARTNER num_shards = 50 input_path = "abc.text" output_path = "def.txt" dispatcher = PIDDispatcher(instance_id=instance_id, instance_repository=mock_instance_repo) dispatcher.build_stages( input_path=input_path, output_path=output_path, num_shards=num_shards, protocol=protocol, role=pid_role, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), ) # run pid shard stage await dispatcher.run_stage(mock_pid_shard_stage()) self.assertEqual(len(dispatcher.dag.nodes), 2) # attempt to fail the prepare stage mock_pid_prepare_stage().run = Exception() with self.assertRaises(PIDStageFailureError): await dispatcher.run_stage(mock_pid_prepare_stage()) self.assertEqual(len(dispatcher.dag.nodes), 2) # rerun the failed stage once again mock_pid_prepare_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) await dispatcher.run_stage(mock_pid_prepare_stage()) # verify the stage is run successfully self.assertEqual(len(dispatcher.dag.nodes), 1)
async def test_union_pid_run_all_order( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_pid_shard_stage, mock_pid_prepare_stage, mock_pid_run_protocol_stage, ) -> None: complete_mock = AsyncMock(return_value=PIDStageStatus.COMPLETED) mock_pid_shard_stage().run = complete_mock mock_pid_prepare_stage().run = complete_mock mock_pid_run_protocol_stage().run = complete_mock instance_id = "456" protocol = PIDProtocol.UNION_PID pid_role = PIDRole.PARTNER num_shards = 50 is_validating = False input_path = "abc.text" output_path = "def.txt" dispatcher = PIDDispatcher(instance_id=instance_id, instance_repository=mock_instance_repo) sample_pid_instance = self._get_sample_pid_instance( instance_id=instance_id, protocol=protocol, pid_role=pid_role, num_shards=num_shards, is_validating=is_validating, input_path=input_path, output_path=output_path, ) dispatcher.instance_repository.read = MagicMock( return_value=sample_pid_instance) dispatcher.build_stages( input_path=input_path, output_path=output_path, num_shards=num_shards, protocol=protocol, role=pid_role, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), ) # pre-run DAG should have 3 nodes self.assertEqual(len(dispatcher.dag.nodes), 3) await dispatcher.run_all() # post run DAG should be empty self.assertEqual(len(dispatcher.dag.nodes), 0) # Expect each (mocked) node to have called run() once self.assertEqual(complete_mock.mock.call_count, 3)
async def test_pid_shard_stage_service( self, pc_role: PrivateComputationRole, test_num_containers: int, has_hmac_key: bool, ) -> None: hamc_key_expected = self.test_hmac_key if has_hmac_key else None pc_instance = self.create_sample_pc_instance(pc_role, test_num_containers, hamc_key_expected) stage_svc = PIDShardStageService( storage_svc=self.mock_storage_svc, onedocker_svc=self.mock_onedocker_svc, onedocker_binary_config_map=self.onedocker_binary_config_map, container_timeout=self.container_timeout, ) containers = [ self.create_container_instance() for _ in range(test_num_containers) ] self.mock_onedocker_svc.start_containers = MagicMock( return_value=containers) self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock( return_value=containers) updated_pc_instance = await stage_svc.run_async(pc_instance=pc_instance ) env_vars = { "ONEDOCKER_REPOSITORY_PATH": self.onedocker_binary_config.repository_path } args_ls_expect = self.get_args_expect(pc_role, test_num_containers, has_hmac_key) # test the start_containers is called with expected parameters self.mock_onedocker_svc.start_containers.assert_called_with( package_name=self.binary_name, version=self.onedocker_binary_config.binary_version, cmd_args_list=args_ls_expect, timeout=self.container_timeout, env_vars=env_vars, ) # test the return value is as expected self.assertEqual( len(updated_pc_instance.instances), 1, "Failed to add the StageStageInstance into pc_instance", ) stage_state_expect = StageStateInstance( pc_instance.instance_id, pc_instance.current_stage.name, containers=containers, ) stage_state_actual = updated_pc_instance.instances[0] self.assertEqual( stage_state_actual, stage_state_expect, "Appended StageStageInstance is not as expected", )
async def test_pid_run_protocol_stage( self, pc_role: PrivateComputationRole, multikey_enabled: bool ) -> None: protocol = ( PIDProtocol.UNION_PID_MULTIKEY if self.test_num_containers == 1 and multikey_enabled else PIDProtocol.UNION_PID ) pc_instance = self.create_sample_pc_instance(pc_role) stage_svc = PIDRunProtocolStageService( storage_svc=self.mock_storage_svc, onedocker_svc=self.mock_onedocker_svc, onedocker_binary_config_map=self.onedocker_binary_config_map, multikey_enabled=multikey_enabled, ) containers = [ await self.create_container_instance() for _ in range(self.test_num_containers) ] self.mock_onedocker_svc.start_containers = MagicMock(return_value=containers) self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock( return_value=containers ) updated_pc_instance = await stage_svc.run_async( pc_instance=pc_instance, server_ips=self.server_ips ) binary_name = PIDRunProtocolBinaryService.get_binary_name(protocol, pc_role) binary_config = self.onedocker_binary_config_map[binary_name] env_vars = {ONEDOCKER_REPOSITORY_PATH: binary_config.repository_path} args_str_expect = self.get_args_expect(pc_role, protocol, self.use_row_numbers) # test the start_containers is called with expected parameters self.mock_onedocker_svc.start_containers.assert_called_with( package_name=binary_name, version=binary_config.binary_version, cmd_args_list=args_str_expect, timeout=DEFAULT_CONTAINER_TIMEOUT_IN_SEC, env_vars=env_vars, ) # test the return value is as expected self.assertEqual( len(updated_pc_instance.instances), self.test_num_containers, "Failed to add the StageStageInstance into pc_instance", ) stage_state_expect = StageStateInstance( pc_instance.instance_id, pc_instance.current_stage.name, containers=containers, ) stage_state_actual = updated_pc_instance.instances[0] self.assertEqual( stage_state_actual, stage_state_expect, "Appended StageStageInstance is not as expected", )
async def test_pid_prepare_stage_service( self, pc_role: PrivateComputationRole, multikey_enabled: bool, test_num_containers: int, ) -> None: pid_protocol = (PIDProtocol.UNION_PID_MULTIKEY if test_num_containers == 1 and multikey_enabled else PIDProtocol.UNION_PID) max_col_cnt_expect = (DEFAULT_MULTIKEY_PROTOCOL_MAX_COLUMN_COUNT if pid_protocol is PIDProtocol.UNION_PID_MULTIKEY else 1) pc_instance = self.create_sample_pc_instance(pc_role, test_num_containers) stage_svc = PIDPrepareStageService( storage_svc=self.mock_storage_svc, onedocker_svc=self.mock_onedocker_svc, onedocker_binary_config_map=self.onedocker_binary_config_map, multikey_enabled=multikey_enabled, ) containers = [ self.create_container_instance() for _ in range(test_num_containers) ] self.mock_onedocker_svc.start_containers = MagicMock( return_value=containers) self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock( return_value=containers) updated_pc_instance = await stage_svc.run_async(pc_instance=pc_instance ) env_vars = { "ONEDOCKER_REPOSITORY_PATH": self.onedocker_binary_config.repository_path } args_ls_expect = self.get_args_expected(pc_role, test_num_containers, max_col_cnt_expect) # test the start_containers is called with expected parameters self.mock_onedocker_svc.start_containers.assert_called_with( package_name=self.binary_name, version=self.onedocker_binary_config.binary_version, cmd_args_list=args_ls_expect, timeout=self.container_timeout, env_vars=env_vars, ) # test the return value is as expected self.assertEqual( len(updated_pc_instance.instances), 1, "Failed to add the StageStateInstance into pc_instance", ) stage_state_expect = StageStateInstance( pc_instance.instance_id, pc_instance.current_stage.name, containers=containers, ) stage_state_actual = updated_pc_instance.instances[0] self.assertEqual( stage_state_actual, stage_state_expect, "Appended StageStateInstance is not as expected", )
async def test_valid_custom_flow( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_get_execution_flow, mock_pid_run_protocol_stage, mock_pid_prepare_stage, mock_pid_shard_stage, ) -> None: # custom flow with non-linear dependency mock_get_execution_flow.return_value = PIDFlow( name="union_pid_advertiser", base_flow="union_pid", # pyre-fixme[6]: For 3rd param expected `Dict[UnionPIDStage, List[str]]` # but got `List[Variable[_T]]`. extra_args=[], flow={ UnionPIDStage.ADV_SHARD: [ UnionPIDStage.ADV_PREPARE, UnionPIDStage.ADV_RUN_PID, ], UnionPIDStage.ADV_PREPARE: [], UnionPIDStage.ADV_RUN_PID: [], }, ) mock_pid_prepare_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_shard_stage().stage_type = UnionPIDStage.PUBLISHER_SHARD mock_pid_shard_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_run_protocol_stage().type = UnionPIDStage.PUBLISHER_RUN_PID mock_pid_run_protocol_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) dispatcher = PIDDispatcher(instance_id="456", instance_repository=mock_instance_repo) dispatcher.build_stages( input_path="abc.text", output_path="def.txt", num_shards=50, protocol=PIDProtocol.UNION_PID, role=PIDRole.PARTNER, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), ) self.assertEqual(len(dispatcher.dag.nodes), 3) await dispatcher.run_all() # Make sure each stage is called exactly once mock_pid_shard_stage().run.mock.assert_called_once() mock_pid_prepare_stage().run.mock.assert_called_once() mock_pid_run_protocol_stage().run.mock.assert_called_once() self.assertEqual( mock_pid_shard_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["abc.text"], output_paths=["def.txt_advertiser_sharded"], num_shards=50, instance_id="456", ), ) self.assertEqual( mock_pid_prepare_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["def.txt_advertiser_sharded"], output_paths=["def.txt_advertiser_prepared"], num_shards=50, instance_id="456", ), ) self.assertEqual( mock_pid_run_protocol_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["def.txt_advertiser_sharded"], output_paths=["def.txt_advertiser_pid_matched"], num_shards=50, instance_id="456", ), ) self.assertEqual(len(dispatcher.dag.nodes), 0) # all done
async def test_union_pid_flow_valid_partner_with_data_path_spine_path( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_pid_run_protocol_stage, mock_pid_prepare_stage, mock_pid_shard_stage, ) -> None: mock_pid_prepare_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_shard_stage().stage_type = UnionPIDStage.PUBLISHER_SHARD mock_pid_shard_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_run_protocol_stage( ).stage_type = UnionPIDStage.PUBLISHER_RUN_PID mock_pid_run_protocol_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) dispatcher = PIDDispatcher(instance_id="456", instance_repository=mock_instance_repo) # The stage contains two additional parameters: data_path and spine_path # data_path is the output of the shard stage # spine_path is the output of the protocol run stage dispatcher.build_stages( input_path="abc.text", output_path="def.txt", num_shards=50, protocol=PIDProtocol.UNION_PID, role=PIDRole.PARTNER, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), data_path="data.txt", spine_path="spine.txt", ) self.assertEqual(len(dispatcher.dag.nodes), 3) await dispatcher.run_all() # Make sure each stage is called exactly once mock_pid_shard_stage().run.mock.assert_called_once() mock_pid_prepare_stage().run.mock.assert_called_once() mock_pid_run_protocol_stage().run.mock.assert_called_once() # expect output_paths as specified in data_path self.assertEqual( mock_pid_shard_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["abc.text"], output_paths=["data.txt_advertiser_sharded"], num_shards=50, instance_id="456", ), ) # expect input_paths as specified in data_path for shard stage self.assertEqual( mock_pid_prepare_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["data.txt_advertiser_sharded"], output_paths=["def.txt_advertiser_prepared"], num_shards=50, instance_id="456", ), ) # expect output_paths as specified in spine_path for protocol run stage self.assertEqual( mock_pid_run_protocol_stage().run.mock.call_args[0][0], PIDStageInput( input_paths=["def.txt_advertiser_prepared"], output_paths=["spine.txt_advertiser_pid_matched"], num_shards=50, instance_id="456", ), ) self.assertEqual(len(dispatcher.dag.nodes), 0) # all done
async def test_union_pid_run_only_unfinished_stages( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_pid_shard_stage, mock_pid_prepare_stage, mock_pid_run_protocol_stage, ) -> None: mock_pid_shard_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_prepare_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_run_protocol_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) instance_id = "456" protocol = PIDProtocol.UNION_PID pid_role = PIDRole.PARTNER num_shards = 50 is_validating = False input_path = "abc.text" output_path = "def.txt" dispatcher = PIDDispatcher(instance_id=instance_id, instance_repository=mock_instance_repo) sample_pid_instance = self._get_sample_pid_instance( instance_id=instance_id, protocol=protocol, pid_role=pid_role, num_shards=num_shards, is_validating=is_validating, input_path=input_path, output_path=output_path, ) # make the instance think it has completed the shard stage previously sample_pid_instance.stages_status[ mock_pid_shard_stage().stage_type] = PIDStageStatus.COMPLETED # make the instance think it has attempted and failed the prepare stage previously sample_pid_instance.stages_status[ mock_pid_prepare_stage().stage_type] = PIDStageStatus.FAILED dispatcher.instance_repository.read = MagicMock( return_value=sample_pid_instance) dispatcher.build_stages( input_path=input_path, output_path=output_path, num_shards=num_shards, protocol=protocol, role=pid_role, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), ) # pre-run DAG should have 2 nodes, since PID Shard is already finished self.assertEqual(len(dispatcher.dag.nodes), 2) await dispatcher.run_all() # post run DAG should be empty self.assertEqual(len(dispatcher.dag.nodes), 0) # pid shard stage was already finished, so it should not be called again mock_pid_shard_stage().run.mock.assert_not_called() # prepare failed, so it should run again mock_pid_prepare_stage().run.mock.assert_called_once() # pid run was never attempted, so it should run mock_pid_run_protocol_stage().run.mock.assert_called_once()
async def test_union_pid_run_stages_one_by_one( self, mock_instance_repo, mock_aws_container_service, mock_onedocker_service, mock_s3_storage_service, mock_pid_shard_stage, mock_pid_prepare_stage, mock_pid_run_protocol_stage, ) -> None: mock_pid_shard_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_prepare_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) mock_pid_run_protocol_stage().run = AsyncMock( return_value=PIDStageStatus.COMPLETED) instance_id = "456" protocol = PIDProtocol.UNION_PID pid_role = PIDRole.PARTNER num_shards = 50 input_path = "abc.text" output_path = "def.txt" dispatcher = PIDDispatcher(instance_id=instance_id, instance_repository=mock_instance_repo) dispatcher.build_stages( input_path=input_path, output_path=output_path, num_shards=num_shards, protocol=protocol, role=pid_role, storage_svc=mock_s3_storage_service, onedocker_svc=mock_onedocker_service, # pyre-fixme[6]: For 8th param expected `DefaultDict[str, # OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`. onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"), ) # pre-run DAG should have 3 nodes self.assertEqual(len(dispatcher.dag.nodes), 3) await dispatcher.run_stage(mock_pid_shard_stage()) self.assertEqual(len(dispatcher.dag.nodes), 2) # attempt to run out of order with self.assertRaises(PIDStageFailureError): await dispatcher.run_stage(mock_pid_run_protocol_stage()) # dag should not have been affected self.assertEqual(len(dispatcher.dag.nodes), 2) # continue running in correct order await dispatcher.run_stage(mock_pid_prepare_stage()) self.assertEqual(len(dispatcher.dag.nodes), 1) await dispatcher.run_stage(mock_pid_run_protocol_stage()) self.assertEqual(len(dispatcher.dag.nodes), 0) # attempt to rerun an already completed stage with self.assertRaises(PIDStageFailureError): await dispatcher.run_stage(mock_pid_run_protocol_stage()) # each stage should only have been called once mock_pid_shard_stage().run.mock.assert_called_once() mock_pid_prepare_stage().run.mock.assert_called_once() mock_pid_run_protocol_stage().run.mock.assert_called_once()