Ejemplo n.º 1
0
def create_bolt_runner(runner_config: Dict[str, Any],
                       logger: logging.Logger) -> BoltRunner:
    publisher_client_config = ConfigYamlDict.from_file(
        runner_config["publisher_client_config"])
    partner_client_config = ConfigYamlDict.from_file(
        runner_config["partner_client_config"])
    publisher_client = BoltPCSClient(
        _build_private_computation_service(
            publisher_client_config["private_computation"],
            publisher_client_config["mpc"],
            publisher_client_config["pid"],
            publisher_client_config.get("post_processing_handlers", {}),
            publisher_client_config.get("pid_post_processing_handlers", {}),
        ))
    partner_client = BoltPCSClient(
        _build_private_computation_service(
            partner_client_config["private_computation"],
            partner_client_config["mpc"],
            partner_client_config["pid"],
            partner_client_config.get("post_processing_handlers", {}),
            partner_client_config.get("pid_post_processing_handlers", {}),
        ))

    runner = BoltRunner(publisher_client=publisher_client,
                        partner_client=partner_client,
                        logger=logger)
    return runner
Ejemplo n.º 2
0
    def test_load_from_invalid_file(self, mock_file) -> None:
        self.assertEqual(open(self.test_filename).read(), self.invalid_data)

        with self.assertRaises(ConfigYamlFileParsingError) as error_context:
            ConfigYamlDict.from_file(self.test_filename)
            self.assertTrue(
                str(error_context.exception).startswith(f"""
                    {self.test_filename} is not a valid YAML file.
                    Please make sure that the content of your config is a valid YAML.
                    \nCause:"""))
Ejemplo n.º 3
0
    def _get_graph_api_token(self, config: Dict[str, Any]) -> str:
        f"""Get graph API token from config.yml or the {FBPCS_GRAPH_API_TOKEN} env var

        Arguments:
            config: dictionary representation of config.yml file

        Returns:
            the graph api token

        Raises:
            GraphAPITokenNotFound: graph api token not in config.yml and not in env var
        """
        try:
            if not isinstance(config, ConfigYamlDict):
                config = ConfigYamlDict.from_dict(config)
            self.logger.info(
                "attempting to read graph api token from config.yml file")
            token = config["graphapi"]["access_token"]
            self.logger.info(
                "successfuly read graph api token from config.yml file")
        except ConfigYamlBaseException:
            self.logger.info(
                f"attempting to read graph api token from {FBPCS_GRAPH_API_TOKEN} env var"
            )
            token = os.getenv(FBPCS_GRAPH_API_TOKEN)
            if not token:
                no_token_exception = GraphAPITokenNotFound.make_error()
                self.logger.exception(no_token_exception)
                raise no_token_exception from None
            self.logger.info(
                f"successfully read graph api token from {FBPCS_GRAPH_API_TOKEN} env var"
            )
        return token
Ejemplo n.º 4
0
    def test_pre_validate_with_minimal_input_paths_args(
            self, getLoggerMock, pre_validate_service_mock) -> None:
        getLoggerMock.return_value = getLoggerMock
        expected_config = ConfigYamlDict.from_file(self.temp_filename)
        argv = [
            "pre_validate",
            f"--config={self.temp_filename}",
            f"--input_paths={','.join(self.temp_files_paths)}",
        ]

        pc_cli.main(argv)

        pre_validate_service_mock.pre_validate.assert_called_once_with(
            config=expected_config,
            input_paths=self.temp_files_paths,
            logger=getLoggerMock,
        )
Ejemplo n.º 5
0
    def test_pre_validate_with_pa_args(self, getLoggerMock,
                                       pre_validate_service_mock) -> None:
        getLoggerMock.return_value = getLoggerMock
        expected_config = ConfigYamlDict.from_file(self.temp_filename)
        argv = [
            "pre_validate",
            f"--config={self.temp_filename}",
            "--dataset_id=123",
            f"--input_path={self.temp_files_paths[0]}",
            "--timestamp=1651847976",
            "--attribution_rule=last_click_1d",
            "--aggregation_type=measurement",
            "--concurrency=1",
            "--num_files_per_mpc_container=1",
            "--k_anonymity_threshold=10",
        ]

        pc_cli.main(argv)

        pre_validate_service_mock.pre_validate.assert_called_once_with(
            config=expected_config,
            input_paths=[self.temp_files_paths[0]],
            logger=getLoggerMock,
        )
