def test_merge_df(self):
        cluster = Clusters(
            "../tests/data/test_1k_cluster_data.csv",
            "dec0dedfeed1111111111111",
            "addedfeed000000000000000",
            "testCluster",
        )
        cell_metadata_df = Annotations(
            self.CELL_METADATA_PATH,
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        cell_metadata_df.preprocess()
        cell_names_cell_metadata_df = np.asarray(cell_metadata_df.file["NAME"])
        cell_names_cluster_df = np.asarray(cluster.file["NAME"])
        # Cell names found in both cluster and metadata files
        common_cell_names = cell_names_cluster_df[
            np.isin(cell_names_cluster_df, cell_names_cell_metadata_df)
        ]
        print(f"common cell names: {common_cell_names}")
        # Perform merge
        print(cluster.file[["NAME", "x", "y", "z"]])
        cluster.merge_df(cluster.file[["NAME", "x", "y", "z"]], cell_metadata_df.file)

        # Ensure ONLY common cell names found in cell metadata file and cluster file
        # are in the newly merged df
        result = all(
            cell[0] in common_cell_names for cell in cluster.file["NAME"].values
        )
        self.assertTrue(
            result,
            f"Merge was not performed correctly. Merge should be performed on 'NAME'",
        )
    def test_low_mem_artifact(self):
        # pandas default of low_memory=True allows internal chunking during parsing
        # causing inconsistent dtype coercion artifact for larger annotation files

        lmtest = Annotations(
            "../tests/data/low_mem_unit.txt",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        lmtest.preprocess()

        # when low memory=True, the first row in the file would be in the first chunk
        # and the numeric value was not properly coerced to become a string
        assert isinstance(
            lmtest.file["mixed_data"]["group"][0], str
        ), "numeric value should be coerced to string"

        # Per SCP-2545 NA values become strings for group annotations.
        print(lmtest.file["mixed_data"]["group"][2])
        print(type(lmtest.file["mixed_data"]["group"][2]))
        assert isinstance(
            lmtest.file["mixed_data"]["group"][2], str
        ), "expect empty cell conversion to NaN is string for group annotation"

        # numeric value in second chunk should still properly be coerced to string type
        assert isinstance(
            lmtest.file["mixed_data"]["group"][32800], str
        ), "numeric value should be coerced to string"
 def test_leading_zeros(self):
     """Ensures leading zeros are not stripped from group annotations"""
     path = "../tests/data/metadata_convention_with_leading_0s.tsv"
     annotation = Annotations(
         path, ["text/csv", "text/plain", "text/tab-separated-values"]
     )
     annotation.preprocess()
     # Grab value from donor id column.
     value_with_leading_zeros = annotation.file.iloc[
         :, annotation.file.columns.get_level_values(0) == "donor_id"
     ].values.item(0)
     self.assertTrue(value_with_leading_zeros.startswith("0"))
    def test_duplicate_headers(self):
        """Annotation headers should not contain duplicate values
        """
        dup_headers = Annotations(
            "../tests/data/dup_headers_v2.0.0.tsv",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )

        self.assertFalse(
            dup_headers.validate_unique_header(),
            "Duplicate headers should fail format validation",
        )

        with self.assertRaises(ValueError):
            dup_headers.preprocess()
class SubSample(Annotations):
    ALLOWED_FILE_TYPES = [
        "text/csv", "text/plain", "text/tab-separated-values"
    ]
    MAX_THRESHOLD = 100_000
    SUBSAMPLE_THRESHOLDS = [MAX_THRESHOLD, 20_000, 10_000, 1_000]

    def __init__(self, cluster_file, cell_metadata_file=None):
        Annotations.__init__(self, cluster_file, self.ALLOWED_FILE_TYPES)
        self.preprocess()
        self.determine_coordinates_and_cell_names()
        if cell_metadata_file is not None:
            self.cell_metadata = Annotations(cell_metadata_file,
                                             CellMetadata.ALLOWED_FILE_TYPES)

    @staticmethod
    def has_cells_in_metadata_file(metadata_cells, cluster_cells):
        """Checks if cells in cluster are in metadata cells"""
        return set(cluster_cells).issubset(set(metadata_cells))

    def prepare_cell_metadata(self):
        """ Does an inner join on cell and cluster file """
        if self.cell_metadata is not None:
            self.cell_metadata.preprocess()
            self.merge_df(self.file[self.coordinates_and_cell_headers],
                          self.cell_metadata.file)
            self.determine_coordinates_and_cell_names()

    def bin(self, annotation: Tuple[str, str], scope: str):
        """Creates bins for a given group

        Args:
            annotation: Tuple[str, str]
                This is the annotation for a single column. For example annotation
                would look like ('annotation_name', 'numeric') or ('annotation_name', 'group')

        Returns:
            bin: Tuple[Dict[str: dataframe]], Tuple[str, str]]
                The first tuple contains all the bins for a given column/
                annotation. It would look like {'unique_value1': filtered dataframe where rows=unique_value1}
                for group values and there can be up to 20 bins for numeric columns.
                The second value in the tuple is structured exactly like the input value.
            """
        bin = {}
        # sample the annotation along with coordinates and cell names
        columns_to_sample = copy.copy(self.coordinates_and_cell_headers)
        if scope == "cluster":
            columns_to_sample.append(annotation[0])
        if "group" in annotation:
            # get unique values in column
            unique_values = self.file[annotation].unique()

            for col_val in unique_values:
                # get subset of data where row is equal to the unique value
                subset = self.file[self.file[annotation] == col_val]
                bin[col_val] = subset[columns_to_sample]
        else:
            columns = copy.copy(self.coordinates_and_cell_headers)
            # coordinates, cell names and annotation name
            columns.append(annotation[0])
            # Subset of df where header is [cell_names, x, y, z, <annot_name>]
            subset = self.file[columns].copy()
            subset.sort_values(by=[annotation], inplace=True)
            # Generates 20 bins
            for index, df in enumerate(np.array_split(subset, 20)):
                bin[str(index)] = df[columns_to_sample]
        return bin, annotation

    def subsample(self, scope):
        """Subsamples groups across a given file"""
        sample_sizes = [
            sample_size for sample_size in self.SUBSAMPLE_THRESHOLDS
            if sample_size < len(self.file.index)
        ]
        for bins in [
                self.bin(col, scope) for col in self.annot_column_headers
        ]:

            amount_of_bins = len(bins[0].keys())
            # (name of current column)
            annotation_name = bins[1]
            # Holds bins for annotation
            # Looks like {"Unique value #1" : dataframe, "Unique value #2": dataframe,...}
            annotation_dict = bins[0]
            for sample_size in sample_sizes:
                group_size = len(annotation_dict.keys())
                # Dict of values for the x, y, and z coordinates
                points = {k: [] for k in self.coordinates_and_cell_headers}
                if scope == "cluster":
                    points[annotation_name[0]] = []
                num_per_group = int(sample_size / group_size)
                cells_left = sample_size
                # bin = ("unique value in column" : dataframe)
                for idx, bin in enumerate(
                        self.return_sorted_bin(annotation_dict,
                                               annotation_name)):
                    amount_of_rows = len(bin[1].index)
                    # If the amount of sampled values is larger
                    # than the whole array, take the whole array
                    if num_per_group > amount_of_rows:
                        amount_picked_rows = amount_of_rows
                    else:
                        amount_picked_rows = num_per_group
                    shuffled_df = (bin[1].reindex(
                        np.random.permutation(
                            bin[1].index)).sample(n=amount_picked_rows))
                    for column in shuffled_df:
                        points[column[0]].extend(
                            shuffled_df[column].values.tolist())
                    # add the current observed annotation to the points dict the amount
                    # of times it has been sampled
                    # points[annotation_name] = [bin[0] for i in range(amount_picked_rows)]
                    # Subtract number of cells 'subsampled' from the number of cells left
                    cells_left -= amount_picked_rows
                    # For last bin sample the number of cells left over
                    # Subtract 2 because 0 based
                    if idx == (amount_of_bins - 2):
                        num_per_group = cells_left
                    else:
                        group_size -= 1
                        if group_size > 1:
                            num_per_group = int(cells_left / group_size)
                # returns tuple = (subsampled values as dictionary, annotation name, sample size )
                yield (points, annotation_name, sample_size)

    def return_sorted_bin(self, bin, annot_name):
        """Sorts binned groups in order of size from smallest to largest for group annotations """

        if "group" in annot_name:
            return sorted(bin.items(), key=lambda x: len(x[1]))
        else:
            return bin.items()

    def set_data_array(self, args, kwargs):
        return Clusters.set_data_array(*args, **kwargs)
class TestAnnotations(unittest.TestCase):
    CLUSTER_PATH = "../tests/data/test_1k_cluster_data.csv"
    CELL_METADATA_PATH = "data/annotation/metadata/convention/valid_no_array_v2.0.0.txt"

    ALLOWED_FILE_TYPES = ["text/csv", "text/plain", "text/tab-separated-values"]

    EXPONENT = -3

    def setUp(self):
        self.df = Annotations(
            self.CLUSTER_PATH, ["text/csv", "text/plain", "text/tab-separated-values"]
        )

    def test_create_columns(self):
        header = ["Intensity", "donor_id", "species__ontology_label"]
        annotatiion_types = ["numeric", "group", "group"]
        colums = Annotations.create_columns(header, annotatiion_types)
        expected = [
            ("Intensity", "numeric"),
            ("donor_id", "group"),
            ("species__ontology_label", "group"),
        ]
        self.assertEqual(colums, expected)

    def test_duplicate_headers(self):
        """Annotation headers should not contain duplicate values
        """
        dup_headers = Annotations(
            "../tests/data/dup_headers_v2.0.0.tsv",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )

        self.assertFalse(
            dup_headers.validate_unique_header(),
            "Duplicate headers should fail format validation",
        )

        with self.assertRaises(ValueError):
            dup_headers.preprocess()

    def test_get_dtypes_for_group_annots(self):
        headers = ["NAME", "cell_type", "organism_age"]
        annot_types = ["TYPE", "group", "numeric"]
        expected_dtypes = {"NAME": np.str, "cell_type": np.str}
        dtypes = Annotations.get_dtypes_for_group_annots(headers, annot_types)
        self.assertEqual(expected_dtypes, dtypes)

    def test_convert_header_to_multiIndex(self):
        expected = [
            ("Name", "TYPE"),
            ("X", "numeric"),
            ("Y", "numeric"),
            ("Z", "numeric"),
            ("Average Intensity", "numeric"),
        ]
        path = "../tests/data/good_subsample_cluster.csv"
        annotation = Annotations(
            path, ["text/csv", "text/plain", "text/tab-separated-values"]
        )
        df = annotation.open_file(
            path, open_as="dataframe", skiprows=2, names=annotation.headers
        )[0]
        new_df = Annotations.convert_header_to_multi_index(df, expected)
        # Remove white spaces
        new_df_columns = [tuple(s.strip() for s in y) for y in new_df.columns]
        self.assertEqual(new_df_columns, expected)

    def test_leading_zeros(self):
        """Ensures leading zeros are not stripped from group annotations"""
        path = "../tests/data/metadata_convention_with_leading_0s.tsv"
        annotation = Annotations(
            path, ["text/csv", "text/plain", "text/tab-separated-values"]
        )
        annotation.preprocess()
        # Grab value from donor id column.
        value_with_leading_zeros = annotation.file.iloc[
            :, annotation.file.columns.get_level_values(0) == "donor_id"
        ].values.item(0)
        self.assertTrue(value_with_leading_zeros.startswith("0"))

    def test_header_format(self):
        """Header rows of metadata file should conform to standard
        """
        error_headers = Annotations(
            "../tests/data/error_headers_v2.0.0.tsv",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )

        self.assertFalse(
            error_headers.validate_header_keyword(),
            "Missing NAME keyword should fail format validation",
        )

        self.assertFalse(
            error_headers.validate_type_keyword(),
            "Missing TYPE keyword should fail format validation",
        )

        self.assertFalse(
            error_headers.validate_type_annotations(),
            "Invalid type annotations should fail format validation",
        )

    def test_low_mem_artifact(self):
        # pandas default of low_memory=True allows internal chunking during parsing
        # causing inconsistent dtype coercion artifact for larger annotation files

        lmtest = Annotations(
            "../tests/data/low_mem_unit.txt",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        lmtest.preprocess()

        # when low memory=True, the first row in the file would be in the first chunk
        # and the numeric value was not properly coerced to become a string
        assert isinstance(
            lmtest.file["mixed_data"]["group"][0], str
        ), "numeric value should be coerced to string"

        # Per SCP-2545 NA values become strings for group annotations.
        print(lmtest.file["mixed_data"]["group"][2])
        print(type(lmtest.file["mixed_data"]["group"][2]))
        assert isinstance(
            lmtest.file["mixed_data"]["group"][2], str
        ), "expect empty cell conversion to NaN is string for group annotation"

        # numeric value in second chunk should still properly be coerced to string type
        assert isinstance(
            lmtest.file["mixed_data"]["group"][32800], str
        ), "numeric value should be coerced to string"

    def test_coerce_numeric_values(self):
        cm = Annotations(
            "../tests/data/metadata_example.txt",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        cm.create_data_frame()
        cm.file = Annotations.coerce_numeric_values(cm.file, cm.annot_types)
        dtype = cm.file.dtypes[("Average Intensity", "numeric")]
        self.assertEqual(dtype, np.float)

        # Test that numeric values wer
        # Pick a random number between 1 and amount of lines in file
        ran_num = random.randint(1, 20)
        for column in cm.file.columns:
            annot_type = column[1]
            if annot_type == "numeric":
                value = str(cm.file[column][ran_num])
                print(Decimal(value).as_tuple().exponent)
                assert (
                    abs(Decimal(value).as_tuple().exponent) >= self.EXPONENT
                ), "Numbers did not round to 3 or less decimals places"

        # Test for string in numeric column
        cm_has_bad_value = Annotations(
            "../tests/data/metadata_bad_contains_str_in_numeric_column.txt",
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        cm_has_bad_value.create_data_frame()
        self.assertRaises(
            ValueError,
            Annotations.coerce_numeric_values,
            cm_has_bad_value.file,
            cm_has_bad_value.annot_types,
        )

    def test_group_annotations(self):
        self.df.preprocess()
        for column in self.df.file.columns:
            # Ensure labels are strings
            header = column[0]
            assert isinstance(header, str)
            annot_type = column[1]
            if annot_type == "group":
                # corrected testings of dataframe column dtype, using != always returns True
                self.assertFalse(
                    np.issubdtype(self.df.file[column].dtypes, np.number),
                    "Group annotations must be string values",
                )

    def test_merge_df(self):
        cluster = Clusters(
            "../tests/data/test_1k_cluster_data.csv",
            "dec0dedfeed1111111111111",
            "addedfeed000000000000000",
            "testCluster",
        )
        cell_metadata_df = Annotations(
            self.CELL_METADATA_PATH,
            ["text/csv", "text/plain", "text/tab-separated-values"],
        )
        cell_metadata_df.preprocess()
        cell_names_cell_metadata_df = np.asarray(cell_metadata_df.file["NAME"])
        cell_names_cluster_df = np.asarray(cluster.file["NAME"])
        # Cell names found in both cluster and metadata files
        common_cell_names = cell_names_cluster_df[
            np.isin(cell_names_cluster_df, cell_names_cell_metadata_df)
        ]
        print(f"common cell names: {common_cell_names}")
        # Perform merge
        print(cluster.file[["NAME", "x", "y", "z"]])
        cluster.merge_df(cluster.file[["NAME", "x", "y", "z"]], cell_metadata_df.file)

        # Ensure ONLY common cell names found in cell metadata file and cluster file
        # are in the newly merged df
        result = all(
            cell[0] in common_cell_names for cell in cluster.file["NAME"].values
        )
        self.assertTrue(
            result,
            f"Merge was not performed correctly. Merge should be performed on 'NAME'",
        )

    def test_validate_numeric_annots(self):
        cluster = Annotations(
            "../tests/data/cluster_bad_missing_coordinate.txt",
            TestAnnotations.ALLOWED_FILE_TYPES,
        )
        cluster.create_data_frame()
        self.assertTrue(cluster.validate_numeric_annots)

    def test_get_cell_names(self):
        import pandas as pd

        expected_cell_names = [
            "CELL_0001",
            "  CELL_0002",
            "  CELL_0003",
            "  CELL_0004",
            "  CELL_0005",
            "  CELL_0006",
            "  CELL_0007",
            "  CELL_0008",
            "  CELL_0009",
            " CELL_00010",
            " CELL_00011",
            " CELL_00012",
            " CELL_00013",
            " CELL_00014",
            " CELL_00015",
            " CELL_00016",
            " CELL_00017",
            " CELL_00018",
            " CELL_00019",
            " CELL_00020",
        ]
        column_names = [
            ("NAME", "TYPE"),
            ("Cluster", "group"),
            ("Sub-Cluster", "group"),
            ("Average Intensity", "numeric"),
        ]
        index = pd.MultiIndex.from_tuples(column_names)

        df = pd.read_csv(
            "../tests/data/metadata_example.txt", sep="\t", names=index, skiprows=2
        )
        cells = Annotations.get_cell_names(df)
        self.assertEqual(cells, expected_cell_names)