def test_minimal_pert_tuple_test(self) -> None:
        model = BasicModel_MultiLayer()
        text_inp = (
            [["abc", "zyd", "ghi"], ["abc", "uyd", "ghi"]],
            torch.ones((2, 3)) * -97,
        )

        model = BasicModel_MultiLayer()
        minimal_pert = MinParamPerturbation(
            forward_func=lambda x: model(*x),
            attack=lambda x, ind, char_val: (add_char_batch(x[0], ind, char_val), x[1]),
            arg_name="char_val",
            arg_min=0,
            arg_max=26,
            arg_step=1,
            preproc_fn=lambda x: (batch_text_preproc_fn(x[0]), x[1]),
            apply_before_preproc=True,
            correct_fn=alt_correct_fn,
        )
        expected_list = [["abc", "uzyd", "ghi"], ["abc", "uuyd", "ghi"]]

        target_inp, pert = minimal_pert.evaluate(
            text_inp,
            target=1,
            attack_kwargs={"ind": 1},
            correct_fn_kwargs={"threshold": 100},
            perturbations_per_eval=15,
        )
        self.assertEqual(pert, 5)
        self.assertListEqual(target_inp[0], expected_list)
    def test_minimal_pert_alt_correct(self) -> None:
        model = BasicModel_MultiLayer()
        text_inp = ["abc", "zyd", "ghi"]
        minimal_pert = MinParamPerturbation(
            forward_func=model,
            attack=add_char,
            arg_name="char_val",
            arg_min=0,
            arg_max=26,
            arg_step=1,
            preproc_fn=text_preproc_fn,
            apply_before_preproc=True,
            correct_fn=alt_correct_fn,
            num_attempts=5,
        )
        expected_list = ["abc", "ezyd", "ghi"]

        target_inp, pert = minimal_pert.evaluate(
            text_inp,
            target=1,
            attack_kwargs={"ind": 1},
            correct_fn_kwargs={"threshold": 1200},
            perturbations_per_eval=5,
        )
        self.assertEqual(pert, 21)
        self.assertListEqual(target_inp, expected_list)

        target_inp_single, pert_single = minimal_pert.evaluate(
            text_inp,
            target=1,
            attack_kwargs={"ind": 1},
            correct_fn_kwargs={"threshold": 1200},
        )
        self.assertEqual(pert_single, 21)
        self.assertListEqual(target_inp_single, expected_list)
 def test_minimal_pert_preproc(self) -> None:
     model = BasicModel_MultiLayer()
     text_inp = ["abc", "zyd", "ghi"]
     minimal_pert = MinParamPerturbation(
         forward_func=model,
         attack=add_char,
         arg_name="char_val",
         arg_min=0,
         arg_max=26,
         arg_step=1,
         preproc_fn=text_preproc_fn,
         apply_before_preproc=True,
     )
     target_inp, pert = minimal_pert.evaluate(
         text_inp, target=1, attack_kwargs={"ind": 1}
     )
     self.assertEqual(pert, None)
     self.assertEqual(target_inp, None)
 def test_minimal_pert_basic_linear(self) -> None:
     model = BasicModel()
     inp = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]])
     minimal_pert = MinParamPerturbation(
         forward_func=lambda x: model(x) + torch.tensor(
             [[0.000001, 0.0, 0.0, 0.0, 0.0]]),
         attack=inp_subtract,
         arg_name="add_arg",
         arg_min=0.0,
         arg_max=1000.0,
         arg_step=1.0,
     )
     target_inp, pert = minimal_pert.evaluate(inp,
                                              target=0,
                                              attack_kwargs={"ind": 0})
     self.assertAlmostEqual(cast(float, pert), 2.0)
     assertTensorAlmostEqual(self, target_inp,
                             torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]]))
    def test_minimal_pert_additional_forward_args(self) -> None:
        model = BasicModel_MultiLayer()
        text_inp = [["abc", "zyd", "ghi"], ["abc", "uyd", "ghi"]]
        additional_forward_args = torch.ones((2, 3)) * -97

        model = BasicModel_MultiLayer()
        minimal_pert = MinParamPerturbation(
            forward_func=model,
            attack=add_char_batch,
            arg_name="char_val",
            arg_min=0,
            arg_max=26,
            arg_step=1,
            preproc_fn=batch_text_preproc_fn,
            apply_before_preproc=True,
            correct_fn=alt_correct_fn,
        )
        expected_list = [["abc", "uzyd", "ghi"], ["abc", "uuyd", "ghi"]]

        target_inp, pert = minimal_pert.evaluate(
            text_inp,
            target=1,
            attack_kwargs={"ind": 1},
            correct_fn_kwargs={"threshold": 100},
            perturbations_per_eval=15,
            additional_forward_args=(additional_forward_args,),
        )
        self.assertEqual(pert, 5)
        self.assertListEqual(target_inp, expected_list)

        target_inp_single, pert_single = minimal_pert.evaluate(
            text_inp,
            target=1,
            attack_kwargs={"ind": 1},
            correct_fn_kwargs={"threshold": 100},
            additional_forward_args=(additional_forward_args,),
        )
        self.assertEqual(pert_single, 5)
        self.assertListEqual(target_inp_single, expected_list)