Пример #1
0
 def test_update(self) -> None:
     instance_id = self._get_random_id()
     test_update_private_computation_instance = PrivateComputationInstance(
         instance_id=instance_id,
         role=PrivateComputationRole.PUBLISHER,
         instances=[self.test_mpc_instance],
         status=PrivateComputationInstanceStatus.CREATED,
         status_update_ts=1600000000,
         num_files_per_mpc_container=40,
         game_type=PrivateComputationGameType.LIFT,
         input_path="in",
         output_dir="out",
         num_pid_containers=4,
         num_mpc_containers=4,
         concurrency=1,
     )
     # Create a new MPC instance to be added to instances
     self.repo.create(test_update_private_computation_instance)
     test_mpc_instance_new = PCSMPCInstance.create_instance(
         instance_id=instance_id,
         game_name="aggregation",
         mpc_party=MPCParty.SERVER,
         num_workers=1,
     )
     instances_new = [self.test_mpc_instance, test_mpc_instance_new]
     # Update instances
     test_update_private_computation_instance.instances = instances_new
     self.repo.update(test_update_private_computation_instance)
     # Assert instances is updated
     self.assertEqual(self.repo.read(instance_id).instances, instances_new)
     self.repo.delete(instance_id)
    def test_cancel_current_stage(self) -> None:
        test_mpc_id = self.test_private_computation_id + "_compute_metrics"
        test_game_name = GameNames.LIFT.value
        test_mpc_party = MPCParty.CLIENT

        # prepare the pl instance that will be read in to memory from the repository
        # at the beginning of the cancel_current_stage function
        mpc_instance_started = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=test_game_name,
            mpc_party=test_mpc_party,
            num_workers=self.test_num_containers,
            status=MPCInstanceStatus.STARTED,
        )
        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.COMPUTATION_STARTED,
            role=PrivateComputationRole.PARTNER,
            instances=[mpc_instance_started],
        )
        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)

        # prepare the mpc instance that's returned from mpc_service.stop_instance()
        mpc_instance_canceled = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=test_game_name,
            mpc_party=test_mpc_party,
            num_workers=self.test_num_containers,
            status=MPCInstanceStatus.CANCELED,
        )
        self.private_computation_service.mpc_svc.stop_instance = MagicMock(
            return_value=mpc_instance_canceled)
        self.private_computation_service.mpc_svc.instance_repository.read = MagicMock(
            return_value=mpc_instance_canceled)

        # call cancel, expect no exception
        private_computation_instance = (
            self.private_computation_service.cancel_current_stage(
                instance_id=self.test_private_computation_id, ))

        # assert the pl instance returned has the correct status
        self.assertEqual(
            PrivateComputationInstanceStatus.COMPUTATION_FAILED,
            private_computation_instance.status,
        )
Пример #3
0
 def setUp(self) -> None:
     instance_id = self._get_random_id()
     self.repo = LocalPrivateComputationInstanceRepository("./")
     self.test_mpc_instance = PCSMPCInstance.create_instance(
         instance_id=instance_id,
         game_name="conversion_lift",
         mpc_party=MPCParty.SERVER,
         num_workers=2,
     )