Ejemplo n.º 6
0
    def test_load_from_file_success(self, mock_file) -> None:
        self.assertEqual(open(self.test_filename).read(), self.valid_data)

        load_data = ConfigYamlDict.from_file(self.test_filename)
        self.assertEqual(load_data, self.test_dict)
def main(argv: Optional[List[str]] = None) -> None:
    s = schema.Schema(
        {
            "create_instance": bool,
            "validate": bool,
            "run_next": bool,
            "run_stage": bool,
            "get_instance": bool,
            "get_server_ips": bool,
            "get_pid": bool,
            "get_mpc": bool,
            "run_instance": bool,
            "run_instances": bool,
            "run_study": bool,
            "pre_validate": bool,
            "run_attribution": bool,
            "cancel_current_stage": bool,
            "print_instance": bool,
            "print_current_status": bool,
            "print_log_urls": bool,
            "get_attribution_dataset_info": bool,
            "bolt_e2e": bool,
            "<instance_id>": schema.Or(None, str),
            "<instance_ids>": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "<study_id>": schema.Or(None, str),
            "--config": schema.Or(
                None, schema.And(schema.Use(PurePath), os.path.exists)
            ),
            "--bolt_config": schema.Or(
                None, schema.And(schema.Use(PurePath), os.path.exists)
            ),
            "--role": schema.Or(
                None,
                schema.And(
                    schema.Use(str.upper),
                    lambda s: s in ("PUBLISHER", "PARTNER"),
                    schema.Use(PrivateComputationRole),
                ),
            ),
            "--game_type": schema.Or(
                None,
                schema.And(
                    schema.Use(str.upper),
                    lambda s: s in ("LIFT", "ATTRIBUTION"),
                    schema.Use(PrivateComputationGameType),
                ),
            ),
            "--objective_ids": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "--dataset_id": schema.Or(None, str),
            "--input_path": schema.Or(None, transform_path),
            "--input_paths": schema.Or(None, schema.Use(transform_many_paths)),
            "--output_dir": schema.Or(None, transform_path),
            "--aggregated_result_path": schema.Or(None, str),
            "--expected_result_path": schema.Or(None, str),
            "--num_pid_containers": schema.Or(None, schema.Use(int)),
            "--num_mpc_containers": schema.Or(None, schema.Use(int)),
            "--aggregation_type": schema.Or(None, schema.Use(AggregationType)),
            "--attribution_rule": schema.Or(None, schema.Use(AttributionRule)),
            "--timestamp": schema.Or(None, str),
            "--num_files_per_mpc_container": schema.Or(None, schema.Use(int)),
            "--num_shards": schema.Or(None, schema.Use(int)),
            "--num_shards_list": schema.Or(
                None, schema.Use(lambda arg: arg.split(","))
            ),
            "--server_ips": schema.Or(None, schema.Use(lambda arg: arg.split(","))),
            "--concurrency": schema.Or(None, schema.Use(int)),
            "--padding_size": schema.Or(None, schema.Use(int)),
            "--k_anonymity_threshold": schema.Or(None, schema.Use(int)),
            "--hmac_key": schema.Or(None, str),
            "--tries_per_stage": schema.Or(None, schema.Use(int)),
            "--dry_run": bool,
            "--logging_service": schema.Or(
                None,
                schema.And(
                    schema.Use(str),
                    lambda arg: parse_host_port(arg)[1] > 0,
                ),
            ),
            "--log_path": schema.Or(None, schema.Use(Path)),
            "--stage_flow": schema.Or(
                None,
                schema.Use(
                    lambda arg: PrivateComputationBaseStageFlow.cls_name_to_cls(arg)
                ),
            ),
            "--result_visibility": schema.Or(
                None,
                schema.Use(lambda arg: ResultVisibility[arg.upper()]),
            ),
            "--stage": schema.Or(None, str),
            "--verbose": bool,
            "--help": bool,
        }
    )
    arguments = s.validate(docopt(__doc__, argv))

    config = {}
    if arguments["--config"]:
        config = ConfigYamlDict.from_file(arguments["--config"])
    # if no --config given and endpoint isn't bolt_e2e, raise exception
    # bolt_e2e endpoint needs --bolt_config argument
    elif not arguments["bolt_e2e"]:
        raise ValueError("--config is a required argument")

    log_path = arguments["--log_path"]
    instance_id = arguments["<instance_id>"]

    # if log_path specified, logging using FileHandler, or console StreamHandler
    log_handler = logging.FileHandler(log_path) if log_path else logging.StreamHandler()
    logging.Formatter.converter = time.gmtime
    logging.basicConfig(
        # Root log level must be INFO or up, to avoid logging debug data which might
        # contain PII.
        level=logging.INFO,
        handlers=[log_handler],
        format="%(asctime)sZ %(levelname)s t:%(threadName)s n:%(name)s ! %(message)s",
    )
    logger = logging.getLogger(__name__)
    log_level = logging.DEBUG if arguments["--verbose"] else logging.INFO
    logger.setLevel(log_level)
    # Concatenate all arguments to a string, with every argument wrapped by quotes.
    all_options = f"{sys.argv[1:]}"[1:-1].replace("', '", "' '")
    # E.g. Command line: private_computation_cli 'create_instance' 'partner_15464380' '--config=/tmp/tmp21ari0i6/config_local.yml' ...
    logging.info(f"Command line: {Path(__file__).stem} {all_options}")

    # When the logging service argument is specified, its value is like "localhost:9090".
    # When the argument is missing, logging service client will be disabled, i.e. no-op.
    (logging_service_host, logging_service_port) = parse_host_port(
        arguments["--logging_service"]
    )
    logger.info(
        f"Client using logging service host: {logging_service_host}, port: {logging_service_port}."
    )
    logging_service_client = ClientManager(logging_service_host, logging_service_port)

    if arguments["create_instance"]:
        logger.info(f"Create instance: {instance_id}")
        put_log_metadata(
            logging_service_client, arguments["--game_type"], "create_instance"
        )
        create_instance(
            config=config,
            instance_id=instance_id,
            role=arguments["--role"],
            game_type=arguments["--game_type"],
            logger=logger,
            input_path=arguments["--input_path"],
            output_dir=arguments["--output_dir"],
            num_pid_containers=arguments["--num_pid_containers"],
            num_mpc_containers=arguments["--num_mpc_containers"],
            attribution_rule=arguments["--attribution_rule"],
            aggregation_type=arguments["--aggregation_type"],
            concurrency=arguments["--concurrency"],
            num_files_per_mpc_container=arguments["--num_files_per_mpc_container"],
            hmac_key=arguments["--hmac_key"],
            padding_size=arguments["--padding_size"],
            k_anonymity_threshold=arguments["--k_anonymity_threshold"],
            stage_flow_cls=arguments["--stage_flow"],
            result_visibility=arguments["--result_visibility"],
        )
    elif arguments["run_next"]:
        logger.info(f"run_next instance: {instance_id}")
        run_next(
            config=config,
            instance_id=instance_id,
            logger=logger,
            server_ips=arguments["--server_ips"],
        )
    elif arguments["run_stage"]:
        stage_name = arguments["--stage"]
        logger.info(f"run_stage: {instance_id=}, {stage_name=}")
        instance = get_instance(config, instance_id, logger)
        stage = instance.stage_flow.get_stage_from_str(stage_name)
        run_stage(
            config=config,
            instance_id=instance_id,
            stage=stage,
            logger=logger,
            server_ips=arguments["--server_ips"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["get_instance"]:
        logger.info(f"Get instance: {instance_id}")
        instance = get_instance(config, instance_id, logger)
        logger.info(instance)
    elif arguments["get_server_ips"]:
        get_server_ips(config, instance_id, logger)
    elif arguments["get_pid"]:
        logger.info(f"Get PID instance: {instance_id}")
        get_pid(config, instance_id, logger)
    elif arguments["get_mpc"]:
        logger.info(f"Get MPC instance: {instance_id}")
        get_mpc(config, instance_id, logger)
    elif arguments["validate"]:
        logger.info(f"Validate instance: {instance_id}")
        validate(
            config=config,
            instance_id=instance_id,
            aggregated_result_path=arguments["--aggregated_result_path"],
            expected_result_path=arguments["--expected_result_path"],
            logger=logger,
        )
    elif arguments["run_instance"]:
        stage_flow = PrivateComputationStageFlow
        logger.info(f"Running instance: {instance_id}")
        run_instance(
            config=config,
            instance_id=instance_id,
            input_path=arguments["--input_path"],
            game_type=arguments["--game_type"],
            num_mpc_containers=arguments["--num_shards"],
            num_pid_containers=arguments["--num_shards"],
            stage_flow=stage_flow,
            logger=logger,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["run_instances"]:
        stage_flow = PrivateComputationStageFlow
        run_instances(
            config=config,
            instance_ids=arguments["<instance_ids>"],
            input_paths=arguments["--input_paths"],
            num_shards_list=arguments["--num_shards_list"],
            stage_flow=stage_flow,
            logger=logger,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
        )
    elif arguments["run_study"]:
        stage_flow = PrivateComputationStageFlow
        run_study(
            config=config,
            study_id=arguments["<study_id>"],
            objective_ids=arguments["--objective_ids"],
            input_paths=arguments["--input_paths"],
            logger=logger,
            stage_flow=stage_flow,
            num_tries=arguments["--tries_per_stage"],
            dry_run=arguments["--dry_run"],
            result_visibility=arguments["--result_visibility"],
        )
    elif arguments["run_attribution"]:
        stage_flow = PrivateComputationPCF2StageFlow
        run_attribution(
            config=config,
            dataset_id=arguments["--dataset_id"],
            input_path=arguments["--input_path"],
            timestamp=arguments["--timestamp"],
            attribution_rule=arguments["--attribution_rule"],
            aggregation_type=arguments["--aggregation_type"],
            concurrency=arguments["--concurrency"],
            num_files_per_mpc_container=arguments["--num_files_per_mpc_container"],
            k_anonymity_threshold=arguments["--k_anonymity_threshold"],
            logger=logger,
            stage_flow=stage_flow,
            num_tries=2,
        )

    elif arguments["cancel_current_stage"]:
        logger.info(f"Canceling the current running stage of instance: {instance_id}")
        cancel_current_stage(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_instance"]:
        print_instance(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_current_status"]:
        print("print_current_status")
        print_current_status(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["print_log_urls"]:
        print_log_urls(
            config=config,
            instance_id=instance_id,
            logger=logger,
        )
    elif arguments["get_attribution_dataset_info"]:
        print(
            get_attribution_dataset_info(
                config=config, dataset_id=arguments["--dataset_id"], logger=logger
            )
        )
    elif arguments["pre_validate"]:
        input_paths = (
            [arguments["--input_path"]]
            if arguments["--input_path"]
            else arguments["--input_paths"]
        )
        PreValidateService.pre_validate(
            config=config,
            input_paths=input_paths,
            logger=logger,
        )
    elif arguments["bolt_e2e"]:
        bolt_config = ConfigYamlDict.from_file(arguments["--bolt_config"])
        bolt_runner, jobs = parse_bolt_config(config=bolt_config, logger=logger)
        run_results = asyncio.run(bolt_runner.run_async(jobs))
        if not all(run_results):
            failed_job_names = [
                job.job_name for job, result in zip(jobs, run_results) if not result
            ]
            raise RuntimeError(f"Jobs failed: {failed_job_names}")
        else:
            print("Jobs succeeded")