Пример #1
0
 def test_mpc_deserialiation(self) -> None:
     # this tests that old fields (and instances) can be deserialized
     with open(LIFT_MPC_PATH) as f:
         instance_json = f.read().strip()
     try:
         PCSMPCInstance.loads_schema(instance_json)
     except Exception as e:
         raise RuntimeError(ERR_MSG) from e
Пример #2
0
 def test_update(self) -> None:
     instance_id = self._get_random_id()
     test_update_private_computation_instance = PrivateComputationInstance(
         instance_id=instance_id,
         role=PrivateComputationRole.PUBLISHER,
         instances=[self.test_mpc_instance],
         status=PrivateComputationInstanceStatus.CREATED,
         status_update_ts=1600000000,
         num_files_per_mpc_container=40,
         game_type=PrivateComputationGameType.LIFT,
         input_path="in",
         output_dir="out",
         num_pid_containers=4,
         num_mpc_containers=4,
         concurrency=1,
     )
     # Create a new MPC instance to be added to instances
     self.repo.create(test_update_private_computation_instance)
     test_mpc_instance_new = PCSMPCInstance.create_instance(
         instance_id=instance_id,
         game_name="aggregation",
         mpc_party=MPCParty.SERVER,
         num_workers=1,
     )
     instances_new = [self.test_mpc_instance, test_mpc_instance_new]
     # Update instances
     test_update_private_computation_instance.instances = instances_new
     self.repo.update(test_update_private_computation_instance)
     # Assert instances is updated
     self.assertEqual(self.repo.read(instance_id).instances, instances_new)
     self.repo.delete(instance_id)
    def test_cancel_current_stage(self) -> None:
        test_mpc_id = self.test_private_computation_id + "_compute_metrics"
        test_game_name = GameNames.LIFT.value
        test_mpc_party = MPCParty.CLIENT

        # prepare the pl instance that will be read in to memory from the repository
        # at the beginning of the cancel_current_stage function
        mpc_instance_started = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=test_game_name,
            mpc_party=test_mpc_party,
            num_workers=self.test_num_containers,
            status=MPCInstanceStatus.STARTED,
        )
        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.COMPUTATION_STARTED,
            role=PrivateComputationRole.PARTNER,
            instances=[mpc_instance_started],
        )
        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)

        # prepare the mpc instance that's returned from mpc_service.stop_instance()
        mpc_instance_canceled = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=test_game_name,
            mpc_party=test_mpc_party,
            num_workers=self.test_num_containers,
            status=MPCInstanceStatus.CANCELED,
        )
        self.private_computation_service.mpc_svc.stop_instance = MagicMock(
            return_value=mpc_instance_canceled)
        self.private_computation_service.mpc_svc.instance_repository.read = MagicMock(
            return_value=mpc_instance_canceled)

        # call cancel, expect no exception
        private_computation_instance = (
            self.private_computation_service.cancel_current_stage(
                instance_id=self.test_private_computation_id, ))

        # assert the pl instance returned has the correct status
        self.assertEqual(
            PrivateComputationInstanceStatus.COMPUTATION_FAILED,
            private_computation_instance.status,
        )
Пример #4
0
 def setUp(self) -> None:
     instance_id = self._get_random_id()
     self.repo = LocalPrivateComputationInstanceRepository("./")
     self.test_mpc_instance = PCSMPCInstance.create_instance(
         instance_id=instance_id,
         game_name="conversion_lift",
         mpc_party=MPCParty.SERVER,
         num_workers=2,
     )
