def test_get_instance(self, mock_build_pcs) -> None:
     mock_build_pcs.return_value = self.mock_pcs
     get_instance(
         config=self.config,
         instance_id=self.test_instance_id,
         logger=MagicMock(),
     )
     self.mock_pcs.get_instance.assert_called_once_with(
         self.test_instance_id)
    def __init__(
        self,
        instance_id: str,
        config: Dict[str, Any],
        input_path: str,
        num_mpc_containers: int,
        num_pid_containers: int,
        logger: logging.Logger,
        game_type: PrivateComputationGameType,
        attribution_rule: Optional[AttributionRule] = None,
        aggregation_type: Optional[AggregationType] = None,
        concurrency: Optional[int] = None,
        num_files_per_mpc_container: Optional[int] = None,
        k_anonymity_threshold: Optional[int] = None,
        result_visibility: Optional[ResultVisibility] = None,
    ) -> None:
        super().__init__(instance_id, logger, PrivateComputationRole.PARTNER)
        self.config: Dict[str, Any] = config
        self.input_path: str = input_path
        self.output_dir: str = self.get_output_dir_from_input_path(input_path)
        # try to get instance from instance repo, if not, create a new instance
        self.status: PrivateComputationInstanceStatus
        pc_instance: PrivateComputationInstance
        try:
            pc_instance = get_instance(self.config, self.instance_id,
                                       self.logger)
        except RuntimeError:
            self.logger.info(
                f"Creating new partner instance {self.instance_id}")
            pc_instance = create_instance(
                config=self.config,
                instance_id=self.instance_id,
                role=PrivateComputationRole.PARTNER,
                game_type=game_type,
                logger=self.logger,
                input_path=self.input_path,
                output_dir=self.output_dir,
                num_pid_containers=num_pid_containers,
                num_mpc_containers=num_mpc_containers,
                attribution_rule=attribution_rule,
                aggregation_type=aggregation_type,
                concurrency=concurrency,
                num_files_per_mpc_container=num_files_per_mpc_container,
                k_anonymity_threshold=k_anonymity_threshold,
                result_visibility=result_visibility,
            )

        self.status = pc_instance.status
        if self._need_override_input_path(pc_instance):
            update_input_path(self.config, self.instance_id, self.input_path,
                              self.logger)

        self.wait_valid_status(WAIT_VALID_STATUS_TIMEOUT)
示例#3
0
def run_study(
    config: Dict[str, Any],
    study_id: str,
    objective_ids: List[str],
    input_paths: List[str],
    logger: logging.Logger,
    stage_flow: Type[PrivateComputationBaseStageFlow],
    num_tries: Optional[int] = 2,  # this is number of tries per stage
    dry_run: Optional[bool] = False,  # if set to true, it will only run one stage
    result_visibility: Optional[ResultVisibility] = None,
) -> None:

    ## Step 1: Validation. Function arguments and study metadata must be valid for private lift run.
    _validate_input(objective_ids, input_paths)

    # obtain study information
    client = PCGraphAPIClient(config, logger)
    study_data = _get_study_data(study_id, client)

    # Verify study can run private lift:
    _verify_study_type(study_data)

    # verify mpc objectives
    _verify_mpc_objs(study_data, objective_ids)

    # verify study opp_data_information is non-empty
    if OPP_DATA_INFORMATION not in study_data:
        raise PCStudyValidationException(
            f"Study {study_id} has no opportunity datasets.",
            f"Check {study_id} study data to include {OPP_DATA_INFORMATION}",
        )

    ## Step 2. Preparation. Find which cell-obj pairs should have new instances created for and which should use existing
    ## valid ones. If a valid instance exists for a particular cell-obj pair, use it. Otherwise, try to create one.

    cell_obj_instance = _get_cell_obj_instance(
        study_data,
        objective_ids,
        input_paths,
    )
    _print_json(
        "Existing valid instances for cell-obj pairs", cell_obj_instance, logger
    )
    # create new instances
    _create_new_instances(cell_obj_instance, study_id, client, logger)
    _print_json("Instances to run for cell-obj pairs", cell_obj_instance, logger)
    # create a dict with {instance_id, input_path} pairs
    instances_input_path = _instance_to_input_path(cell_obj_instance)
    _print_json(
        "Instances will be calculated with corresponding input paths",
        instances_input_path,
        logger,
    )

    # check that the version in config.yml is same as from graph api
    _check_versions(cell_obj_instance, config, client)

    ## Step 3. Run Instances. Run maximum number of instances in parallel

    all_instance_ids = []
    chunks = _get_chunks(instances_input_path, MAX_NUM_INSTANCES)
    for chunk in chunks:
        instance_ids = list(chunk.keys())
        all_instance_ids.extend(instance_ids)
        chunk_input_paths = list(map(lambda x: x["input_path"], chunk.values()))
        chunk_num_shards = list(map(lambda x: x["num_shards"], chunk.values()))
        logger.info(f"Start running instances {instance_ids}.")
        run_instances(
            config,
            instance_ids,
            chunk_input_paths,
            chunk_num_shards,
            stage_flow,
            logger,
            num_tries,
            dry_run,
            result_visibility,
        )
        logger.info(f"Finished running instances {instance_ids}.")

    ## Step 4: Print out the initial and end states
    new_cell_obj_instances = _get_cell_obj_instance(
        _get_study_data(study_id, client), objective_ids, input_paths
    )
    _print_json(
        "Pre-run statuses for instance of each cell-objective pair",
        cell_obj_instance,
        logger,
    )
    _print_json(
        "Post-run statuses for instance of each cell-objective pair",
        new_cell_obj_instances,
        logger,
    )

    for instance_id in all_instance_ids:
        if (
            get_instance(config, instance_id, logger).status
            is not PrivateComputationInstanceStatus.AGGREGATION_COMPLETED
        ):
            raise OneCommandRunnerBaseException(
                f"{instance_id=} FAILED.",
                "Status is not aggregation completed",
                "Check logs for more information",
            )
 def update_instance(self) -> None:
     self.status = get_instance(self.config, self.instance_id,
                                self.logger).status
