def _run_attack(attack_input: AttackInputData, attack_type: AttackType, balance_attacker_training: bool = True, min_num_samples: int = 1): """Runs membership inference attacks for specified input and type. Args: attack_input: input data for running an attack attack_type: the attack to run balance_attacker_training: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the model under attack. min_num_samples: minimum number of examples in either training or test data. Returns: the attack result. """ attack_input.validate() if min(attack_input.get_train_size(), attack_input.get_test_size()) < min_num_samples: return None if attack_type.is_trained_attack: return _run_trained_attack(attack_input, attack_type, balance_attacker_training) if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK: return _run_threshold_entropy_attack(attack_input) return _run_threshold_attack(attack_input)
def test_get_probs_sizes(self): attack_input = AttackInputData( probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]), probs_test=np.array([[0, 0.0001, 0.9999]]), labels_train=np.array([1, 0]), labels_test=np.array([0])) np.testing.assert_equal(attack_input.get_train_size(), 2) np.testing.assert_equal(attack_input.get_test_size(), 1)
def _run_threshold_entropy_attack(attack_input: AttackInputData): ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((attack_input.get_entropy_train(), attack_input.get_entropy_test()))) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), data_size=DataSize(ntrain=ntrain, ntest=ntest), attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK, membership_scores_train=-attack_input.get_entropy_train(), membership_scores_test=-attack_input.get_entropy_test(), roc_curve=roc_curve)
def _run_threshold_attack(attack_input: AttackInputData): """Runs a threshold attack on loss.""" ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() loss_train = attack_input.get_loss_train() loss_test = attack_input.get_loss_test() if loss_train is None or loss_test is None: raise ValueError( 'Not possible to run threshold attack without losses.') fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((loss_train, loss_test))) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), data_size=DataSize(ntrain=ntrain, ntest=ntest), attack_type=AttackType.THRESHOLD_ATTACK, membership_scores_train=attack_input.get_loss_train(), membership_scores_test=attack_input.get_loss_test(), roc_curve=roc_curve)