async def test_pid_run_stage_with_exception(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_pid_shard_stage,
        mock_pid_prepare_stage,
    ) -> None:
        mock_pid_shard_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)

        instance_id = "344"
        protocol = PIDProtocol.UNION_PID
        pid_role = PIDRole.PARTNER
        num_shards = 50
        input_path = "abc.text"
        output_path = "def.txt"

        dispatcher = PIDDispatcher(instance_id=instance_id,
                                   instance_repository=mock_instance_repo)

        dispatcher.build_stages(
            input_path=input_path,
            output_path=output_path,
            num_shards=num_shards,
            protocol=protocol,
            role=pid_role,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
        )

        # run pid shard stage
        await dispatcher.run_stage(mock_pid_shard_stage())
        self.assertEqual(len(dispatcher.dag.nodes), 2)

        # attempt to fail the prepare stage
        mock_pid_prepare_stage().run = Exception()
        with self.assertRaises(PIDStageFailureError):
            await dispatcher.run_stage(mock_pid_prepare_stage())
        self.assertEqual(len(dispatcher.dag.nodes), 2)

        # rerun the failed stage once again
        mock_pid_prepare_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        await dispatcher.run_stage(mock_pid_prepare_stage())
        # verify the stage is run successfully
        self.assertEqual(len(dispatcher.dag.nodes), 1)
    async def test_union_pid_run_all_order(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_pid_shard_stage,
        mock_pid_prepare_stage,
        mock_pid_run_protocol_stage,
    ) -> None:
        complete_mock = AsyncMock(return_value=PIDStageStatus.COMPLETED)
        mock_pid_shard_stage().run = complete_mock
        mock_pid_prepare_stage().run = complete_mock
        mock_pid_run_protocol_stage().run = complete_mock

        instance_id = "456"
        protocol = PIDProtocol.UNION_PID
        pid_role = PIDRole.PARTNER
        num_shards = 50
        is_validating = False
        input_path = "abc.text"
        output_path = "def.txt"

        dispatcher = PIDDispatcher(instance_id=instance_id,
                                   instance_repository=mock_instance_repo)

        sample_pid_instance = self._get_sample_pid_instance(
            instance_id=instance_id,
            protocol=protocol,
            pid_role=pid_role,
            num_shards=num_shards,
            is_validating=is_validating,
            input_path=input_path,
            output_path=output_path,
        )
        dispatcher.instance_repository.read = MagicMock(
            return_value=sample_pid_instance)

        dispatcher.build_stages(
            input_path=input_path,
            output_path=output_path,
            num_shards=num_shards,
            protocol=protocol,
            role=pid_role,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
        )

        # pre-run DAG should have 3 nodes
        self.assertEqual(len(dispatcher.dag.nodes), 3)
        await dispatcher.run_all()
        # post run DAG should be empty
        self.assertEqual(len(dispatcher.dag.nodes), 0)
        # Expect each (mocked) node to have called run() once
        self.assertEqual(complete_mock.mock.call_count, 3)
Esempio n. 3
0
 async def test_pid_shard_stage_service(
     self,
     pc_role: PrivateComputationRole,
     test_num_containers: int,
     has_hmac_key: bool,
 ) -> None:
     hamc_key_expected = self.test_hmac_key if has_hmac_key else None
     pc_instance = self.create_sample_pc_instance(pc_role,
                                                  test_num_containers,
                                                  hamc_key_expected)
     stage_svc = PIDShardStageService(
         storage_svc=self.mock_storage_svc,
         onedocker_svc=self.mock_onedocker_svc,
         onedocker_binary_config_map=self.onedocker_binary_config_map,
         container_timeout=self.container_timeout,
     )
     containers = [
         self.create_container_instance()
         for _ in range(test_num_containers)
     ]
     self.mock_onedocker_svc.start_containers = MagicMock(
         return_value=containers)
     self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock(
         return_value=containers)
     updated_pc_instance = await stage_svc.run_async(pc_instance=pc_instance
                                                     )
     env_vars = {
         "ONEDOCKER_REPOSITORY_PATH":
         self.onedocker_binary_config.repository_path
     }
     args_ls_expect = self.get_args_expect(pc_role, test_num_containers,
                                           has_hmac_key)
     # test the start_containers is called with expected parameters
     self.mock_onedocker_svc.start_containers.assert_called_with(
         package_name=self.binary_name,
         version=self.onedocker_binary_config.binary_version,
         cmd_args_list=args_ls_expect,
         timeout=self.container_timeout,
         env_vars=env_vars,
     )
     # test the return value is as expected
     self.assertEqual(
         len(updated_pc_instance.instances),
         1,
         "Failed to add the StageStageInstance into pc_instance",
     )
     stage_state_expect = StageStateInstance(
         pc_instance.instance_id,
         pc_instance.current_stage.name,
         containers=containers,
     )
     stage_state_actual = updated_pc_instance.instances[0]
     self.assertEqual(
         stage_state_actual,
         stage_state_expect,
         "Appended StageStageInstance is not as expected",
     )
