def get_attribution_dataset_info(config: Dict[str, Any], dataset_id: str, logger: logging.Logger) -> str: client = PCGraphAPIClient(config, logger) return json.loads( client.get_attribution_dataset_info( dataset_id, [DATASETS_INFORMATION], ).text)
def _get_attribution_dataset_info(client: PCGraphAPIClient, dataset_id: str, logger: logging.Logger) -> Any: return json.loads( client.get_attribution_dataset_info( dataset_id, [DATASETS_INFORMATION], ).text)
def _check_versions( cell_obj_instances: Dict[str, Dict[str, Dict[str, Any]]], config: Dict[str, Any], client: PCGraphAPIClient, ) -> None: """Checks that the publisher version (graph api) and the partner version (config.yml) are the same Arguments: cell_obj_instances: theoretically is dict mapping cell->obj->instance. config: The dict representation of a config.yml file client: Interface for submitting graph API requests Raises: IncorrectVersionError: the publisher and partner are running with different versions """ config_tier = get_tier(config) for cell_id in cell_obj_instances: for objective_id in cell_obj_instances[cell_id]: instance_data = cell_obj_instances[cell_id][objective_id] instance_id = instance_data["instance_id"] # if there is no tier for some reason (e.g. old study?), let's just assume # the tier is correct tier_str = json.loads(client.get_instance(instance_id).text).get("tier") if tier_str: expected_tier = PCSTier.from_str(tier_str) if expected_tier is not config_tier: raise IncorrectVersionError.make_error( instance_id, expected_tier, config_tier )
def test_get_graph_api_token_dict_and_env(self) -> None: expected_token = "from_dict" with patch.dict("os.environ", {FBPCS_GRAPH_API_TOKEN: "from_env"}): config = {"graphapi": {"access_token": expected_token}} actual_token = PCGraphAPIClient(config, self.mock_logger).access_token self.assertEqual(expected_token, actual_token)
def _create_instance_retry( client: PCGraphAPIClient, study_id: str, cell_id: str, objective_id: str, logger: logging.Logger, ) -> str: tries = 0 while tries < CREATE_INSTANCE_TRIES: tries += 1 try: instance_id = json.loads( client.create_instance( study_id, {"cell_id": cell_id, "objective_id": objective_id} ).text )["id"] logger.info( f"Created instance {instance_id} for cell {cell_id} and objective {objective_id}" ) return instance_id except GraphAPIGenericException as err: if tries >= CREATE_INSTANCE_TRIES: logger.error( f"Error: Instance not created for cell {cell_id} and {objective_id}" ) raise err logger.info( f"Instance not created for cell {cell_id} and {objective_id}. Retrying:" ) return "" # this is to make pyre happy
def test_get_graph_api_token_from_env_config_no_field(self) -> None: expected_token = "from_env" with patch.dict("os.environ", {FBPCS_GRAPH_API_TOKEN: expected_token}): config = {"graphapi": {"random_field": "not_a_token"}} actual_token = PCGraphAPIClient(config, self.mock_logger).access_token self.assertEqual(expected_token, actual_token)
def _get_study_data(study_id: str, client: PCGraphAPIClient) -> Any: return json.loads( client.get_study_data( study_id, [ TYPE, START_TIME, OBSERVATION_END_TIME, OBJECTIVES, OPP_DATA_INFORMATION, INSTANCES, ], ).text )
def run_instance( *, config: Dict[str, Any], instance_id: str, input_path: str, num_mpc_containers: int, num_pid_containers: int, stage_flow: Type[PrivateComputationBaseStageFlow], logger: logging.Logger, game_type: PrivateComputationGameType, attribution_rule: Optional[AttributionRule] = None, aggregation_type: Optional[AggregationType] = None, concurrency: Optional[int] = None, num_files_per_mpc_container: Optional[int] = None, k_anonymity_threshold: Optional[int] = None, num_tries: Optional[int] = 2, # this is number of tries per stage dry_run: Optional[bool] = False, result_visibility: Optional[ResultVisibility] = None, ) -> None: num_tries = num_tries if num_tries is not None else MAX_TRIES if num_tries < MIN_TRIES or num_tries > MAX_TRIES: raise PCStudyValidationException( "Number of retries not allowed", f"num_tries must be between {MIN_TRIES} and {MAX_TRIES}.", ) client = PCGraphAPIClient(config, logger) instance_runner = PLInstanceRunner( config, instance_id, input_path, num_mpc_containers, num_pid_containers, logger, client, num_tries, game_type, dry_run, stage_flow, attribution_rule, aggregation_type, concurrency, num_files_per_mpc_container, k_anonymity_threshold, result_visibility, ) logger.info( f"Running private {game_type.name.lower()} for instance {instance_id}") instance_runner.run()
def _create_new_instance( dataset_id: str, timestamp: int, attribution_rule: str, client: PCGraphAPIClient, logger: logging.Logger, ) -> str: instance_id = json.loads( client.create_pa_instance( dataset_id, timestamp, attribution_rule, 2, ).text)["id"] logger.info( f"Created instance {instance_id} for dataset {dataset_id} and attribution rule {attribution_rule}" ) return instance_id
def run_attribution( config: Dict[str, Any], dataset_id: str, input_path: str, timestamp: str, attribution_rule: AttributionRule, aggregation_type: AggregationType, concurrency: int, num_files_per_mpc_container: int, k_anonymity_threshold: int, stage_flow: Type[PrivateComputationBaseStageFlow], logger: logging.Logger, num_tries: Optional[int] = 2, # this is number of tries per stage ) -> None: ## Step 1: Validation. Function arguments and for private attribution run. # obtain the values in the dataset info vector. client = PCGraphAPIClient(config, logger) datasets_info = _get_attribution_dataset_info(client, dataset_id, logger) datasets = datasets_info[DATASETS_INFORMATION] matched_data = {} attribution_rule_str = attribution_rule.name attribution_rule_val = attribution_rule.value instance_id = None pacific_timezone = pytz.timezone("US/Pacific") # Validate if input is datetime or timestamp is_date_format = _iso_date_validator(timestamp) if is_date_format: dt = pacific_timezone.localize(datetime.strptime( timestamp, "%Y-%m-%d")) else: dt = datetime.fromtimestamp(int(timestamp), tz=timezone.utc) # Compute the argument after the timestamp has been input dt_arg = int(datetime.timestamp(dt)) # Verify that input has matching dataset info: # a. attribution rule # b. timestamp if len(datasets) == 0: raise ValueError("Dataset for given parameters and dataset invalid") for data in datasets: if data["key"] == attribution_rule_str: matched_attr = data["value"] for m_data in matched_attr: m_time = dateutil.parser.parse(m_data[TIMESTAMP]) if m_time == dt: matched_data = m_data break if len(matched_data) == 0: raise ValueError("No dataset matching to the information provided") # Step 2: Validate what instances need to be created vs what already exist # Conditions for retry: # 1. Not in a terminal status # 2. Instance has been created > 1d ago dataset_instance_data = _get_existing_pa_instances(client, dataset_id) existing_instances = dataset_instance_data["data"] for inst in existing_instances: inst_time = dateutil.parser.parse(inst[TIMESTAMP]) creation_time = dateutil.parser.parse(inst[CREATED_TIME]) exp_time = datetime.now(tz=timezone.utc) - timedelta(days=1) expired = exp_time > creation_time if (inst[ATTRIBUTION_RULE] == attribution_rule_val and inst_time == dt and inst[STATUS] not in TERMINAL_STATUSES and not expired): instance_id = inst["id"] break if instance_id is None: instance_id = _create_new_instance( dataset_id, int(dt_arg), attribution_rule_val, client, logger, ) instance_data = _get_pa_instance_info(client, instance_id, logger) _check_version(instance_data, config) num_pid_containers = instance_data[NUM_SHARDS] num_mpc_containers = instance_data[NUM_CONTAINERS] ## Step 3. Run Instances. Run maximum number of instances in parallel logger.info(f"Start running instance {instance_id}.") instance_parameters = { "config": config, "instance_id": instance_id, "input_path": input_path, "num_mpc_containers": num_mpc_containers, "num_pid_containers": num_pid_containers, "stage_flow": stage_flow, "logger": logger, "game_type": PrivateComputationGameType.ATTRIBUTION, "attribution_rule": attribution_rule, "aggregation_type": AggregationType.MEASUREMENT, "concurrency": concurrency, "num_files_per_mpc_container": num_files_per_mpc_container, "k_anonymity_threshold": k_anonymity_threshold, "num_tries": num_tries, } run_instance(**instance_parameters) logger.info(f"Finished running instances {instance_id}.")
def _get_existing_pa_instances(client: PCGraphAPIClient, dataset_id: str) -> Any: return json.loads(client.get_existing_pa_instances(dataset_id).text)
def _get_pa_instance_info(client: PCGraphAPIClient, instance_id: str, logger: logging.Logger) -> Any: return json.loads(client.get_instance(instance_id).text)
def run_study( config: Dict[str, Any], study_id: str, objective_ids: List[str], input_paths: List[str], logger: logging.Logger, stage_flow: Type[PrivateComputationBaseStageFlow], num_tries: Optional[int] = 2, # this is number of tries per stage dry_run: Optional[bool] = False, # if set to true, it will only run one stage result_visibility: Optional[ResultVisibility] = None, ) -> None: ## Step 1: Validation. Function arguments and study metadata must be valid for private lift run. _validate_input(objective_ids, input_paths) # obtain study information client = PCGraphAPIClient(config, logger) study_data = _get_study_data(study_id, client) # Verify study can run private lift: _verify_study_type(study_data) # verify mpc objectives _verify_mpc_objs(study_data, objective_ids) # verify study opp_data_information is non-empty if OPP_DATA_INFORMATION not in study_data: raise PCStudyValidationException( f"Study {study_id} has no opportunity datasets.", f"Check {study_id} study data to include {OPP_DATA_INFORMATION}", ) ## Step 2. Preparation. Find which cell-obj pairs should have new instances created for and which should use existing ## valid ones. If a valid instance exists for a particular cell-obj pair, use it. Otherwise, try to create one. cell_obj_instance = _get_cell_obj_instance( study_data, objective_ids, input_paths, ) _print_json( "Existing valid instances for cell-obj pairs", cell_obj_instance, logger ) # create new instances _create_new_instances(cell_obj_instance, study_id, client, logger) _print_json("Instances to run for cell-obj pairs", cell_obj_instance, logger) # create a dict with {instance_id, input_path} pairs instances_input_path = _instance_to_input_path(cell_obj_instance) _print_json( "Instances will be calculated with corresponding input paths", instances_input_path, logger, ) # check that the version in config.yml is same as from graph api _check_versions(cell_obj_instance, config, client) ## Step 3. Run Instances. Run maximum number of instances in parallel all_instance_ids = [] chunks = _get_chunks(instances_input_path, MAX_NUM_INSTANCES) for chunk in chunks: instance_ids = list(chunk.keys()) all_instance_ids.extend(instance_ids) chunk_input_paths = list(map(lambda x: x["input_path"], chunk.values())) chunk_num_shards = list(map(lambda x: x["num_shards"], chunk.values())) logger.info(f"Start running instances {instance_ids}.") run_instances( config, instance_ids, chunk_input_paths, chunk_num_shards, stage_flow, logger, num_tries, dry_run, result_visibility, ) logger.info(f"Finished running instances {instance_ids}.") ## Step 4: Print out the initial and end states new_cell_obj_instances = _get_cell_obj_instance( _get_study_data(study_id, client), objective_ids, input_paths ) _print_json( "Pre-run statuses for instance of each cell-objective pair", cell_obj_instance, logger, ) _print_json( "Post-run statuses for instance of each cell-objective pair", new_cell_obj_instances, logger, ) for instance_id in all_instance_ids: if ( get_instance(config, instance_id, logger).status is not PrivateComputationInstanceStatus.AGGREGATION_COMPLETED ): raise OneCommandRunnerBaseException( f"{instance_id=} FAILED.", "Status is not aggregation completed", "Check logs for more information", )
def test_get_graph_api_token_no_field(self) -> None: config = {"graphapi": {"random_field": "not_a_token"}} with self.assertRaises(GraphAPITokenNotFound): PCGraphAPIClient(config, self.mock_logger).access_token
def test_get_graph_api_token_no_token_todo(self) -> None: config = {"graphapi": {"access_token": "TODO"}} with self.assertRaises(GraphAPITokenNotFound): PCGraphAPIClient(config, self.mock_logger).access_token
def test_get_graph_api_token_from_dict(self) -> None: expected_token = "from_dict" config = {"graphapi": {"access_token": expected_token}} actual_token = PCGraphAPIClient(config, self.mock_logger).access_token self.assertEqual(expected_token, actual_token)