Example #1
0
    def _update_instance(
        self, private_computation_instance: PrivateComputationInstance
    ) -> PrivateComputationInstance:
        stage = private_computation_instance.current_stage
        stage_svc = stage.get_stage_service(self.stage_service_args)
        self.logger.info(f"Updating instance | {stage}={stage!r}")
        new_status = stage_svc.get_status(private_computation_instance)
        private_computation_instance.update_status(new_status, self.logger)
        self.instance_repository.update(private_computation_instance)
        self.logger.info(
            f"Finished updating instance: {private_computation_instance.instance_id}"
        )

        return private_computation_instance
    async def run_async(
        self,
        pc_instance: PrivateComputationInstance,
        server_ips: Optional[List[str]] = None,
    ) -> PrivateComputationInstance:
        """Runs the private computation post processing handlers stage

        Post processing handlers are designed to run after final results are available. You can write
        post processing handlers to download results from cloud storage, send you an email, etc.

        Args:
            pc_instance: the private computation instance to run post processing handlers with
            server_ips: only used by the partner role. These are the ip addresses of the publisher's containers.

        Returns:
            An updated version of pc_instance that stores a post processing instance
        """

        post_processing_handlers_statuses = None
        if pc_instance.instances:
            last_instance = pc_instance.instances[-1]
            if (
                isinstance(last_instance, PostProcessingInstance)
                and last_instance.handler_statuses.keys()
                == self._post_processing_handlers.keys()
            ):
                self._logger.info("Copying statuses from last instance")
                post_processing_handlers_statuses = (
                    last_instance.handler_statuses.copy()
                )

        post_processing_instance = PostProcessingInstance.create_instance(
            instance_id=pc_instance.instance_id
            + "_post_processing"
            + str(pc_instance.retry_counter),
            handlers=self._post_processing_handlers,
            handler_statuses=post_processing_handlers_statuses,
            status=PostProcessingInstanceStatus.STARTED,
        )

        pc_instance.instances.append(post_processing_instance)

        # if any handlers fail, then the post_processing_instance status will be
        # set to failed, as will the pc_instance status
        await asyncio.gather(
            *[
                self._run_post_processing_handler(
                    pc_instance,
                    post_processing_instance,
                    name,
                    handler,
                )
                for name, handler in self._post_processing_handlers.items()
                if post_processing_instance.handler_statuses[name]
                != PostProcessingHandlerStatus.COMPLETED
            ]
        )

        # if any of the handlers failed, then the status of the post processing instance would have
        # been set to failed. If none of them failed, then that means all of the handlers completed, so
        # we can set the status to completed.
        if post_processing_instance.status is not PostProcessingInstanceStatus.FAILED:
            post_processing_instance.status = PostProcessingInstanceStatus.COMPLETED
            pc_instance.update_status(
                pc_instance.current_stage.completed_status, self._logger
            )
        return pc_instance