def test_run_attack_threshold_calculates_correct_auc(self):
    result = mia.run_attack(
        AttackInputData(
            loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
            loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
        AttackType.THRESHOLD_ATTACK)

    np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
    def test_run_attack_threshold_sets_attack_type(self):
        result = mia.run_attack(get_test_input(100, 100),
                                AttackType.THRESHOLD_ATTACK)

        self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
    def test_run_attack_trained_sets_attack_type(self):
        result = mia.run_attack(get_test_input(100, 100),
                                AttackType.LOGISTIC_REGRESSION)

        self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)