def run_membership_probability_analysis(
        attack_input: AttackInputData,
        slicing_spec: SlicingSpec = None) -> MembershipProbabilityResults:
    """Perform membership probability analysis on all given slice types.

  Args:
    attack_input: input data for compute membership probabilities
    slicing_spec: specifies attack_input slices

  Returns:
    the membership probability results.
  """
    attack_input.validate()
    membership_prob_results = []

    if slicing_spec is None:
        slicing_spec = SlicingSpec(entire_dataset=True)
    num_classes = None
    if slicing_spec.by_class:
        num_classes = attack_input.num_classes
    input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
    for single_slice_spec in input_slice_specs:
        attack_input_slice = get_slice(attack_input, single_slice_spec)
        membership_prob_results.append(
            _compute_membership_probability(attack_input_slice))

    return MembershipProbabilityResults(
        membership_prob_results=membership_prob_results)
def _run_attack(attack_input: AttackInputData,
                attack_type: AttackType,
                balance_attacker_training: bool = True,
                min_num_samples: int = 1):
    """Runs membership inference attacks for specified input and type.

  Args:
    attack_input: input data for running an attack
    attack_type: the attack to run
    balance_attacker_training: Whether the training and test sets for the
      membership inference attacker should have a balanced (roughly equal)
      number of samples from the training and test sets used to develop the
      model under attack.
    min_num_samples: minimum number of examples in either training or test data.

  Returns:
    the attack result.
  """
    attack_input.validate()
    if min(attack_input.get_train_size(),
           attack_input.get_test_size()) < min_num_samples:
        return None

    if attack_type.is_trained_attack:
        return _run_trained_attack(attack_input, attack_type,
                                   balance_attacker_training)
    if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK:
        return _run_threshold_entropy_attack(attack_input)
    return _run_threshold_attack(attack_input)
def run_attacks(attack_input: AttackInputData,
                slicing_spec: SlicingSpec = None,
                attack_types: Iterable[AttackType] = (
                    AttackType.THRESHOLD_ATTACK, ),
                privacy_report_metadata: PrivacyReportMetadata = None,
                balance_attacker_training: bool = True,
                min_num_samples: int = 1) -> AttackResults:
    """Runs membership inference attacks on a classification model.

  It runs attacks specified by attack_types on each attack_input slice which is
   specified by slicing_spec.

  Args:
    attack_input: input data for running an attack
    slicing_spec: specifies attack_input slices to run attack on
    attack_types: attacks to run
    privacy_report_metadata: the metadata of the model under attack.
    balance_attacker_training: Whether the training and test sets for the
      membership inference attacker should have a balanced (roughly equal)
      number of samples from the training and test sets used to develop the
      model under attack.
    min_num_samples: minimum number of examples in either training or test data.

  Returns:
    the attack result.
  """
    attack_input.validate()
    attack_results = []

    if slicing_spec is None:
        slicing_spec = SlicingSpec(entire_dataset=True)
    num_classes = None
    if slicing_spec.by_class:
        num_classes = attack_input.num_classes
    input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
    for single_slice_spec in input_slice_specs:
        attack_input_slice = get_slice(attack_input, single_slice_spec)
        for attack_type in attack_types:
            attack_result = _run_attack(attack_input_slice, attack_type,
                                        balance_attacker_training,
                                        min_num_samples)
            if attack_result is not None:
                attack_results.append(attack_result)

    privacy_report_metadata = _compute_missing_privacy_report_metadata(
        privacy_report_metadata, attack_input)

    return AttackResults(single_attack_results=attack_results,
                         privacy_report_metadata=privacy_report_metadata)