Exemplo n.º 1
0
def test_split_metagene(pretrained_toy_model, toydata):
    r"""Test that metagenes are split correctly"""
    st_experiment = pretrained_toy_model.get_experiment("ST")
    metagene = next(iter(st_experiment.metagenes.keys()))
    metagene_new = st_experiment.split_metagene(metagene)

    with Session(
            model=pretrained_toy_model,
            dataloader=toydata,
            covariates=extract_covariates(toydata.dataset.data.design),
    ):
        x = to_device(next(iter(toydata)))
        with pyro.poutine.trace() as guide_tr:
            get("model").guide(x)
        with pyro.poutine.trace() as model_tr:
            with pyro.poutine.replay(trace=guide_tr.trace):
                get("model").model(x)

    rim_mean = model_tr.trace.nodes["rim"]["fn"].mean
    assert (rim_mean[0, 0] == rim_mean[-1][0, -1]).all()

    rate_mg = guide_tr.trace.nodes[_encode_metagene_name(metagene)]["fn"].mean
    rate_mg_new = guide_tr.trace.nodes[_encode_metagene_name(
        metagene_new)]["fn"].mean
    assert (rate_mg == rate_mg_new).all()
Exemplo n.º 2
0
def test_compute_differential_expression(
    pretrained_toy_model, toydata, tmp_path
):
    with Session(
        model=pretrained_toy_model,
        genes=toydata.dataset.genes,
        dataloader=toydata,
        covariates=extract_covariates(toydata.dataset.data.design),
        save_path=tmp_path,
        eval=True,
    ):
        compute_differential_expression("annotation1", "annotation2")

    assert os.path.exists(tmp_path / "differential_expression" / "data.csv.gz")
Exemplo n.º 3
0
def test_metagenes(pretrained_toy_model, toydata, tmp_path):
    with Session(
            model=pretrained_toy_model,
            genes=toydata.dataset.genes,
            dataloader=toydata,
            covariates=extract_covariates(toydata.dataset.data.design),
            save_path=tmp_path,
            eval=True,
    ):
        compute_metagene_summary()

    for section in map(os.path.basename, toydata.dataset.data.design):
        assert os.path.exists(tmp_path / "metagenes" / section /
                              f"summary.png")
Exemplo n.º 4
0
def test_gene_maps(pretrained_toy_model, toydata, tmp_path):
    with Session(
        model=pretrained_toy_model,
        genes=toydata.dataset.genes,
        dataloader=toydata,
        covariates=extract_covariates(toydata.dataset.data.design),
        save_path=tmp_path,
        eval=True,
    ):
        compute_gene_maps()

    for section in toydata.dataset.data.design:
        for gene in toydata.dataset.genes:
            assert os.path.exists(
                tmp_path / "gene_maps" / section / f"{gene}.png"
            )
Exemplo n.º 5
0
def pretrained_toy_model(toydata):
    r"""Pretrained toy model"""
    # pylint: disable=redefined-outer-name
    st_experiment = ST(
        depth=2,
        num_channels=4,
        metagenes=[MetageneDefault(0.0, None) for _ in range(1)],
    )
    xfuse = XFuse(experiments=[st_experiment])
    with Session(
            model=xfuse,
            optimizer=pyro.optim.Adam({"lr": 0.001}),
            dataloader=toydata,
            covariates=extract_covariates(toydata.dataset.data.design),
    ):
        train(100 + get("training_data").epoch)
    return xfuse
Exemplo n.º 6
0
def test_toydata(mocker, toydata):
    r"""Integration test on toy dataset"""
    st_experiment = ST(
        depth=2,
        num_channels=4,
        metagenes=[MetageneDefault(0.0, None) for _ in range(3)],
    )
    xfuse = XFuse(experiments=[st_experiment])
    rmse = RMSE()
    mock_log_scalar = mocker.patch("xfuse.messengers.stats.rmse.log_scalar")
    with Session(
        model=xfuse,
        optimizer=pyro.optim.Adam({"lr": 0.0001}),
        dataloader=toydata,
        covariates=extract_covariates(toydata.dataset.data.design),
        messengers=[rmse],
    ):
        train(100 + get("training_data").epoch)
    rmses = [x[1][1] for x in mock_log_scalar.mock_calls]
    assert rmses[-1] < 6.0
Exemplo n.º 7
0
def test_metagene_expansion(
    # pylint: disable=redefined-outer-name
    toydata,
    pretrained_toy_model,
    expansion_strategies,
    compute_expected_metagenes,
):
    r"""Test metagene expansion dynamics"""
    st_experiment = pretrained_toy_model.get_experiment("ST")
    num_start_metagenes = len(st_experiment.metagenes)

    for expansion_strategy, expected_metagenes in zip(
        expansion_strategies, compute_expected_metagenes(num_start_metagenes)
    ):
        with Session(
            covariates=extract_covariates(toydata.dataset.data.design),
            dataloader=toydata,
            metagene_expansion_strategy=expansion_strategy,
            model=pretrained_toy_model,
        ):
            purge_metagenes(num_samples=10)
        assert len(st_experiment.metagenes) == expected_metagenes
Exemplo n.º 8
0
def test_compute_imputation(pretrained_toy_model, toydata, tmp_path):
    with Session(
            model=pretrained_toy_model,
            genes=toydata.dataset.genes,
            dataloader=toydata,
            covariates=extract_covariates(toydata.dataset.data.design),
            save_path=tmp_path,
            eval=True,
    ):
        compute_imputation("annotation1")

    for name, slide in toydata.dataset.data.slides.items():
        name = os.path.basename(name)
        output_file = (tmp_path / "imputation-annotation1" / name /
                       "imputed_counts.csv.gz")
        assert os.path.exists(output_file)

        output_data = pd.read_csv(output_file)
        output_data_labels = np.unique(output_data.label)
        output_data_labels = np.sort(output_data_labels)
        annotation_labels = np.unique(slide.data.annotation("annotation1"))
        annotation_labels = np.sort(annotation_labels[annotation_labels > 0])
        assert (annotation_labels == output_data_labels).all()