def main(argv: Optional[List[str]] = None) -> None:
    s = schema.Schema(
        {
            "create_instance": bool,
            "validate": bool,
            "run_next": bool,
            "run_stage": bool,
            "get_instance": bool,
            "get_server_ips": bool,
            "get_pid": bool,
            "get_mpc": bool,
            "run_instance": bool,
            "run_instances": bool,
            "run_study": bool,
            "pre_validate": bool,
            "run_attribution": bool,
            "cancel_current_stage": bool,
            "print_instance": bool,
            "print_current_status": bool,
            "print_log_urls": bool,
            "get_attribution_dataset_info": bool,
            "bolt_e2e": bool,
            "<instance_id>": schema.Or(None, str),
            "<instance_ids>": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "<study_id>": schema.Or(None, str),
            "--config": schema.Or(
                None, schema.And(schema.Use(PurePath), os.path.exists)
            ),
            "--bolt_config": schema.Or(
                None, schema.And(schema.Use(PurePath), os.path.exists)
            ),
            "--role": schema.Or(
                None,
                schema.And(
                    schema.Use(str.upper),
                    lambda s: s in ("PUBLISHER", "PARTNER"),
                    schema.Use(PrivateComputationRole),
                ),
            ),
            "--game_type": schema.Or(
                None,
                schema.And(
                    schema.Use(str.upper),
                    lambda s: s in ("LIFT", "ATTRIBUTION"),
                    schema.Use(PrivateComputationGameType),
                ),
            ),
            "--objective_ids": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "--dataset_id": schema.Or(None, str),
            "--input_path": schema.Or(None, transform_path),
            "--input_paths": schema.Or(None, schema.Use(transform_many_paths)),
            "--output_dir": schema.Or(None, transform_path),
            "--aggregated_result_path": schema.Or(None, str),
            "--expected_result_path": schema.Or(None, str),
            "--num_pid_containers": schema.Or(None, schema.Use(int)),
            "--num_mpc_containers": schema.Or(None, schema.Use(int)),
            "--aggregation_type": schema.Or(None, schema.Use(AggregationType)),
            "--attribution_rule": schema.Or(None, schema.Use(AttributionRule)),
            "--timestamp": schema.Or(None, str),
            "--num_files_per_mpc_container": schema.Or(None, schema.Use(int)),
            "--num_shards": schema.Or(None, schema.Use(int)),
            "--num_shards_list": schema.Or(
                None, schema.Use(lambda arg: arg.split(","))
            ),
            "--server_ips": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "--concurrency": schema.Or(None, schema.Use(int)),
            "--padding_size": schema.Or(None, schema.Use(int)),
            "--k_anonymity_threshold": schema.Or(None, schema.Use(int)),
            "--hmac_key": schema.Or(None, str),
            "--tries_per_stage": schema.Or(None, schema.Use(int)),
            "--dry_run": bool,
            "--logging_service": schema.Or(
                None,
                schema.And(
                    schema.Use(str),
                    lambda arg: parse_host_port(arg)[1] > 0,
                ),
            ),
            "--log_path": schema.Or(None, schema.Use(Path)),
            "--stage_flow": schema.Or(
                None,
                schema.Use(
                    lambda arg: PrivateComputationBaseStageFlow.cls_name_to_cls(arg)
                ),
            ),
            "--result_visibility": schema.Or(
                None,
                schema.Use(lambda arg: ResultVisibility[arg.upper()]),
            ),
            "--stage": schema.Or(None, str),
            "--verbose": bool,
            "--help": bool,
        }
    )
    arguments = s.validate(docopt(__doc__, argv))

    config = {}
    if arguments["--config"]:
        config = ConfigYamlDict.from_file(arguments["--config"])
    # if no --config given and endpoint isn't bolt_e2e, raise exception
    # bolt_e2e endpoint needs --bolt_config argument
    elif not arguments["bolt_e2e"]:
        raise ValueError("--config is a required argument")

    log_path = arguments["--log_path"]
    instance_id = arguments["<instance_id>"]

    # if log_path specified, logging using FileHandler, or console StreamHandler
    log_handler = logging.FileHandler(log_path) if log_path else logging.StreamHandler()
    logging.Formatter.converter = time.gmtime
    logging.basicConfig(
        # Root log level must be INFO or up, to avoid logging debug data which might
        # contain PII.
        level=logging.INFO,
        handlers=[log_handler],
        format="%(asctime)sZ %(levelname)s t:%(threadName)s n:%(name)s ! %(message)s",
    )
    logger = logging.getLogger(__name__)
    log_level = logging.DEBUG if arguments["--verbose"] else logging.INFO
    logger.setLevel(log_level)
    # Concatenate all arguments to a string, with every argument wrapped by quotes.
    all_options = f"{sys.argv[1:]}"[1:-1].replace("', '", "' '")
    # E.g. Command line: private_computation_cli 'create_instance' 'partner_15464380' '--config=/tmp/tmp21ari0i6/config_local.yml' ...
    logging.info(f"Command line: {Path(__file__).stem} {all_options}")

    # When the logging service argument is specified, its value is like "localhost:9090".
    # When the argument is missing, logging service client will be disabled, i.e. no-op.
    (logging_service_host, logging_service_port) = parse_host_port(
        arguments["--logging_service"]
    )
    logger.info(
        f"Client using logging service host: {logging_service_host}, port: {logging_service_port}."
    )
    logging_service_client = ClientManager(logging_service_host, logging_service_port)

    if arguments["create_instance"]:
        logger.info(f"Create instance: {instance_id}")
        put_log_metadata(
            logging_service_client, arguments["--game_type"], "create_instance"
        )
        create_instance(
            config=config,
            instance_id=instance_id,
            role=arguments["--role"],
            game_type=arguments["--game_type"],
            logger=logger,
            input_path=arguments["--input_path"],
            output_dir=arguments["--output_dir"],
            num_pid_containers=arguments["--num_pid_containers"],
            num_mpc_containers=arguments["--num_mpc_containers"],
            attribution_rule=arguments["--attribution_rule"],
            aggregation_type=arguments["--aggregation_type"],
            concurrency=arguments["--concurrency"],
            num_files_per_mpc_container=arguments["--num_files_per_mpc_container"],
            hmac_key=arguments["--hmac_key"],
            padding_size=arguments["--padding_size"],
            k_anonymity_threshold=arguments["--k_anonymity_threshold"],
            stage_flow_cls=arguments["--stage_flow"],
            result_visibility=arguments["--result_visibility"],
        )
    elif arguments["run_next"]:
        logger.info(f"run_next instance: {instance_id}")
        run_next(
            config=config,
            instance_id=instance_id,
            logger=logger,
            server_ips=arguments["--server_ips"],
        )
    elif arguments["run_stage"]:
        stage_name = arguments["--stage"]
        logger.info(f"run_stage: {instance_id=}, {stage_name=}")
        instance = get_instance(config, instance_id, logger)
        stage = instance.stage_flow.get_stage_from_str(stage_name)
        run_stage(
            config=config,
            instance_id=instance_id,
            stage=stage,
            logger=logger,
            server_ips=arguments["--server_ips"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["get_instance"]:
        logger.info(f"Get instance: {instance_id}")
        instance = get_instance(config, instance_id, logger)
        logger.info(instance)
    elif arguments["get_server_ips"]:
        get_server_ips(config, instance_id, logger)
    elif arguments["get_pid"]:
        logger.info(f"Get PID instance: {instance_id}")
        get_pid(config, instance_id, logger)
    elif arguments["get_mpc"]:
        logger.info(f"Get MPC instance: {instance_id}")
        get_mpc(config, instance_id, logger)
    elif arguments["validate"]:
        logger.info(f"Validate instance: {instance_id}")
        validate(
            config=config,
            instance_id=instance_id,
            aggregated_result_path=arguments["--aggregated_result_path"],
            expected_result_path=arguments["--expected_result_path"],
            logger=logger,
        )
    elif arguments["run_instance"]:
        stage_flow = PrivateComputationStageFlow
        logger.info(f"Running instance: {instance_id}")
        run_instance(
            config=config,
            instance_id=instance_id,
            input_path=arguments["--input_path"],
            game_type=arguments["--game_type"],
            num_mpc_containers=arguments["--num_shards"],
            num_pid_containers=arguments["--num_shards"],
            stage_flow=stage_flow,
            logger=logger,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["run_instances"]:
        stage_flow = PrivateComputationStageFlow
        run_instances(
            config=config,
            instance_ids=arguments["<instance_ids>"],
            input_paths=arguments["--input_paths"],
            num_shards_list=arguments["--num_shards_list"],
            stage_flow=stage_flow,
            logger=logger,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["run_study"]:
        stage_flow = PrivateComputationStageFlow
        run_study(
            config=config,
            study_id=arguments["<study_id>"],
            objective_ids=arguments["--objective_ids"],
            input_paths=arguments["--input_paths"],
            logger=logger,
            stage_flow=stage_flow,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
            result_visibility=arguments["--result_visibility"],
        )
    elif arguments["run_attribution"]:
        stage_flow = PrivateComputationPCF2StageFlow
        run_attribution(
            config=config,
            dataset_id=arguments["--dataset_id"],
            input_path=arguments["--input_path"],
            timestamp=arguments["--timestamp"],
            attribution_rule=arguments["--attribution_rule"],
            aggregation_type=arguments["--aggregation_type"],
            concurrency=arguments["--concurrency"],
            num_files_per_mpc_container=arguments["--num_files_per_mpc_container"],
            k_anonymity_threshold=arguments["--k_anonymity_threshold"],
            logger=logger,
            stage_flow=stage_flow,
            num_tries=2,
        )

    elif arguments["cancel_current_stage"]:
        logger.info(f"Canceling the current running stage of instance: {instance_id}")
        cancel_current_stage(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_instance"]:
        print_instance(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_current_status"]:
        print("print_current_status")
        print_current_status(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_log_urls"]:
        print_log_urls(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["get_attribution_dataset_info"]:
        print(
            get_attribution_dataset_info(
                config=config, dataset_id=arguments["--dataset_id"], logger=logger
            )
        )
    elif arguments["pre_validate"]:
        input_paths = (
            [arguments["--input_path"]]
            if arguments["--input_path"]
            else arguments["--input_paths"]
        )
        PreValidateService.pre_validate(
            config=config,
            input_paths=input_paths,
            logger=logger,
        )
    elif arguments["bolt_e2e"]:
        bolt_config = ConfigYamlDict.from_file(arguments["--bolt_config"])
        bolt_runner, jobs = parse_bolt_config(config=bolt_config, logger=logger)
        run_results = asyncio.run(bolt_runner.run_async(jobs))
        if not all(run_results):
            failed_job_names = [
                job.job_name for job, result in zip(jobs, run_results) if not result
            ]
            raise RuntimeError(f"Jobs failed: {failed_job_names}")
        else:
            print("Jobs succeeded")