예제 #1
0
def get_attribution_dataset_info(config: Dict[str, Any], dataset_id: str,
                                 logger: logging.Logger) -> str:
    client = PCGraphAPIClient(config, logger)

    return json.loads(
        client.get_attribution_dataset_info(
            dataset_id,
            [DATASETS_INFORMATION],
        ).text)
예제 #2
0
def _get_attribution_dataset_info(client: PCGraphAPIClient, dataset_id: str,
                                  logger: logging.Logger) -> Any:
    return json.loads(
        client.get_attribution_dataset_info(
            dataset_id,
            [DATASETS_INFORMATION],
        ).text)
예제 #3
0
def _check_versions(
    cell_obj_instances: Dict[str, Dict[str, Dict[str, Any]]],
    config: Dict[str, Any],
    client: PCGraphAPIClient,
) -> None:
    """Checks that the publisher version (graph api) and the partner version (config.yml) are the same

    Arguments:
        cell_obj_instances: theoretically is dict mapping cell->obj->instance.
        config: The dict representation of a config.yml file
        client: Interface for submitting graph API requests

    Raises:
        IncorrectVersionError: the publisher and partner are running with different versions
    """

    config_tier = get_tier(config)

    for cell_id in cell_obj_instances:
        for objective_id in cell_obj_instances[cell_id]:
            instance_data = cell_obj_instances[cell_id][objective_id]
            instance_id = instance_data["instance_id"]
            # if there is no tier for some reason (e.g. old study?), let's just assume
            # the tier is correct
            tier_str = json.loads(client.get_instance(instance_id).text).get("tier")
            if tier_str:
                expected_tier = PCSTier.from_str(tier_str)
                if expected_tier is not config_tier:
                    raise IncorrectVersionError.make_error(
                        instance_id, expected_tier, config_tier
                    )
예제 #4
0
 def test_get_graph_api_token_dict_and_env(self) -> None:
     expected_token = "from_dict"
     with patch.dict("os.environ", {FBPCS_GRAPH_API_TOKEN: "from_env"}):
         config = {"graphapi": {"access_token": expected_token}}
         actual_token = PCGraphAPIClient(config,
                                         self.mock_logger).access_token
         self.assertEqual(expected_token, actual_token)
예제 #5
0
def _create_instance_retry(
    client: PCGraphAPIClient,
    study_id: str,
    cell_id: str,
    objective_id: str,
    logger: logging.Logger,
) -> str:
    tries = 0
    while tries < CREATE_INSTANCE_TRIES:
        tries += 1
        try:
            instance_id = json.loads(
                client.create_instance(
                    study_id, {"cell_id": cell_id, "objective_id": objective_id}
                ).text
            )["id"]
            logger.info(
                f"Created instance {instance_id} for cell {cell_id} and objective {objective_id}"
            )
            return instance_id
        except GraphAPIGenericException as err:
            if tries >= CREATE_INSTANCE_TRIES:
                logger.error(
                    f"Error: Instance not created for cell {cell_id} and {objective_id}"
                )
                raise err
            logger.info(
                f"Instance not created for cell {cell_id} and {objective_id}. Retrying:"
            )
    return ""  # this is to make pyre happy
예제 #6
0
 def test_get_graph_api_token_from_env_config_no_field(self) -> None:
     expected_token = "from_env"
     with patch.dict("os.environ", {FBPCS_GRAPH_API_TOKEN: expected_token}):
         config = {"graphapi": {"random_field": "not_a_token"}}
         actual_token = PCGraphAPIClient(config,
                                         self.mock_logger).access_token
         self.assertEqual(expected_token, actual_token)
예제 #7
0
def _get_study_data(study_id: str, client: PCGraphAPIClient) -> Any:
    return json.loads(
        client.get_study_data(
            study_id,
            [
                TYPE,
                START_TIME,
                OBSERVATION_END_TIME,
                OBJECTIVES,
                OPP_DATA_INFORMATION,
                INSTANCES,
            ],
        ).text
    )
예제 #8
0
def run_instance(
    *,
    config: Dict[str, Any],
    instance_id: str,
    input_path: str,
    num_mpc_containers: int,
    num_pid_containers: int,
    stage_flow: Type[PrivateComputationBaseStageFlow],
    logger: logging.Logger,
    game_type: PrivateComputationGameType,
    attribution_rule: Optional[AttributionRule] = None,
    aggregation_type: Optional[AggregationType] = None,
    concurrency: Optional[int] = None,
    num_files_per_mpc_container: Optional[int] = None,
    k_anonymity_threshold: Optional[int] = None,
    num_tries: Optional[int] = 2,  # this is number of tries per stage
    dry_run: Optional[bool] = False,
    result_visibility: Optional[ResultVisibility] = None,
) -> None:
    num_tries = num_tries if num_tries is not None else MAX_TRIES
    if num_tries < MIN_TRIES or num_tries > MAX_TRIES:
        raise PCStudyValidationException(
            "Number of retries not allowed",
            f"num_tries must be between {MIN_TRIES} and {MAX_TRIES}.",
        )
    client = PCGraphAPIClient(config, logger)
    instance_runner = PLInstanceRunner(
        config,
        instance_id,
        input_path,
        num_mpc_containers,
        num_pid_containers,
        logger,
        client,
        num_tries,
        game_type,
        dry_run,
        stage_flow,
        attribution_rule,
        aggregation_type,
        concurrency,
        num_files_per_mpc_container,
        k_anonymity_threshold,
        result_visibility,
    )
    logger.info(
        f"Running private {game_type.name.lower()} for instance {instance_id}")
    instance_runner.run()
