Exemplo n.º 1
0
    def test_hdf5_load(self):
        hdf5_summary = "celltype_summary.hdf5"
        orig_ast = Astir(self.expr, self.marker_dict)
        orig_ast.fit_type(max_epochs=5, n_init=1, n_init_epochs=1)
        orig_ast.fit_state(max_epochs=5, n_init=1, n_init_epochs=1)
        orig_ast.save_models(hdf5_summary)
        new_ast = Astir()
        new_ast.load_model(hdf5_summary)

        orig_type_run_info = orig_ast.get_type_run_info()
        orig_state_run_info = orig_ast.get_state_run_info()
        new_type_run_info = new_ast.get_type_run_info()
        new_state_run_info = new_ast.get_state_run_info()
        for key, val in orig_type_run_info.items():
            if val != new_type_run_info[key]:
                raise AssertionError(
                    "variable " + key +
                    " is different in original model and loaded model")
        for key, val in orig_state_run_info.items():
            if val != new_state_run_info[key]:
                raise AssertionError(
                    "variable " + key +
                    " is different in original model and loaded model")

        orig_type_losses = orig_ast.get_type_losses()
        orig_state_losses = orig_ast.get_state_losses()
        new_type_losses = new_ast.get_type_losses()
        new_state_losses = new_ast.get_state_losses()
        if not (all(orig_type_losses == new_type_losses)
                and all(orig_state_losses == new_state_losses)):
            raise AssertionError(
                "loss is different in original model and loaded model")
