Exemple #1
0
    def test_svr_invalid_weight_name_no_raise_fit(self, adata_cflare: AnnData):
        model = SKLearnModel(
            adata_cflare, SVR(), weight_name="w", ignore_raise=True
        ).prepare(adata_cflare.var_names[0], "0")

        with pytest.raises(TypeError):
            model.fit()
Exemple #2
0
    def test_svr_correct_no_weights(self, adata_cflare: AnnData):
        model = (SKLearnModel(adata_cflare, SVR(), weight_name="").prepare(
            adata_cflare.var_names[0], "0").fit())
        model_w = (SKLearnModel(adata_cflare,
                                SVR()).prepare(adata_cflare.var_names[0],
                                               "0").fit())

        assert model._weight_name == ""
        assert model_w._weight_name == "sample_weight"

        assert not np.allclose(model.predict(), model_w.predict())
Exemple #3
0
    def maybe_sanity_check(callbacks: Dict[str, Dict[str, Callable]]) -> None:
        if not perform_sanity_check:
            return

        from sklearn.svm import SVR

        logg.debug("Performing callback sanity checks")
        for gene in callbacks.keys():
            for lineage, cb in callbacks[gene].items():
                # create the model here because the callback can search the attribute
                dummy_model = SKLearnModel(adata, model=SVR())
                try:
                    model = cb(dummy_model, gene=gene, lineage=lineage, **kwargs)
                    assert model is dummy_model, (
                        "Creation of new models is not allowed. "
                        "Ensure that callback returns the same model."
                    )
                    assert (
                        model.prepared
                    ), "Model is not prepared. Ensure that callback calls `.prepare()`."
                    assert (
                        model._gene == gene
                    ), f"Callback modified the gene from `{gene!r}` to `{model._gene!r}`."
                    assert (
                        model._lineage == lineage
                    ), f"Callback modified the lineage from `{lineage!r}` to `{model._lineage!r}`."
                except Exception as e:
                    raise RuntimeError(
                        f"Callback validation failed for gene `{gene!r}` and lineage `{lineage!r}`."
                    ) from e
Exemple #4
0
    def test_svr_invalid_weight_name_no_raise(self, adata_cflare: AnnData):
        model = SKLearnModel(adata_cflare,
                             SVR(),
                             weight_name="foobar",
                             ignore_raise=True)

        assert model._weight_name == "foobar"
Exemple #5
0
    def test_svr_correct_weight_name(self, adata_cflare: AnnData):
        model = SKLearnModel(adata_cflare, SVR())

        assert model._weight_name == "sample_weight"
Exemple #6
0
 def test_svr_invalid_weight_name(self, adata_cflare: AnnData):
     with pytest.raises(ValueError):
         SKLearnModel(adata_cflare, SVR(), weight_name="foobar")
Exemple #7
0
 def test_wrong_model_type(self, adata_cflare: AnnData):
     model = create_model(adata_cflare)
     with pytest.raises(TypeError):
         SKLearnModel(adata_cflare, model)
Exemple #8
0
 def test_wrong_type(self):
     with pytest.raises(TypeError):
         SKLearnModel(0, SVR())