예제 #9
0
def _create_new_instance(
    dataset_id: str,
    timestamp: int,
    attribution_rule: str,
    client: PCGraphAPIClient,
    logger: logging.Logger,
) -> str:
    instance_id = json.loads(
        client.create_pa_instance(
            dataset_id,
            timestamp,
            attribution_rule,
            2,
        ).text)["id"]
    logger.info(
        f"Created instance {instance_id} for dataset {dataset_id} and attribution rule {attribution_rule}"
    )
    return instance_id
예제 #10
0
def run_attribution(
        config: Dict[str, Any],
        dataset_id: str,
        input_path: str,
        timestamp: str,
        attribution_rule: AttributionRule,
        aggregation_type: AggregationType,
        concurrency: int,
        num_files_per_mpc_container: int,
        k_anonymity_threshold: int,
        stage_flow: Type[PrivateComputationBaseStageFlow],
        logger: logging.Logger,
        num_tries: Optional[int] = 2,  # this is number of tries per stage
) -> None:

    ## Step 1: Validation. Function arguments and  for private attribution run.
    # obtain the values in the dataset info vector.
    client = PCGraphAPIClient(config, logger)
    datasets_info = _get_attribution_dataset_info(client, dataset_id, logger)
    datasets = datasets_info[DATASETS_INFORMATION]
    matched_data = {}
    attribution_rule_str = attribution_rule.name
    attribution_rule_val = attribution_rule.value
    instance_id = None
    pacific_timezone = pytz.timezone("US/Pacific")
    # Validate if input is datetime or timestamp
    is_date_format = _iso_date_validator(timestamp)
    if is_date_format:
        dt = pacific_timezone.localize(datetime.strptime(
            timestamp, "%Y-%m-%d"))
    else:
        dt = datetime.fromtimestamp(int(timestamp), tz=timezone.utc)

    # Compute the argument after the timestamp has been input
    dt_arg = int(datetime.timestamp(dt))

    # Verify that input has matching dataset info:
    # a. attribution rule
    # b. timestamp
    if len(datasets) == 0:
        raise ValueError("Dataset for given parameters and dataset invalid")
    for data in datasets:
        if data["key"] == attribution_rule_str:
            matched_attr = data["value"]

    for m_data in matched_attr:
        m_time = dateutil.parser.parse(m_data[TIMESTAMP])
        if m_time == dt:
            matched_data = m_data
            break
    if len(matched_data) == 0:
        raise ValueError("No dataset matching to the information provided")
    # Step 2: Validate what instances need to be created vs what already exist
    # Conditions for retry:
    # 1. Not in a terminal status
    # 2. Instance has been created > 1d ago
    dataset_instance_data = _get_existing_pa_instances(client, dataset_id)
    existing_instances = dataset_instance_data["data"]
    for inst in existing_instances:
        inst_time = dateutil.parser.parse(inst[TIMESTAMP])
        creation_time = dateutil.parser.parse(inst[CREATED_TIME])
        exp_time = datetime.now(tz=timezone.utc) - timedelta(days=1)
        expired = exp_time > creation_time
        if (inst[ATTRIBUTION_RULE] == attribution_rule_val and inst_time == dt
                and inst[STATUS] not in TERMINAL_STATUSES and not expired):
            instance_id = inst["id"]
            break

    if instance_id is None:
        instance_id = _create_new_instance(
            dataset_id,
            int(dt_arg),
            attribution_rule_val,
            client,
            logger,
        )
    instance_data = _get_pa_instance_info(client, instance_id, logger)
    _check_version(instance_data, config)
    num_pid_containers = instance_data[NUM_SHARDS]
    num_mpc_containers = instance_data[NUM_CONTAINERS]

    ## Step 3. Run Instances. Run maximum number of instances in parallel
    logger.info(f"Start running instance {instance_id}.")
    instance_parameters = {
        "config": config,
        "instance_id": instance_id,
        "input_path": input_path,
        "num_mpc_containers": num_mpc_containers,
        "num_pid_containers": num_pid_containers,
        "stage_flow": stage_flow,
        "logger": logger,
        "game_type": PrivateComputationGameType.ATTRIBUTION,
        "attribution_rule": attribution_rule,
        "aggregation_type": AggregationType.MEASUREMENT,
        "concurrency": concurrency,
        "num_files_per_mpc_container": num_files_per_mpc_container,
        "k_anonymity_threshold": k_anonymity_threshold,
        "num_tries": num_tries,
    }
    run_instance(**instance_parameters)
    logger.info(f"Finished running instances {instance_id}.")