Exemplo n.º 2
0
class TestAstir(TestCase):
    def __init__(self, *args, **kwargs):
        super(TestAstir, self).__init__(*args, **kwargs)
        warnings.filterwarnings("ignore", category=UserWarning)
        self.expr_csv_file = os.path.join(os.path.dirname(__file__),
                                          "test-data/test_data.csv")
        self.marker_yaml_file = os.path.join(
            os.path.dirname(__file__), "test-data/jackson-2020-markers.yml")
        self.design_file = os.path.join(os.path.dirname(__file__),
                                        "test-data/design.csv")
        self.test_dir = os.path.join(os.path.dirname(__file__),
                                     "test-data/test-dir-read")
        self.adata_file = os.path.join(os.path.dirname(__file__),
                                       "test-data/adata_small.h5ad")

        self.expr = pd.read_csv(self.expr_csv_file, index_col=0)
        with open(self.marker_yaml_file, "r") as stream:
            self.marker_dict = yaml.safe_load(stream)

        self.a = Astir(self.expr, self.marker_dict)
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

    def test_basic_instance_creation(self):
        """Tests basic instance creation"""

        self.assertIsInstance(self.a, Astir)
        # self.assertTrue(isinstance(a, str))

    def test_csv_reading(self):
        """Test from_csv_yaml function"""
        a = from_csv_yaml(self.expr_csv_file, self.marker_yaml_file)

        self.assertIsInstance(a, Astir)

    def test_dir_reading(self):

        a = from_csv_dir_yaml(self.test_dir, self.marker_yaml_file)

        self.assertIsInstance(a, Astir)

        ## Make sure the design matrix has been constructed correctly
        self.assertTrue(a._type_dset._design.shape[0] == len(a._type_dset))
        files = os.listdir(self.test_dir)
        files = [f for f in files if f.endswith(".csv")]
        self.assertTrue(a._type_dset._design.shape[1] == len(files))

    def test_csv_reading_with_design(self):

        a = from_csv_yaml(self.expr_csv_file,
                          self.marker_yaml_file,
                          design_csv=self.design_file)

        self.assertIsInstance(a, Astir)

    def test_fitting_type(self):

        epochs = 2
        with open(os.devnull, "w") as devnull:
            with contextlib.redirect_stdout(devnull):
                self.a.fit_type(max_epochs=epochs)

        # Check probability matrix looks ok
        probabilities = self.a.get_celltype_probabilities()

        self.assertTrue(probabilities.shape[0] == self.expr.shape[0])
        self.assertIsInstance(probabilities, pd.DataFrame)

        # Check assignments look ok
        assignments = self.a.get_celltypes()
        self.assertIsInstance(assignments, pd.DataFrame)
        self.assertTrue(assignments.shape[0] == self.expr.shape[0])
        self.assertTrue(assignments.columns[0] == "cell_type")

        # Check diagnostics look ok
        type_diagnostics = self.a.diagnostics_celltype(threshold=0.2, alpha=0)
        self.assertIsInstance(type_diagnostics, pd.DataFrame)
        self.assertTrue(type_diagnostics.shape[1] ==
                        7)  # make sure we have the standard 6 columns

    def test_no_overlap(self):
        bad_file = os.path.join(os.path.dirname(__file__),
                                "test-data/bad_data.csv")
        bad_data = pd.read_csv(bad_file)
        raised = False
        try:
            test = Astir(bad_data, self.marker_dict)
        except (RuntimeError):
            raised = True
        self.assertTrue(raised == True)

    def test_missing_marker(self):
        bad_marker = os.path.join(os.path.dirname(__file__),
                                  "test-data/bad_marker.yml")
        with open(bad_marker, "r") as stream:
            bad_dict = yaml.safe_load(stream)
        raised = False
        try:
            test = Astir(self.expr, bad_dict)
            test.fit_type()
        except (RuntimeError):
            raised = True
        self.assertTrue(raised)

    # # Uncomment below test functions to test private variables
    # # Commented it out because these tests can be highly overlapping with
    # # future unittests

    def test_sanitize_dict_state(self):
        """Testing the method _sanitize_dict"""
        expected_state_dict = self.marker_dict["cell_states"]
        (_, actual_state_dict, _) = self.a._sanitize_dict(self.marker_dict)

        expected_state_dict = {
            i: sorted(j) if isinstance(j, list) else j
            for i, j in expected_state_dict.items()
        }
        actual_state_dict = {
            i: sorted(j) if isinstance(j, list) else j
            for i, j in actual_state_dict.items()
        }

        self.assertDictEqual(
            expected_state_dict,
            actual_state_dict,
            "state_dict is different from its expected value",
        )

    def test_state_names(self):
        """Test _state_names field"""
        expected_state_names = sorted(self.marker_dict["cell_states"].keys())
        (_, actual_state_dict, _) = self.a._sanitize_dict(self.marker_dict)
        actual_state_names = sorted(actual_state_dict.keys())

        self.assertListEqual(expected_state_names, actual_state_names,
                             "unexpected state_names value")

    def test_celltype_same_seed_same_result(self):
        """Test whether the loss after one epoch one two different models
        with the same random seed have the same losses after one epochs
        """
        warnings.filterwarnings("ignore", category=UserWarning)
        model1 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )
        model2 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )

        model1.fit_type(max_epochs=10)
        model1_loss = model1.get_type_losses()
        model2.fit_type(max_epochs=10)
        model2_loss = model2.get_type_losses()

        self.assertTrue(np.abs(model1_loss - model2_loss)[-1] < 1e-6)

    def test_celltype_diff_seed_diff_result(self):
        """Test whether the loss after one epoch one two different models
        with the different random seed have different losses after one epoch
        """
        warnings.filterwarnings("ignore", category=UserWarning)
        model1 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )
        model2 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=1234,
        )

        model1.fit_type(max_epochs=10)
        model1_loss = model1.get_type_losses()
        model2.fit_type(max_epochs=10)
        model2_loss = model2.get_type_losses()

        self.assertFalse(np.abs(model1_loss - model2_loss)[-1] < 1e-6)

    def test_cellstate_same_seed_same_result(self):
        """Test whether the loss after one epoch one two different models
        with the same random seed have the same losses after one epochs
        """
        warnings.filterwarnings("ignore", category=UserWarning)
        model1 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )
        model2 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )

        model1.fit_state(max_epochs=5)
        model1_loss = model1.get_state_losses()
        model2.fit_state(max_epochs=5)
        model2_loss = model2.get_state_losses()

        self.assertTrue(np.abs(model1_loss - model2_loss)[-1] < 1e-6)

    # @pytest.mark.filterwarnings("ignore")
    def test_cellstate_diff_seed_diff_result(self):
        """Test whether the loss after one epoch one two different models
        with the different random seed have different losses after one epoch
        """
        warnings.filterwarnings("ignore", category=UserWarning)
        model1 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=42,
        )
        model2 = Astir(
            input_expr=self.expr,
            marker_dict=self.marker_dict,
            design=None,
            random_seed=1234,
        )

        model1.fit_state(max_epochs=5)
        model1_loss = model1.get_state_losses()
        model2.fit_state(max_epochs=5)
        model2_loss = model2.get_state_losses()

        self.assertFalse(np.abs(model1_loss - model2_loss)[-1] < 1e-6)

    def test_cellstate_assignment(self):
        warnings.filterwarnings("ignore", category=UserWarning)
        self.a.fit_state(max_epochs=50, n_init=1)

        state_assignments = self.a.get_cellstates()

        n_classes = len(list(self.marker_dict.keys()))

        self.assertTrue(state_assignments.shape, (len(self.expr), n_classes))

    def test_cellstate_predicted_assignment(self):
        warnings.filterwarnings("ignore", category=UserWarning)
        dset = SCDataset(
            expr_input=self.expr,
            marker_dict=self.marker_dict["cell_states"],
            design=None,
            include_other_column=False,
            device=self._device,
        )

        self.a.fit_state(max_epochs=50, n_init=1)

        state_assignments = self.a.predict_cellstates(dset)

        self.assertTrue(state_assignments.shape,
                        (len(dset), dset.get_n_classes()))

    def test_celltype_assignment(self):
        warnings.filterwarnings("ignore", category=UserWarning)
        self.a.fit_type(max_epochs=50, n_init=1)

        type_assignments = self.a.get_celltypes()

        n_classes = len(list(self.marker_dict.keys()))

        self.assertTrue(type_assignments.shape,
                        (len(self.expr), n_classes + 1))

    def test_celltype_predicted_assignment(self):
        warnings.filterwarnings("ignore", category=UserWarning)

        self.a.fit_type(max_epochs=50, n_init=1)

        type_predict = self.a.predict_celltypes()
        type_assignment = self.a.get_celltype_probabilities()
        comp = type_predict == type_assignment
        self.assertTrue(comp.all().all())

    # def test_adata_reading(self):
    #     ast = from_anndata_yaml(
    #         self.adata_file,
    #         self.marker_yaml_file,
    #         protein_name="protein",
    #         cell_name="cell_name",
    #         batch_name="batch",
    #     )

    #     self.assertTrue(ast.get_type_dataset().get_n_features() == 14)
    #     self.assertTrue(ast.get_type_dataset().get_n_classes() == 6)
    #     self.assertIsInstance(ast.get_type_dataset().get_exprs(), torch.Tensor)
    #     self.assertEqual(ast.get_type_dataset().get_exprs().shape[0], 10)

    def test_cellstate_diagnostics(self):
        warnings.filterwarnings("ignore", category=UserWarning)
        self.a.fit_state(max_epochs=50, n_init=1)

        state_diagnostics = self.a.diagnostics_cellstate()
        self.assertIsInstance(state_diagnostics, pd.DataFrame)

    def test_type_hdf5_summary(self):
        hdf5_summary = "celltype_summary.hdf5"
        info = {
            "max_epochs": 5,
            "learning_rate": 0.001,
            "batch_size": 24,
            "delta_loss": 0.001,
            "n_init": 1,
            "n_init_epochs": 1,
        }
        self.a.fit_type(
            max_epochs=info["max_epochs"],
            learning_rate=info["learning_rate"],
            batch_size=info["batch_size"],
            delta_loss=info["delta_loss"],
            n_init=info["n_init"],
            n_init_epochs=info["n_init_epochs"],
        )
        self.a.save_models(hdf5_summary)
        params = list(self.a.get_type_model().get_data().items()) + list(
            self.a.get_type_model().get_variables().items())

        recog_params = []
        for key, val in self.a.get_type_model().get_recognet(
        ).named_parameters():
            recog_params.append((key, val.detach().cpu().numpy()))
        same = True
        with h5py.File(hdf5_summary, "r") as f:
            f_params = f["/celltype_model/parameters"]
            for key, val in params:
                if not (val.detach().cpu().numpy()
                        == f_params[key][()]).all().all():
                    same = False
            f_recog = f["/celltype_model/recog_net"]
            for key, val in recog_params:
                if not (f_recog[key][()] == val).all():
                    same = False
            f_info = f["/celltype_model/run_info"]
            for key, val in info.items():
                if val != f_info[key][()]:
                    same = False
            if not (self.a.get_type_model().get_losses().cpu().numpy()
                    == f["/celltype_model/losses"]["losses"][()]).all():
                same = False
        self.assertTrue(same)

    def test_state_summary(self):
        hdf5_summary = "cellstate_summary.hdf5"
        info = {
            "max_epochs": 5,
            "learning_rate": 0.001,
            "batch_size": 24,
            "delta_loss": 0.001,
            "n_init": 1,
            "n_init_epochs": 1,
            "delta_loss_batch": 2,
        }
        self.a.fit_state(
            max_epochs=info["max_epochs"],
            learning_rate=info["learning_rate"],
            batch_size=info["batch_size"],
            delta_loss=info["delta_loss"],
            n_init=info["n_init"],
            n_init_epochs=info["n_init_epochs"],
            delta_loss_batch=info["delta_loss_batch"],
        )
        self.a.save_models(hdf5_summary)
        params = list(self.a.get_state_model().get_data().items()) + list(
            self.a.get_state_model().get_variables().items())
        recog_params = []
        for key, val in self.a.get_state_model().get_recognet(
        ).named_parameters():
            recog_params.append((key, val.detach().cpu().numpy()))
        same = True
        with h5py.File(hdf5_summary, "r") as f:
            f_params = f["/cellstate_model/parameters"]
            for key, val in params:
                if not (val.detach().cpu().numpy()
                        == f_params[key][()]).all().all():
                    same = False
            f_recog = f["/cellstate_model/recog_net"]
            for key, val in recog_params:
                if not (f_recog[key][()] == val).all():
                    same = False
            f_info = f["/cellstate_model/run_info"]
            for key, val in info.items():
                if val != f_info[key][()]:
                    same = False
            if not (self.a.get_state_model().get_losses().cpu().numpy()
                    == f["/cellstate_model/losses"]["losses"][()]).all():
                same = False
        self.assertTrue(same)

    def test_hierarchy_assignment(self):
        self.a.fit_type(max_epochs=5, n_init=1, n_init_epochs=1)
        original_assignment = self.a.get_celltype_probabilities()
        hier_dict = self.a.get_hierarchy_dict()
        # expected_assignment = pd.DataFrame()
        # for key, cells in hier_dict.items():
        #     expected_assignment[key] = original_assignment[cells].sum(axis=1)
        actual_assignment = self.a.assign_celltype_hierarchy(depth=3)
        # self.assertTrue((original_assignment == actual_assignment).all().all())
        for cell in actual_assignment.columns:
            self.assertTrue(
                (actual_assignment[cell] == original_assignment[cell]).all())

    def test_hdf5_load(self):
        hdf5_summary = "celltype_summary.hdf5"
        orig_ast = Astir(self.expr, self.marker_dict)
        orig_ast.fit_type(max_epochs=5, n_init=1, n_init_epochs=1)
        orig_ast.fit_state(max_epochs=5, n_init=1, n_init_epochs=1)
        orig_ast.save_models(hdf5_summary)
        new_ast = Astir()
        new_ast.load_model(hdf5_summary)

        orig_type_run_info = orig_ast.get_type_run_info()
        orig_state_run_info = orig_ast.get_state_run_info()
        new_type_run_info = new_ast.get_type_run_info()
        new_state_run_info = new_ast.get_state_run_info()
        for key, val in orig_type_run_info.items():
            if val != new_type_run_info[key]:
                raise AssertionError(
                    "variable " + key +
                    " is different in original model and loaded model")
        for key, val in orig_state_run_info.items():
            if val != new_state_run_info[key]:
                raise AssertionError(
                    "variable " + key +
                    " is different in original model and loaded model")

        orig_type_losses = orig_ast.get_type_losses()
        orig_state_losses = orig_ast.get_state_losses()
        new_type_losses = new_ast.get_type_losses()
        new_state_losses = new_ast.get_state_losses()
        if not (all(orig_type_losses == new_type_losses)
                and all(orig_state_losses == new_state_losses)):
            raise AssertionError(
                "loss is different in original model and loaded model")