def test_split_metagene(pretrained_toy_model, toydata): r"""Test that metagenes are split correctly""" st_experiment = pretrained_toy_model.get_experiment("ST") metagene = next(iter(st_experiment.metagenes.keys())) metagene_new = st_experiment.split_metagene(metagene) with Session( model=pretrained_toy_model, dataloader=toydata, covariates=extract_covariates(toydata.dataset.data.design), ): x = to_device(next(iter(toydata))) with pyro.poutine.trace() as guide_tr: get("model").guide(x) with pyro.poutine.trace() as model_tr: with pyro.poutine.replay(trace=guide_tr.trace): get("model").model(x) rim_mean = model_tr.trace.nodes["rim"]["fn"].mean assert (rim_mean[0, 0] == rim_mean[-1][0, -1]).all() rate_mg = guide_tr.trace.nodes[_encode_metagene_name(metagene)]["fn"].mean rate_mg_new = guide_tr.trace.nodes[_encode_metagene_name( metagene_new)]["fn"].mean assert (rate_mg == rate_mg_new).all()
def _mock_run(*_args, **_kwargs): with Session(panic=Unset()): assert get("training_data").step > 1 new_state_dict = get_state_dict() assert all( (new_state_dict.modules[module_name][param_name] == param_value ).all() for module_name, module_state in state_dict.modules.items() for param_name, param_value in module_state.items()) assert all( (new_state_dict.params[param_name] == param_value).all() for param_name, param_value in state_dict.params.items())
def pretrained_toy_model(toydata): r"""Pretrained toy model""" # pylint: disable=redefined-outer-name st_experiment = ST( depth=2, num_channels=4, metagenes=[MetageneDefault(0.0, None) for _ in range(1)], ) xfuse = XFuse(experiments=[st_experiment]) with Session( model=xfuse, optimizer=pyro.optim.Adam({"lr": 0.001}), dataloader=toydata, ): train(100 + get("training_data").epoch) return xfuse
def test_toydata(mocker, toydata): r"""Integration test on toy dataset""" st_experiment = ST( depth=2, num_channels=4, metagenes=[MetageneDefault(0.0, None) for _ in range(3)], ) xfuse = XFuse(experiments=[st_experiment]) rmse = RMSE() mock_log_scalar = mocker.patch("xfuse.messengers.stats.rmse.log_scalar") with Session( model=xfuse, optimizer=pyro.optim.Adam({"lr": 0.0001}), dataloader=toydata, covariates=extract_covariates(toydata.dataset.data.design), messengers=[rmse], ): train(100 + get("training_data").epoch) rmses = [x[1][1] for x in mock_log_scalar.mock_calls] assert rmses[-1] < 6.0
def pretrained_toy_model(toydata): r"""Pretrained toy model""" # pylint: disable=redefined-outer-name st_experiment = ST( depth=2, num_channels=4, metagenes=[MetageneDefault(0.0, None) for _ in range(1)], ) xfuse = XFuse(experiments=[st_experiment]) with Session( model=xfuse, optimizer=pyro.optim.Adam({"lr": 0.001}), dataloader=toydata, genes=toydata.dataset.genes, covariates={ covariate: values.cat.categories.values.tolist() for covariate, values in toydata.dataset.data.design.iteritems() }, ): train(100 + get("training_data").epoch) return xfuse