def test_incompatible_target(self) -> None:
        dataframe = self._load_data(bad_features=[4])

        hyperparams_class = MIRanking.metadata.query()["primitive_code"][
            "class_type_arguments"
        ]["Hyperparams"]
        hyperparams = hyperparams_class.defaults().replace({"target_col_index": 4})
        mi_ranking = MIRanking(hyperparams=hyperparams)
        result_dataframe = mi_ranking.produce(inputs=dataframe).value

        # verify the output
        self.assertListEqual(list(result_dataframe["idx"]), [])
        self.assertListEqual(list(result_dataframe["name"]), [])
        self.assertListEqual(list(result_dataframe["rank"]), [])
    def test_continuous_target(self) -> None:
        dataframe = self._load_data()

        hyperparams_class = MIRanking.metadata.query()["primitive_code"][
            "class_type_arguments"
        ]["Hyperparams"]
        hyperparams = hyperparams_class.defaults().replace({"target_col_index": 2})
        mi_ranking = MIRanking(hyperparams=hyperparams)
        result_dataframe = mi_ranking.produce(inputs=dataframe).value

        # verify the output
        self.assertListEqual(list(result_dataframe["idx"]), [1, 5, 4, 3])
        self.assertListEqual(
            list(result_dataframe["name"]), ["alpha", "echo", "delta", "charlie"]
        )
        expected_ranks = [1.0, 0.930536, 7.316753e-16, 0.0]
        for i, r in enumerate(result_dataframe["rank"]):
            self.assertAlmostEqual(r, expected_ranks[i], places=6)
    def test_unique_categorical_removed(self) -> None:
        dataframe = self._load_data()

        hyperparams_class = MIRanking.metadata.query()["primitive_code"][
            "class_type_arguments"
        ]["Hyperparams"]
        hyperparams = hyperparams_class.defaults().replace({"target_col_index": 1})
        dataframe.insert(len(dataframe.columns), "removed", [1, 2, 3, 4, 5, 6, 7, 8, 9])
        dataframe.metadata = dataframe.metadata.add_semantic_type(
            (metadata_base.ALL_ELEMENTS, 1),
            "https://metadata.datadrivendiscovery.org/types/CategoricalData",
        )
        mi_ranking = MIRanking(hyperparams=hyperparams)
        result_dataframe = mi_ranking.produce(inputs=dataframe).value

        # verify the output
        self.assertListEqual(list(result_dataframe["idx"]), [2, 5, 4, 3])
        self.assertListEqual(
            list(result_dataframe["name"]), ["bravo", "echo", "delta", "charlie"]
        )
    def test_discrete_target_rank_in_metadata(self) -> None:
        dataframe = self._load_data()

        hyperparams_class = MIRanking.metadata.query()["primitive_code"][
            "class_type_arguments"
        ]["Hyperparams"]
        hyperparams = hyperparams_class.defaults().replace(
            {"target_col_index": 1, "return_as_metadata": True}
        )
        mi_ranking = MIRanking(hyperparams=hyperparams)
        result_dataframe = mi_ranking.produce(inputs=dataframe).value

        expected_ranks = [1.0, 1.0, 1.0, 0.0]
        ranked_cols = [2, 5, 4, 3]
        for i in range(len(ranked_cols)):
            self.assertAlmostEqual(
                result_dataframe.metadata.query(
                    (metadata_base.ALL_ELEMENTS, ranked_cols[i])
                ).get("rank"),
                expected_ranks[i],
                places=6,
            )
    def test_full_mutual_info(self) -> None:
        dataframe = self._load_data(alpha_class=float, charlie_class=float)

        hyperparams_class = MIRanking.metadata.query()["primitive_code"][
            "class_type_arguments"
        ]["Hyperparams"]
        hyperparams = hyperparams_class.defaults().replace({"target_col_index": 2})
        mi_ranking = MIRanking(hyperparams=hyperparams)
        # dataframe.drop(columns=['charlie', 'delta', 'echo'], inplace=True)
        # dataframe['bravo'][dataframe['bravo'] == 4] = 3
        dataframe["bravo"] = [1, 0.5, 10, 6.7, 6.9, 2.3, 5.5, 7.3, 9]
        dataframe["alpha"] = [1, 0.5, 10, 6.7, 6.9, 2.3, 5.5, 7.3, 9]
        dataframe["charlie"] = [2, 2.5, 9, 5.7, 7.9, 1.3, 6.5, 8.3, 8]
        result_dataframe = mi_ranking.produce(inputs=dataframe).value

        # verify the output
        self.assertListEqual(list(result_dataframe["idx"]), [1, 3, 4, 5])
        self.assertListEqual(
            list(result_dataframe["name"]), ["alpha", "charlie", "delta", "echo"]
        )
        expected_ranks = [1.0, 0.665361, 0.0, 0.0]
        for i, r in enumerate(result_dataframe["rank"]):
            self.assertAlmostEqual(r, expected_ranks[i], places=6)