Пример #5
0
    async def run_async(
        self,
        pc_instance: PrivateComputationInstance,
        server_ips: Optional[List[str]] = None,
    ) -> PrivateComputationInstance:
        """Runs the pcf2.0 based private aggregation stage

        Args:
            pc_instance: the private computation instance to run aggregation stage
            server_ips: only used by the partner role. These are the ip addresses of the publisher's containers.

        Returns:
            An updated version of pc_instance that stores an MPCInstance
        """

        # Prepare arguments for attribution game
        game_args = self._get_compute_metrics_game_args(pc_instance, )

        # We do this check here because depends on how game_args is generated, len(game_args) could be different,
        #   but we will always expect server_ips == len(game_args)
        if server_ips and len(server_ips) != len(game_args):
            raise ValueError(
                f"Unable to rerun MPC pcf2.0 based aggregation because there is a mismatch between the number of server ips given ({len(server_ips)}) and the number of containers ({len(game_args)}) to be spawned."
            )

        # Create and start MPC instance to run MPC compute
        logging.info(
            "Starting to run MPC instance for pcf2.0 based aggregation stage.")

        stage_data = PrivateComputationServiceData.PCF2_AGGREGATION_STAGE_DATA
        binary_name = OneDockerBinaryNames.PCF2_AGGREGATION.value
        game_name = checked_cast(str, stage_data.game_name)

        binary_config = self._onedocker_binary_config_map[binary_name]
        retry_counter_str = str(pc_instance.retry_counter)
        mpc_instance = await create_and_start_mpc_instance(
            mpc_svc=self._mpc_service,
            instance_id=pc_instance.instance_id + "_" +
            GameNames.PCF2_AGGREGATION.value + retry_counter_str,
            game_name=game_name,
            mpc_party=map_private_computation_role_to_mpc_party(
                pc_instance.role),
            num_containers=len(game_args),
            binary_version=binary_config.binary_version,
            server_ips=server_ips,
            game_args=game_args,
            container_timeout=self._container_timeout,
            repository_path=binary_config.repository_path,
        )

        logging.info(
            "MPC instance started running for pcf2.0 based aggregation stage.")

        # Push MPC instance to PrivateComputationInstance.instances and update PL Instance status
        pc_instance.instances.append(
            PCSMPCInstance.from_mpc_instance(mpc_instance))
        return pc_instance
Пример #6
0
 def test_read_existing_instance(self):
     self.local_instance_repo._exist = MagicMock(return_value=True)
     data = self.mpc_instance.dumps_schema()
     path = TEST_BASE_DIR.joinpath(TEST_INSTANCE_ID)
     with patch("builtins.open", mock_open(read_data=data)) as mock_file:
         self.assertEqual(open(path).read().strip(), data)
         mpc_instance = PCSMPCInstance.loads_schema(
             self.local_instance_repo.read(TEST_INSTANCE_ID))
         self.assertEqual(self.mpc_instance, mpc_instance)
         mock_file.assert_called_with(path, "r")