Пример #4
0
 def setUp(self):
     self.mpc_instance = PCSMPCInstance.create_instance(
         instance_id=TEST_INSTANCE_ID,
         game_name=TEST_GAME_NAME,
         mpc_party=TEST_MPC_PARTY,
         num_workers=TEST_NUM_WORKERS,
         server_ips=TEST_SERVER_IPS,
         status=MPCInstanceStatus.CREATED,
         game_args=TEST_GAME_ARGS,
     )
     self.local_instance_repo = LocalInstanceRepository(TEST_BASE_DIR)
    def test_get_status_from_stage(self) -> None:
        # Test get status from an MPC stage
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id="test_mpc_id",
            game_name=GameNames.SHARD_AGGREGATOR.value,
            mpc_party=MPCParty.SERVER,
            num_workers=2,
            status=MPCInstanceStatus.FAILED,
        )
        pc_instance = self.create_sample_instance(
            PrivateComputationInstanceStatus.AGGREGATION_STARTED,
            instances=[mpc_instance],
        )
        self.private_computation_service.mpc_svc.update_instance = MagicMock(
            return_value=mpc_instance)
        self.assertEqual(
            PrivateComputationInstanceStatus.AGGREGATION_FAILED,
            self.private_computation_service._update_instance(
                pc_instance).status,
        )

        # Test get status from the PID stage
        pid_instance = PIDInstance(
            instance_id="test_pid_id",
            protocol=DEFAULT_PID_PROTOCOL,
            pid_role=PIDRole.PUBLISHER,
            num_shards=4,
            input_path="input",
            output_path="output",
            stages_containers={},
            stages_status={
                UnionPIDStage.PUBLISHER_RUN_PID: PIDStageStatus.COMPLETED
            },
            current_stage=UnionPIDStage.PUBLISHER_RUN_PID,
            status=PIDInstanceStatus.COMPLETED,
        )
        pc_instance = self.create_sample_instance(
            PrivateComputationInstanceStatus.ID_MATCHING_STARTED,
            instances=[pid_instance],
        )

        self.private_computation_service.pid_svc.update_instance = MagicMock(
            return_value=pid_instance)
        self.assertEqual(
            PrivateComputationInstanceStatus.ID_MATCHING_COMPLETED,
            self.private_computation_service._update_instance(
                pc_instance).status,
        )
    async def test_attribution_stage(self) -> None:
        private_computation_instance = self._create_pc_instance()
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=private_computation_instance.instance_id
            + "_decoupled_attribution0",
            game_name=GameNames.DECOUPLED_ATTRIBUTION.value,
            mpc_party=MPCParty.CLIENT,
            num_workers=private_computation_instance.num_mpc_containers,
        )

        self.mock_mpc_svc.start_instance_async = AsyncMock(return_value=mpc_instance)

        test_server_ips = [
            f"192.0.2.{i}"
            for i in range(private_computation_instance.num_mpc_containers)
        ]
        await self.stage_svc.run_async(private_computation_instance, test_server_ips)

        self.assertEqual(mpc_instance, private_computation_instance.instances[0])