예제 #11
0
def _get_existing_pa_instances(client: PCGraphAPIClient,
                               dataset_id: str) -> Any:
    return json.loads(client.get_existing_pa_instances(dataset_id).text)
예제 #12
0
def _get_pa_instance_info(client: PCGraphAPIClient, instance_id: str,
                          logger: logging.Logger) -> Any:
    return json.loads(client.get_instance(instance_id).text)
예제 #13
0
def run_study(
    config: Dict[str, Any],
    study_id: str,
    objective_ids: List[str],
    input_paths: List[str],
    logger: logging.Logger,
    stage_flow: Type[PrivateComputationBaseStageFlow],
    num_tries: Optional[int] = 2,  # this is number of tries per stage
    dry_run: Optional[bool] = False,  # if set to true, it will only run one stage
    result_visibility: Optional[ResultVisibility] = None,
) -> None:

    ## Step 1: Validation. Function arguments and study metadata must be valid for private lift run.
    _validate_input(objective_ids, input_paths)

    # obtain study information
    client = PCGraphAPIClient(config, logger)
    study_data = _get_study_data(study_id, client)

    # Verify study can run private lift:
    _verify_study_type(study_data)

    # verify mpc objectives
    _verify_mpc_objs(study_data, objective_ids)

    # verify study opp_data_information is non-empty
    if OPP_DATA_INFORMATION not in study_data:
        raise PCStudyValidationException(
            f"Study {study_id} has no opportunity datasets.",
            f"Check {study_id} study data to include {OPP_DATA_INFORMATION}",
        )

    ## Step 2. Preparation. Find which cell-obj pairs should have new instances created for and which should use existing
    ## valid ones. If a valid instance exists for a particular cell-obj pair, use it. Otherwise, try to create one.

    cell_obj_instance = _get_cell_obj_instance(
        study_data,
        objective_ids,
        input_paths,
    )
    _print_json(
        "Existing valid instances for cell-obj pairs", cell_obj_instance, logger
    )
    # create new instances
    _create_new_instances(cell_obj_instance, study_id, client, logger)
    _print_json("Instances to run for cell-obj pairs", cell_obj_instance, logger)
    # create a dict with {instance_id, input_path} pairs
    instances_input_path = _instance_to_input_path(cell_obj_instance)
    _print_json(
        "Instances will be calculated with corresponding input paths",
        instances_input_path,
        logger,
    )

    # check that the version in config.yml is same as from graph api
    _check_versions(cell_obj_instance, config, client)

    ## Step 3. Run Instances. Run maximum number of instances in parallel

    all_instance_ids = []
    chunks = _get_chunks(instances_input_path, MAX_NUM_INSTANCES)
    for chunk in chunks:
        instance_ids = list(chunk.keys())
        all_instance_ids.extend(instance_ids)
        chunk_input_paths = list(map(lambda x: x["input_path"], chunk.values()))
        chunk_num_shards = list(map(lambda x: x["num_shards"], chunk.values()))
        logger.info(f"Start running instances {instance_ids}.")
        run_instances(
            config,
            instance_ids,
            chunk_input_paths,
            chunk_num_shards,
            stage_flow,
            logger,
            num_tries,
            dry_run,
            result_visibility,
        )
        logger.info(f"Finished running instances {instance_ids}.")

    ## Step 4: Print out the initial and end states
    new_cell_obj_instances = _get_cell_obj_instance(
        _get_study_data(study_id, client), objective_ids, input_paths
    )
    _print_json(
        "Pre-run statuses for instance of each cell-objective pair",
        cell_obj_instance,
        logger,
    )
    _print_json(
        "Post-run statuses for instance of each cell-objective pair",
        new_cell_obj_instances,
        logger,
    )

    for instance_id in all_instance_ids:
        if (
            get_instance(config, instance_id, logger).status
            is not PrivateComputationInstanceStatus.AGGREGATION_COMPLETED
        ):
            raise OneCommandRunnerBaseException(
                f"{instance_id=} FAILED.",
                "Status is not aggregation completed",
                "Check logs for more information",
            )
예제 #14
0
 def test_get_graph_api_token_no_field(self) -> None:
     config = {"graphapi": {"random_field": "not_a_token"}}
     with self.assertRaises(GraphAPITokenNotFound):
         PCGraphAPIClient(config, self.mock_logger).access_token
예제 #15
0
 def test_get_graph_api_token_no_token_todo(self) -> None:
     config = {"graphapi": {"access_token": "TODO"}}
     with self.assertRaises(GraphAPITokenNotFound):
         PCGraphAPIClient(config, self.mock_logger).access_token
예제 #16
0
 def test_get_graph_api_token_from_dict(self) -> None:
     expected_token = "from_dict"
     config = {"graphapi": {"access_token": expected_token}}
     actual_token = PCGraphAPIClient(config, self.mock_logger).access_token
     self.assertEqual(expected_token, actual_token)