Пример #7
0
 def setUp(self):
     self.mpc_instance = PCSMPCInstance.create_instance(
         instance_id=TEST_INSTANCE_ID,
         game_name=TEST_GAME_NAME,
         mpc_party=TEST_MPC_PARTY,
         num_workers=TEST_NUM_WORKERS,
         server_ips=TEST_SERVER_IPS,
         status=MPCInstanceStatus.CREATED,
         game_args=TEST_GAME_ARGS,
     )
     self.local_instance_repo = LocalInstanceRepository(TEST_BASE_DIR)
    def test_get_status_from_stage(self) -> None:
        # Test get status from an MPC stage
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id="test_mpc_id",
            game_name=GameNames.SHARD_AGGREGATOR.value,
            mpc_party=MPCParty.SERVER,
            num_workers=2,
            status=MPCInstanceStatus.FAILED,
        )
        pc_instance = self.create_sample_instance(
            PrivateComputationInstanceStatus.AGGREGATION_STARTED,
            instances=[mpc_instance],
        )
        self.private_computation_service.mpc_svc.update_instance = MagicMock(
            return_value=mpc_instance)
        self.assertEqual(
            PrivateComputationInstanceStatus.AGGREGATION_FAILED,
            self.private_computation_service._update_instance(
                pc_instance).status,
        )

        # Test get status from the PID stage
        pid_instance = PIDInstance(
            instance_id="test_pid_id",
            protocol=DEFAULT_PID_PROTOCOL,
            pid_role=PIDRole.PUBLISHER,
            num_shards=4,
            input_path="input",
            output_path="output",
            stages_containers={},
            stages_status={
                UnionPIDStage.PUBLISHER_RUN_PID: PIDStageStatus.COMPLETED
            },
            current_stage=UnionPIDStage.PUBLISHER_RUN_PID,
            status=PIDInstanceStatus.COMPLETED,
        )
        pc_instance = self.create_sample_instance(
            PrivateComputationInstanceStatus.ID_MATCHING_STARTED,
            instances=[pid_instance],
        )

        self.private_computation_service.pid_svc.update_instance = MagicMock(
            return_value=pid_instance)
        self.assertEqual(
            PrivateComputationInstanceStatus.ID_MATCHING_COMPLETED,
            self.private_computation_service._update_instance(
                pc_instance).status,
        )
    async def test_attribution_stage(self) -> None:
        private_computation_instance = self._create_pc_instance()
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=private_computation_instance.instance_id
            + "_decoupled_attribution0",
            game_name=GameNames.DECOUPLED_ATTRIBUTION.value,
            mpc_party=MPCParty.CLIENT,
            num_workers=private_computation_instance.num_mpc_containers,
        )

        self.mock_mpc_svc.start_instance_async = AsyncMock(return_value=mpc_instance)

        test_server_ips = [
            f"192.0.2.{i}"
            for i in range(private_computation_instance.num_mpc_containers)
        ]
        await self.stage_svc.run_async(private_computation_instance, test_server_ips)

        self.assertEqual(mpc_instance, private_computation_instance.instances[0])
def gen_dummy_mpc_instance() -> PCSMPCInstance:
    """Creates a dummy mpc instance to be used in unit tests"""

    return PCSMPCInstance.create_instance(
        instance_id="mpc_instance_id",
        game_name="lift",
        mpc_party=MPCParty.SERVER,
        num_workers=1,
        server_ips=["10.0.10.242"],
        containers=[gen_dummy_container_instance()],
        status=MPCInstanceStatus.COMPLETED,
        game_args=[{
            "input_base_path":
            "https://bucket.s3.us-west-2.amazonaws.com/lift/partner/partner_instance_1638998680_0_out_dir/data_processing_stage/out.csv",
            "output_base_path":
            "https://bucket.s3.us-west-2.amazonaws.com/lift/partner/partner_instance_1638998680_0_out_dir/compute_stage/out.json",
            "num_files": 40,
            "concurrency": 4,
            "file_start_index": 0,
        }],
    )
Пример #11
0
def get_updated_pc_status_mpc_game(
    private_computation_instance: PrivateComputationInstance,
    mpc_svc: MPCService,
) -> PrivateComputationInstanceStatus:
    """Updates the MPCInstances and gets latest PrivateComputationInstance status

    Arguments:
        private_computation_instance: The PC instance that is being updated
        mpc_svc: Used to update MPC instances stored on private_computation_instance

    Returns:
        The latest status for private_computation_instance
    """
    status = private_computation_instance.status
    if private_computation_instance.instances:
        # Only need to update the last stage/instance
        last_instance = private_computation_instance.instances[-1]
        if not isinstance(last_instance, MPCInstance):
            return status

        # MPC service has to call update_instance to get the newest containers
        # information in case they are still running
        private_computation_instance.instances[
            -1] = PCSMPCInstance.from_mpc_instance(
                mpc_svc.update_instance(last_instance.instance_id))

        mpc_instance_status = private_computation_instance.instances[-1].status

        current_stage = private_computation_instance.current_stage
        if mpc_instance_status is MPCInstanceStatus.STARTED:
            status = current_stage.started_status
        elif mpc_instance_status is MPCInstanceStatus.COMPLETED:
            status = current_stage.completed_status
        elif mpc_instance_status in (
                MPCInstanceStatus.FAILED,
                MPCInstanceStatus.CANCELED,
        ):
            status = current_stage.failed_status

    return status
    async def test_compute_metrics(self) -> None:
        private_computation_instance = self._create_pc_instance()
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=private_computation_instance.instance_id +
            "_pcf2_lift0",
            game_name=GameNames.PCF2_LIFT.value,
            mpc_party=MPCParty.CLIENT,
            num_workers=private_computation_instance.num_mpc_containers,
        )

        self.mock_mpc_svc.start_instance_async = AsyncMock(
            return_value=mpc_instance)

        test_server_ips = [
            f"192.0.2.{i}"
            for i in range(private_computation_instance.num_mpc_containers)
        ]
        await self.stage_svc.run_async(private_computation_instance,
                                       test_server_ips)

        self.assertEqual(mpc_instance,
                         private_computation_instance.instances[0])