def gen_dummy_mpc_instance() -> PCSMPCInstance:
    """Creates a dummy mpc instance to be used in unit tests"""

    return PCSMPCInstance.create_instance(
        instance_id="mpc_instance_id",
        game_name="lift",
        mpc_party=MPCParty.SERVER,
        num_workers=1,
        server_ips=["10.0.10.242"],
        containers=[gen_dummy_container_instance()],
        status=MPCInstanceStatus.COMPLETED,
        game_args=[{
            "input_base_path":
            "https://bucket.s3.us-west-2.amazonaws.com/lift/partner/partner_instance_1638998680_0_out_dir/data_processing_stage/out.csv",
            "output_base_path":
            "https://bucket.s3.us-west-2.amazonaws.com/lift/partner/partner_instance_1638998680_0_out_dir/compute_stage/out.json",
            "num_files": 40,
            "concurrency": 4,
            "file_start_index": 0,
        }],
    )
    async def test_compute_metrics(self) -> None:
        private_computation_instance = self._create_pc_instance()
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=private_computation_instance.instance_id +
            "_pcf2_lift0",
            game_name=GameNames.PCF2_LIFT.value,
            mpc_party=MPCParty.CLIENT,
            num_workers=private_computation_instance.num_mpc_containers,
        )

        self.mock_mpc_svc.start_instance_async = AsyncMock(
            return_value=mpc_instance)

        test_server_ips = [
            f"192.0.2.{i}"
            for i in range(private_computation_instance.num_mpc_containers)
        ]
        await self.stage_svc.run_async(private_computation_instance,
                                       test_server_ips)

        self.assertEqual(mpc_instance,
                         private_computation_instance.instances[0])
    def test_update_instance(self) -> None:
        test_pid_id = self.test_private_computation_id + "_id_match"
        test_pid_role = PIDRole.PUBLISHER
        test_input_path = "pid_in"
        test_output_path = "pid_out"
        # create one PID instance to be put into PrivateComputationInstance
        pid_instance = PIDInstance(
            instance_id=test_pid_id,
            protocol=DEFAULT_PID_PROTOCOL,
            pid_role=test_pid_role,
            num_shards=self.test_num_containers,
            input_path=test_input_path,
            output_path=test_output_path,
            status=PIDInstanceStatus.STARTED,
        )

        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.ID_MATCHING_STARTED,
            instances=[pid_instance],
        )

        updated_pid_instance = pid_instance
        updated_pid_instance.status = PIDInstanceStatus.COMPLETED
        updated_pid_instance.current_stage = UnionPIDStage.PUBLISHER_RUN_PID
        updated_pid_instance.stages_status = {
            UnionPIDStage.PUBLISHER_RUN_PID: PIDStageStatus.COMPLETED
        }

        self.private_computation_service.pid_svc.update_instance = MagicMock(
            return_value=updated_pid_instance)

        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)

        # end_ts should not be calculated until the instance run is complete.
        self.assertEqual(0, private_computation_instance.end_ts)

        # call update on the PrivateComputationInstance
        updated_instance = self.private_computation_service.update_instance(
            instance_id=self.test_private_computation_id)

        # check update instance called on the right pid instance
        # pyre-fixme[16]: Callable `update_instance` has no attribute `assert_called`.
        self.private_computation_service.pid_svc.update_instance.assert_called(
        )
        self.assertEqual(
            test_pid_id,
            # pyre-fixme[16]: Callable `update_instance` has no attribute `call_args`.
            self.private_computation_service.pid_svc.update_instance.
            call_args[0][0],
        )

        # check update instance called on the right private lift instance
        # pyre-fixme[16]: Callable `update` has no attribute `assert_called`.
        self.private_computation_service.instance_repository.update.assert_called(
        )
        self.assertEqual(
            private_computation_instance,
            # pyre-fixme[16]: Callable `update` has no attribute `call_args`.
            self.private_computation_service.instance_repository.update.
            call_args[0][0],
        )

        # check updated_instance has new status
        self.assertEqual(
            PrivateComputationInstanceStatus.ID_MATCHING_COMPLETED,
            updated_instance.status,
        )

        # create one MPC instance to be put into PrivateComputationInstance
        test_mpc_id = "test_mpc_id"
        mpc_instance = PCSMPCInstance.create_instance(
            instance_id=test_mpc_id,
            game_name=GameNames.LIFT.value,
            mpc_party=MPCParty.SERVER,
            num_workers=2,
        )

        private_computation_instance = self.create_sample_instance(
            status=PrivateComputationInstanceStatus.COMPUTATION_STARTED,
            instances=[mpc_instance],
        )

        updated_mpc_instance = mpc_instance
        updated_mpc_instance.status = MPCInstanceStatus.COMPLETED
        self.private_computation_service.mpc_svc.update_instance = MagicMock(
            return_value=updated_mpc_instance)

        self.private_computation_service.instance_repository.read = MagicMock(
            return_value=private_computation_instance)
        # call update on the PrivateComputationInstance
        updated_instance = self.private_computation_service.update_instance(
            instance_id=self.test_private_computation_id)

        # check update instance called on the right mpc instance
        # pyre-fixme[16]: Callable `update_instance` has no attribute `assert_called`.
        self.private_computation_service.mpc_svc.update_instance.assert_called(
        )
        self.assertEqual(
            test_mpc_id,
            # pyre-fixme[16]: Callable `update_instance` has no attribute `call_args`.
            self.private_computation_service.mpc_svc.update_instance.
            call_args[0][0],
        )

        # check update instance called on the right private lift instance
        self.private_computation_service.instance_repository.update.assert_called(
        )
        self.assertEqual(
            private_computation_instance,
            self.private_computation_service.instance_repository.update.
            call_args[0][0],
        )

        # check updated_instance has new status
        self.assertEqual(
            PrivateComputationInstanceStatus.COMPUTATION_COMPLETED,
            updated_instance.status,
        )

        # elapsed_time should report current running time if the run is incomplete.
        self.assertEqual(
            time.time() - private_computation_instance.creation_ts + 1,
            private_computation_instance.elapsed_time,
        )

        expected_end_ts = time.time() + 2
        private_computation_instance.update_status(
            private_computation_instance.stage_flow.get_last_stage().
            completed_status,
            logging.getLogger(),
        )
        self.assertEqual(expected_end_ts, private_computation_instance.end_ts)
        expected_elapsed_time = (private_computation_instance.end_ts -
                                 private_computation_instance.creation_ts)
        self.assertEqual(
            expected_elapsed_time,
            private_computation_instance.elapsed_time,
        )