def test_python_compatability(self):
        """
        Test compatibility of functions with native python types.
        """
        # Just run and check these call don't throw an error
        bootstrap_test(set(self.scores_a), set(self.scores_b))
        permutation_test(set(self.scores_a), set(self.scores_b))
        aso(
            set(self.scores_a),
            set(self.scores_b),
            num_samples=1,
            num_bootstrap_iterations=1,
            show_progress=False,
        )
        bonferroni_correction(set(self.p_values))

        bootstrap_test(tuple(self.scores_a), tuple(self.scores_b))
        permutation_test(tuple(self.scores_a), tuple(self.scores_b))
        aso(
            tuple(self.scores_a),
            tuple(self.scores_b),
            num_samples=1,
            num_bootstrap_iterations=1,
            show_progress=False,
        )
        bonferroni_correction(tuple(self.p_values))
    def test_bonferroni_correction(self):
        """
        Test whether the Bonferroni correction works as expected.
        """
        # 1. Test single p-value
        p_values1 = np.random.rand(1)
        self.assertEqual(p_values1, bonferroni_correction(p_values1))

        # 2. Test identical p-values
        p_values2 = np.ones(5) * np.random.rand(1) / 5
        corrected_p_values2 = bonferroni_correction(p_values2)
        self.assertTrue(
            (p_values2 * np.arange(5, 0, -1) == corrected_p_values2).all())

        # 3. Test different p-values
        p_values3 = np.random.rand(5) / 5
        p_values3.sort(
        )  # Sort values here already so that the multiplication with np.arange for the test works
        corrected_p_values3 = bonferroni_correction(p_values3)
        self.assertTrue(
            (p_values3 * np.arange(5, 0, -1) == corrected_p_values3).all())

        # Make sure absurdly high p-values don't corrected over 1
        p_values4 = np.ones(4) - 1e-4
        for p_corrected in bonferroni_correction(p_values4):
            self.assertAlmostEqual(p_corrected, 1, delta=0.01)
    def test_numpy_array_shapes(self):
        """
        Test different numpy array shapes.
        """
        # These should work
        correct1_scores_a = np.array(self.scores_a)
        correct1_scores_b = np.array(self.scores_b)
        correct2_scores_a = correct1_scores_a[..., np.newaxis]
        correct2_scores_b = correct1_scores_b[..., np.newaxis]

        for test_func in [bootstrap_test, permutation_test, self._aso_wrapper]:
            for scores_a, scores_b in [
                (correct1_scores_a, correct1_scores_b),
                (correct2_scores_a, correct2_scores_b),
            ]:
                test_func(scores_a, scores_b)

        correct1_p_values = np.array(self.p_values)
        correct2_p_values = correct1_p_values[..., np.newaxis]

        bonferroni_correction(correct1_p_values)
        bonferroni_correction(correct2_p_values)

        # These should fail
        incorrect_scores_a = np.ones((2, 4)) * self.scores_a
        incorrect_scores_b = np.ones((2, 4)) * self.scores_b
        incorrect_p_values = np.ones((2, 3)) * self.p_values

        for test_func in [bootstrap_test, permutation_test, self._aso_wrapper]:
            with self.assertRaises(TypeError):
                test_func(incorrect_scores_a, incorrect_scores_b)

        with self.assertRaises(TypeError):
            bonferroni_correction(incorrect_p_values)
    def test_assertions(self):
        """
        Make sure that invalid input arguments raise an error.
        """

        with self.assertRaises(AssertionError):
            bonferroni_correction([])

        with self.assertRaises(AssertionError):
            bonferroni_correction([-0.4, 0.5])

        with self.assertRaises(AssertionError):
            bonferroni_correction([0.3, 1.2])
    def test_pytorch_compatibility(self):
        """
        Test compatibility of functions with PyTorch tensors.
        """
        try:
            import torch

            # These should work
            correct1_scores_a = torch.FloatTensor(self.scores_a)
            correct1_scores_b = torch.FloatTensor(self.scores_b)
            correct2_scores_a = correct1_scores_a.unsqueeze(dim=1)
            correct2_scores_b = correct1_scores_b.unsqueeze(dim=1)

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                for scores_a, scores_b in [
                    (correct1_scores_a, correct1_scores_b),
                    (correct2_scores_a, correct2_scores_b),
                ]:
                    test_func(scores_a, scores_b)

            correct1_p_values = torch.FloatTensor(self.p_values)
            correct2_p_values = correct1_p_values.unsqueeze(dim=1)

            bonferroni_correction(correct1_p_values)
            bonferroni_correction(correct2_p_values)

            # These shouldn't
            incorrect_scores_a = torch.ones((2, 4)) * correct1_scores_a
            incorrect_scores_b = torch.ones((2, 4)) * correct1_scores_b
            incorrect_p_values = torch.ones((2, 3)) * correct1_p_values

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                with self.assertRaises(TypeError):
                    test_func(incorrect_scores_a, incorrect_scores_b)

            with self.assertRaises(TypeError):
                bonferroni_correction(incorrect_p_values)

        except ImportError:
            pass
    def test_jax_compatibility(self):
        """
        Test compatibility of functions with Jax arrays.
        """
        try:
            import jax.numpy as jnp

            # These should work
            correct1_scores_a = jnp.asarray(self.scores_a)
            correct1_scores_b = jnp.asarray(self.scores_b)
            correct2_scores_a = correct1_scores_a[..., jnp.newaxis]
            correct2_scores_b = correct1_scores_b[..., jnp.newaxis]

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                for scores_a, scores_b in [
                    (correct1_scores_a, correct1_scores_b),
                    (correct2_scores_a, correct2_scores_b),
                ]:
                    test_func(scores_a, scores_b)

            correct1_p_values = jnp.asarray(self.p_values)
            correct2_p_values = correct1_p_values[..., jnp.newaxis]

            bonferroni_correction(correct1_p_values)
            bonferroni_correction(correct2_p_values)

            # These shouldn't
            incorrect_scores_a = jnp.ones((2, 4)) * correct1_scores_a
            incorrect_scores_b = jnp.ones((2, 4)) * correct1_scores_b
            incorrect_p_values = jnp.ones((2, 3)) * correct1_p_values

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                with self.assertRaises(TypeError):
                    test_func(incorrect_scores_a, incorrect_scores_b)

            with self.assertRaises(TypeError):
                bonferroni_correction(incorrect_p_values)

        except ImportError:
            pass
    def test_tensorflow_compatibility(self):
        """
        Test compatibility of functions with Tensorflow tensors.
        """
        try:
            import tensorflow as tf

            # These should work
            correct1_scores_a = tf.convert_to_tensor(self.scores_a)
            correct1_scores_b = tf.convert_to_tensor(self.scores_b)
            correct2_scores_a = tf.expand_dims(correct1_scores_a, axis=1)
            correct2_scores_b = tf.expand_dims(correct1_scores_b, axis=1)

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                for scores_a, scores_b in [
                    (correct1_scores_a, correct1_scores_b),
                    (correct2_scores_a, correct2_scores_b),
                ]:
                    test_func(scores_a, scores_b)

            correct1_p_values = tf.convert_to_tensor(self.p_values)
            correct2_p_values = tf.expand_dims(correct1_p_values, axis=1)

            bonferroni_correction(correct1_p_values)
            bonferroni_correction(correct2_p_values)

            # These shouldn't
            incorrect_scores_a = tf.ones((2, 4)) * correct1_scores_a
            incorrect_scores_b = tf.ones((2, 4)) * correct1_scores_b
            incorrect_p_values = tf.ones((2, 3)) * correct1_p_values

            for test_func in [
                    bootstrap_test, permutation_test, self._aso_wrapper
            ]:
                with self.assertRaises(TypeError):
                    test_func(incorrect_scores_a, incorrect_scores_b)

            with self.assertRaises(TypeError):
                bonferroni_correction(incorrect_p_values)

        except ImportError:
            pass