Пример #13
0
 def update(self, instance: MPCInstance) -> None:
     self.repo.update(PCSMPCInstance.from_mpc_instance(instance))
Пример #14
0
 def read(self, instance_id: str) -> PCSMPCInstance:
     return PCSMPCInstance.loads_schema(self.repo.read(instance_id))
Пример #15
0
    async def run_async(
        self,
        pc_instance: PrivateComputationInstance,
        server_ips: Optional[List[str]] = None,
    ) -> PrivateComputationInstance:
        """Runs the private computation aggregate metrics stage

        Args:
            pc_instance: the private computation instance to run aggregate metrics with
            server_ips: only used by the partner role. These are the ip addresses of the publisher's containers.

        Returns:
            An updated version of pc_instance that stores an MPCInstance
        """

        num_shards = (pc_instance.num_mpc_containers *
                      pc_instance.num_files_per_mpc_container)

        # TODO T101225989: map aggregation_type from the compute stage to metrics_format_type
        metrics_format_type = (
            "lift" if pc_instance.game_type is PrivateComputationGameType.LIFT
            else "ad_object")

        binary_name = OneDockerBinaryNames.SHARD_AGGREGATOR.value
        binary_config = self._onedocker_binary_config_map[binary_name]

        # Get output path of previous stage depending on what stage flow we are using
        # Using "PrivateComputationDecoupledStageFlow" instead of PrivateComputationDecoupledStageFlow.get_cls_name() to avoid
        # circular import error.
        if pc_instance.get_flow_cls_name in [
                "PrivateComputationDecoupledStageFlow",
                "PrivateComputationDecoupledLocalTestStageFlow",
        ]:
            input_stage_path = pc_instance.decoupled_aggregation_stage_output_base_path
        elif pc_instance.get_flow_cls_name in [
                "PrivateComputationPCF2StageFlow",
                "PrivateComputationPCF2LocalTestStageFlow",
        ]:
            input_stage_path = pc_instance.pcf2_aggregation_stage_output_base_path
        elif pc_instance.get_flow_cls_name == "PrivateComputationPCF2LiftStageFlow":
            input_stage_path = pc_instance.pcf2_lift_stage_output_base_path
        else:
            input_stage_path = pc_instance.compute_stage_output_base_path

        if self._log_cost_to_s3:
            run_name = pc_instance.instance_id

            if pc_instance.post_processing_data:
                pc_instance.post_processing_data.s3_cost_export_output_paths.add(
                    f"sa-logs/{run_name}_{pc_instance.role.value.title()}.json",
                )
        else:
            run_name = ""

        if self._is_validating:
            # num_containers_real_data is the number of containers processing real data
            # synthetic data is processed by a dedicated extra container, and this container is always the last container,
            # hence synthetic_data_shard_start_index = num_real_data_shards
            # each of the containers, processing real or synthetic data, processes the same number of shards due to our resharding mechanism
            # num_shards representing the total number of shards which is equal to num_real_data_shards + num_synthetic_data_shards
            # hence, when num_containers_real_data and num_shards are given, num_synthetic_data_shards = num_shards / (num_containers_real_data + 1)
            num_containers_real_data = pc_instance.num_pid_containers
            if num_containers_real_data is None:
                raise ValueError("num_containers_real_data is None")
            num_synthetic_data_shards = num_shards // (
                num_containers_real_data + 1)
            num_real_data_shards = num_shards - num_synthetic_data_shards
            synthetic_data_shard_start_index = num_real_data_shards

            # Create and start MPC instance for real data shards and synthetic data shards
            game_args = [
                {
                    "input_base_path": input_stage_path,
                    "num_shards": num_real_data_shards,
                    "metrics_format_type": metrics_format_type,
                    "output_path":
                    pc_instance.shard_aggregate_stage_output_path,
                    "first_shard_index": 0,
                    "threshold": pc_instance.k_anonymity_threshold,
                    "run_name": run_name,
                    "log_cost": self._log_cost_to_s3,
                },
                {
                    "input_base_path": input_stage_path,
                    "num_shards": num_synthetic_data_shards,
                    "metrics_format_type": metrics_format_type,
                    "output_path":
                    pc_instance.shard_aggregate_stage_output_path +
                    "_synthetic_data_shards",
                    "first_shard_index": synthetic_data_shard_start_index,
                    "threshold": pc_instance.k_anonymity_threshold,
                    "run_name": run_name,
                    "log_cost": self._log_cost_to_s3,
                },
            ]
            # We should only export visibility to scribe when it's set
            if pc_instance.result_visibility is not ResultVisibility.PUBLIC:
                result_visibility = int(pc_instance.result_visibility)
                for arg in game_args:
                    arg["visibility"] = result_visibility

            mpc_instance = await create_and_start_mpc_instance(
                mpc_svc=self._mpc_service,
                instance_id=pc_instance.instance_id + "_aggregate_shards" +
                str(pc_instance.retry_counter),
                game_name=GameNames.SHARD_AGGREGATOR.value,
                mpc_party=map_private_computation_role_to_mpc_party(
                    pc_instance.role),
                num_containers=2,
                binary_version=binary_config.binary_version,
                server_ips=server_ips,
                game_args=game_args,
                container_timeout=self._container_timeout,
            )
        else:
            # Create and start MPC instance
            game_args = [
                {
                    "input_base_path": input_stage_path,
                    "metrics_format_type": metrics_format_type,
                    "num_shards": num_shards,
                    "output_path":
                    pc_instance.shard_aggregate_stage_output_path,
                    "threshold": pc_instance.k_anonymity_threshold,
                    "run_name": run_name,
                    "log_cost": self._log_cost_to_s3,
                },
            ]
            # We should only export visibility to scribe when it's set
            if pc_instance.result_visibility is not ResultVisibility.PUBLIC:
                result_visibility = int(pc_instance.result_visibility)
                for arg in game_args:
                    arg["visibility"] = result_visibility

            mpc_instance = await create_and_start_mpc_instance(
                mpc_svc=self._mpc_service,
                instance_id=pc_instance.instance_id + "_aggregate_shards" +
                str(pc_instance.retry_counter),
                game_name=GameNames.SHARD_AGGREGATOR.value,
                mpc_party=map_private_computation_role_to_mpc_party(
                    pc_instance.role),
                num_containers=1,
                binary_version=binary_config.binary_version,
                server_ips=server_ips,
                game_args=game_args,
                container_timeout=self._container_timeout,
                repository_path=binary_config.repository_path,
            )
        # Push MPC instance to PrivateComputationInstance.instances and update PL Instance status
        pc_instance.instances.append(
            PCSMPCInstance.from_mpc_instance(mpc_instance))
        return pc_instance
    def test_update_instance(self) -> None:
        test_pid_id = self.test_private_computation_id + "_id_match"
        test_pid_role = PIDRole.PUBLISHER
        test_input_path = "pid_in"
        test_output_path = "pid_out"
        # create one PID instance to be put into PrivateComputationInstance
        pid_instance = PIDInstance(
            instance_id=test_pid_id,
            protocol=DEFAULT_PID_PROTOCOL,
            pid_role=test_pid_role,
            num_shards=self.test_num_containers,
            input_path=test_input_path,
            output_path=test_output_path,
            status=PIDInstanceStatus.STARTED,
        )

        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.ID_MATCHING_STARTED,
            instances=[pid_instance],
        )

        updated_pid_instance = pid_instance
        updated_pid_instance.status = PIDInstanceStatus.COMPLETED
        updated_pid_instance.current_stage = UnionPIDStage.PUBLISHER_RUN_PID
        updated_pid_instance.stages_status = {
            UnionPIDStage.PUBLISHER_RUN_PID: PIDStageStatus.COMPLETED
        }

        self.private_computation_service.pid_svc.update_instance = MagicMock(
            return_value=updated_pid_instance)

        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)

        # end_ts should not be calculated until the instance run is complete.
        self.assertEqual(0, private_computation_instance.end_ts)

        # call update on the PrivateComputationInstance
        updated_instance = self.private_computation_service.update_instance(
            instance_id=self.test_private_computation_id)

        # check update instance called on the right pid instance
        # pyre-fixme[16]: Callable `update_instance` has no attribute `assert_called`.
        self.private_computation_service.pid_svc.update_instance.assert_called(
        )
        self.assertEqual(
            test_pid_id,
            # pyre-fixme[16]: Callable `update_instance` has no attribute `call_args`.
            self.private_computation_service.pid_svc.update_instance.
            call_args[0][0],
        )

        # check update instance called on the right private lift instance
        # pyre-fixme[16]: Callable `update` has no attribute `assert_called`.
        self.private_computation_service.instance_repository.update.assert_called(
        )
        self.assertEqual(
            private_computation_instance,
            # pyre-fixme[16]: Callable `update` has no attribute `call_args`.
            self.private_computation_service.instance_repository.update.
            call_args[0][0],
        )

        # check updated_instance has new status
        self.assertEqual(
            PrivateComputationInstanceStatus.ID_MATCHING_COMPLETED,
            updated_instance.status,
        )

        # create one MPC instance to be put into PrivateComputationInstance
        test_mpc_id = "test_mpc_id"
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=GameNames.LIFT.value,
            mpc_party=MPCParty.SERVER,
            num_workers=2,
        )

        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.COMPUTATION_STARTED,
            instances=[mpc_instance],
        )

        updated_mpc_instance = mpc_instance
        updated_mpc_instance.status = MPCInstanceStatus.COMPLETED
        self.private_computation_service.mpc_svc.update_instance = MagicMock(
            return_value=updated_mpc_instance)

        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)
        # call update on the PrivateComputationInstance
        updated_instance = self.private_computation_service.update_instance(
            instance_id=self.test_private_computation_id)

        # check update instance called on the right mpc instance
        # pyre-fixme[16]: Callable `update_instance` has no attribute `assert_called`.
        self.private_computation_service.mpc_svc.update_instance.assert_called(
        )
        self.assertEqual(
            test_mpc_id,
            # pyre-fixme[16]: Callable `update_instance` has no attribute `call_args`.
            self.private_computation_service.mpc_svc.update_instance.
            call_args[0][0],
        )

        # check update instance called on the right private lift instance
        self.private_computation_service.instance_repository.update.assert_called(
        )
        self.assertEqual(
            private_computation_instance,
            self.private_computation_service.instance_repository.update.
            call_args[0][0],
        )

        # check updated_instance has new status
        self.assertEqual(
            PrivateComputationInstanceStatus.COMPUTATION_COMPLETED,
            updated_instance.status,
        )

        # elapsed_time should report current running time if the run is incomplete.
        self.assertEqual(
            time.time() - private_computation_instance.creation_ts + 1,
            private_computation_instance.elapsed_time,
        )

        expected_end_ts = time.time() + 2
        private_computation_instance.update_status(
            private_computation_instance.stage_flow.get_last_stage().
            completed_status,
            logging.getLogger(),
        )
        self.assertEqual(expected_end_ts, private_computation_instance.end_ts)
        expected_elapsed_time = (private_computation_instance.end_ts -
                                 private_computation_instance.creation_ts)
        self.assertEqual(
            expected_elapsed_time,
            private_computation_instance.elapsed_time,
        )