Esempio n. 4
0
    async def test_pid_run_protocol_stage(
        self, pc_role: PrivateComputationRole, multikey_enabled: bool
    ) -> None:
        protocol = (
            PIDProtocol.UNION_PID_MULTIKEY
            if self.test_num_containers == 1 and multikey_enabled
            else PIDProtocol.UNION_PID
        )
        pc_instance = self.create_sample_pc_instance(pc_role)
        stage_svc = PIDRunProtocolStageService(
            storage_svc=self.mock_storage_svc,
            onedocker_svc=self.mock_onedocker_svc,
            onedocker_binary_config_map=self.onedocker_binary_config_map,
            multikey_enabled=multikey_enabled,
        )
        containers = [
            await self.create_container_instance()
            for _ in range(self.test_num_containers)
        ]
        self.mock_onedocker_svc.start_containers = MagicMock(return_value=containers)
        self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock(
            return_value=containers
        )
        updated_pc_instance = await stage_svc.run_async(
            pc_instance=pc_instance, server_ips=self.server_ips
        )

        binary_name = PIDRunProtocolBinaryService.get_binary_name(protocol, pc_role)
        binary_config = self.onedocker_binary_config_map[binary_name]
        env_vars = {ONEDOCKER_REPOSITORY_PATH: binary_config.repository_path}
        args_str_expect = self.get_args_expect(pc_role, protocol, self.use_row_numbers)
        # test the start_containers is called with expected parameters
        self.mock_onedocker_svc.start_containers.assert_called_with(
            package_name=binary_name,
            version=binary_config.binary_version,
            cmd_args_list=args_str_expect,
            timeout=DEFAULT_CONTAINER_TIMEOUT_IN_SEC,
            env_vars=env_vars,
        )
        # test the return value is as expected
        self.assertEqual(
            len(updated_pc_instance.instances),
            self.test_num_containers,
            "Failed to add the StageStageInstance into pc_instance",
        )
        stage_state_expect = StageStateInstance(
            pc_instance.instance_id,
            pc_instance.current_stage.name,
            containers=containers,
        )
        stage_state_actual = updated_pc_instance.instances[0]
        self.assertEqual(
            stage_state_actual,
            stage_state_expect,
            "Appended StageStageInstance is not as expected",
        )
Esempio n. 5
0
 async def test_pid_prepare_stage_service(
     self,
     pc_role: PrivateComputationRole,
     multikey_enabled: bool,
     test_num_containers: int,
 ) -> None:
     pid_protocol = (PIDProtocol.UNION_PID_MULTIKEY
                     if test_num_containers == 1 and multikey_enabled else
                     PIDProtocol.UNION_PID)
     max_col_cnt_expect = (DEFAULT_MULTIKEY_PROTOCOL_MAX_COLUMN_COUNT
                           if pid_protocol is PIDProtocol.UNION_PID_MULTIKEY
                           else 1)
     pc_instance = self.create_sample_pc_instance(pc_role,
                                                  test_num_containers)
     stage_svc = PIDPrepareStageService(
         storage_svc=self.mock_storage_svc,
         onedocker_svc=self.mock_onedocker_svc,
         onedocker_binary_config_map=self.onedocker_binary_config_map,
         multikey_enabled=multikey_enabled,
     )
     containers = [
         self.create_container_instance()
         for _ in range(test_num_containers)
     ]
     self.mock_onedocker_svc.start_containers = MagicMock(
         return_value=containers)
     self.mock_onedocker_svc.wait_for_pending_containers = AsyncMock(
         return_value=containers)
     updated_pc_instance = await stage_svc.run_async(pc_instance=pc_instance
                                                     )
     env_vars = {
         "ONEDOCKER_REPOSITORY_PATH":
         self.onedocker_binary_config.repository_path
     }
     args_ls_expect = self.get_args_expected(pc_role, test_num_containers,
                                             max_col_cnt_expect)
     # test the start_containers is called with expected parameters
     self.mock_onedocker_svc.start_containers.assert_called_with(
         package_name=self.binary_name,
         version=self.onedocker_binary_config.binary_version,
         cmd_args_list=args_ls_expect,
         timeout=self.container_timeout,
         env_vars=env_vars,
     )
     # test the return value is as expected
     self.assertEqual(
         len(updated_pc_instance.instances),
         1,
         "Failed to add the StageStateInstance into pc_instance",
     )
     stage_state_expect = StageStateInstance(
         pc_instance.instance_id,
         pc_instance.current_stage.name,
         containers=containers,
     )
     stage_state_actual = updated_pc_instance.instances[0]
     self.assertEqual(
         stage_state_actual,
         stage_state_expect,
         "Appended StageStateInstance is not as expected",
     )
    async def test_valid_custom_flow(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_get_execution_flow,
        mock_pid_run_protocol_stage,
        mock_pid_prepare_stage,
        mock_pid_shard_stage,
    ) -> None:
        # custom flow with non-linear dependency
        mock_get_execution_flow.return_value = PIDFlow(
            name="union_pid_advertiser",
            base_flow="union_pid",
            # pyre-fixme[6]: For 3rd param expected `Dict[UnionPIDStage, List[str]]`
            #  but got `List[Variable[_T]]`.
            extra_args=[],
            flow={
                UnionPIDStage.ADV_SHARD: [
                    UnionPIDStage.ADV_PREPARE,
                    UnionPIDStage.ADV_RUN_PID,
                ],
                UnionPIDStage.ADV_PREPARE: [],
                UnionPIDStage.ADV_RUN_PID: [],
            },
        )
        mock_pid_prepare_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_shard_stage().stage_type = UnionPIDStage.PUBLISHER_SHARD
        mock_pid_shard_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_run_protocol_stage().type = UnionPIDStage.PUBLISHER_RUN_PID
        mock_pid_run_protocol_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)

        dispatcher = PIDDispatcher(instance_id="456",
                                   instance_repository=mock_instance_repo)
        dispatcher.build_stages(
            input_path="abc.text",
            output_path="def.txt",
            num_shards=50,
            protocol=PIDProtocol.UNION_PID,
            role=PIDRole.PARTNER,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
        )

        self.assertEqual(len(dispatcher.dag.nodes), 3)

        await dispatcher.run_all()
        # Make sure each stage is called exactly once
        mock_pid_shard_stage().run.mock.assert_called_once()
        mock_pid_prepare_stage().run.mock.assert_called_once()
        mock_pid_run_protocol_stage().run.mock.assert_called_once()

        self.assertEqual(
            mock_pid_shard_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["abc.text"],
                output_paths=["def.txt_advertiser_sharded"],
                num_shards=50,
                instance_id="456",
            ),
        )
        self.assertEqual(
            mock_pid_prepare_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["def.txt_advertiser_sharded"],
                output_paths=["def.txt_advertiser_prepared"],
                num_shards=50,
                instance_id="456",
            ),
        )
        self.assertEqual(
            mock_pid_run_protocol_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["def.txt_advertiser_sharded"],
                output_paths=["def.txt_advertiser_pid_matched"],
                num_shards=50,
                instance_id="456",
            ),
        )
        self.assertEqual(len(dispatcher.dag.nodes), 0)  # all done
    async def test_union_pid_flow_valid_partner_with_data_path_spine_path(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_pid_run_protocol_stage,
        mock_pid_prepare_stage,
        mock_pid_shard_stage,
    ) -> None:
        mock_pid_prepare_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_shard_stage().stage_type = UnionPIDStage.PUBLISHER_SHARD
        mock_pid_shard_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_run_protocol_stage(
        ).stage_type = UnionPIDStage.PUBLISHER_RUN_PID
        mock_pid_run_protocol_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)

        dispatcher = PIDDispatcher(instance_id="456",
                                   instance_repository=mock_instance_repo)
        # The stage contains two additional parameters: data_path and spine_path
        # data_path is the output of the shard stage
        # spine_path is the output of the protocol run stage
        dispatcher.build_stages(
            input_path="abc.text",
            output_path="def.txt",
            num_shards=50,
            protocol=PIDProtocol.UNION_PID,
            role=PIDRole.PARTNER,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
            data_path="data.txt",
            spine_path="spine.txt",
        )

        self.assertEqual(len(dispatcher.dag.nodes), 3)

        await dispatcher.run_all()
        # Make sure each stage is called exactly once
        mock_pid_shard_stage().run.mock.assert_called_once()
        mock_pid_prepare_stage().run.mock.assert_called_once()
        mock_pid_run_protocol_stage().run.mock.assert_called_once()

        # expect output_paths as specified in data_path
        self.assertEqual(
            mock_pid_shard_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["abc.text"],
                output_paths=["data.txt_advertiser_sharded"],
                num_shards=50,
                instance_id="456",
            ),
        )
        # expect input_paths as specified in data_path for shard stage
        self.assertEqual(
            mock_pid_prepare_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["data.txt_advertiser_sharded"],
                output_paths=["def.txt_advertiser_prepared"],
                num_shards=50,
                instance_id="456",
            ),
        )
        # expect output_paths as specified in spine_path for protocol run stage
        self.assertEqual(
            mock_pid_run_protocol_stage().run.mock.call_args[0][0],
            PIDStageInput(
                input_paths=["def.txt_advertiser_prepared"],
                output_paths=["spine.txt_advertiser_pid_matched"],
                num_shards=50,
                instance_id="456",
            ),
        )
        self.assertEqual(len(dispatcher.dag.nodes), 0)  # all done
    async def test_union_pid_run_only_unfinished_stages(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_pid_shard_stage,
        mock_pid_prepare_stage,
        mock_pid_run_protocol_stage,
    ) -> None:
        mock_pid_shard_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_prepare_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_run_protocol_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)

        instance_id = "456"
        protocol = PIDProtocol.UNION_PID
        pid_role = PIDRole.PARTNER
        num_shards = 50
        is_validating = False
        input_path = "abc.text"
        output_path = "def.txt"

        dispatcher = PIDDispatcher(instance_id=instance_id,
                                   instance_repository=mock_instance_repo)

        sample_pid_instance = self._get_sample_pid_instance(
            instance_id=instance_id,
            protocol=protocol,
            pid_role=pid_role,
            num_shards=num_shards,
            is_validating=is_validating,
            input_path=input_path,
            output_path=output_path,
        )
        # make the instance think it has completed the shard stage previously
        sample_pid_instance.stages_status[
            mock_pid_shard_stage().stage_type] = PIDStageStatus.COMPLETED
        # make the instance think it has attempted and failed the prepare stage previously
        sample_pid_instance.stages_status[
            mock_pid_prepare_stage().stage_type] = PIDStageStatus.FAILED
        dispatcher.instance_repository.read = MagicMock(
            return_value=sample_pid_instance)

        dispatcher.build_stages(
            input_path=input_path,
            output_path=output_path,
            num_shards=num_shards,
            protocol=protocol,
            role=pid_role,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
        )

        # pre-run DAG should have 2 nodes, since PID Shard is already finished
        self.assertEqual(len(dispatcher.dag.nodes), 2)
        await dispatcher.run_all()
        # post run DAG should be empty
        self.assertEqual(len(dispatcher.dag.nodes), 0)
        # pid shard stage was already finished, so it should not be called again
        mock_pid_shard_stage().run.mock.assert_not_called()
        # prepare failed, so it should run again
        mock_pid_prepare_stage().run.mock.assert_called_once()
        # pid run was never attempted, so it should run
        mock_pid_run_protocol_stage().run.mock.assert_called_once()
    async def test_union_pid_run_stages_one_by_one(
        self,
        mock_instance_repo,
        mock_aws_container_service,
        mock_onedocker_service,
        mock_s3_storage_service,
        mock_pid_shard_stage,
        mock_pid_prepare_stage,
        mock_pid_run_protocol_stage,
    ) -> None:
        mock_pid_shard_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_prepare_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)
        mock_pid_run_protocol_stage().run = AsyncMock(
            return_value=PIDStageStatus.COMPLETED)

        instance_id = "456"
        protocol = PIDProtocol.UNION_PID
        pid_role = PIDRole.PARTNER
        num_shards = 50
        input_path = "abc.text"
        output_path = "def.txt"

        dispatcher = PIDDispatcher(instance_id=instance_id,
                                   instance_repository=mock_instance_repo)

        dispatcher.build_stages(
            input_path=input_path,
            output_path=output_path,
            num_shards=num_shards,
            protocol=protocol,
            role=pid_role,
            storage_svc=mock_s3_storage_service,
            onedocker_svc=mock_onedocker_service,
            # pyre-fixme[6]: For 8th param expected `DefaultDict[str,
            #  OneDockerBinaryConfig]` but got `DefaultDict[Variable[_KT], str]`.
            onedocker_binary_config_map=defaultdict(lambda: "OD_CONFIG"),
        )

        # pre-run DAG should have 3 nodes
        self.assertEqual(len(dispatcher.dag.nodes), 3)
        await dispatcher.run_stage(mock_pid_shard_stage())
        self.assertEqual(len(dispatcher.dag.nodes), 2)

        # attempt to run out of order
        with self.assertRaises(PIDStageFailureError):
            await dispatcher.run_stage(mock_pid_run_protocol_stage())
        # dag should not have been affected
        self.assertEqual(len(dispatcher.dag.nodes), 2)

        # continue running in correct order
        await dispatcher.run_stage(mock_pid_prepare_stage())
        self.assertEqual(len(dispatcher.dag.nodes), 1)

        await dispatcher.run_stage(mock_pid_run_protocol_stage())
        self.assertEqual(len(dispatcher.dag.nodes), 0)

        # attempt to rerun an already completed stage
        with self.assertRaises(PIDStageFailureError):
            await dispatcher.run_stage(mock_pid_run_protocol_stage())

        # each stage should only have been called once
        mock_pid_shard_stage().run.mock.assert_called_once()
        mock_pid_prepare_stage().run.mock.assert_called_once()
        mock_pid_run_protocol_stage().run.mock.assert_called_once()