예제 #1
0
def test_device_backed_data_splitter():
    a = synthetic_iid()
    # test leaving validataion_size empty works
    ds = DeviceBackedDataSplitter(a, train_size=1.0, use_gpu=None)
    ds.setup()
    train_dl = ds.train_dataloader()
    ds.val_dataloader()
    assert len(next(iter(train_dl))["X"]) == a.shape[0]

    model = SCVI(a, n_latent=5)
    training_plan = TrainingPlan(model.module, len(ds.train_idx))
    runner = TrainRunner(
        model,
        training_plan=training_plan,
        data_splitter=ds,
        max_epochs=1,
        use_gpu=None,
    )
    runner()
예제 #2
0
def test_device_backed_data_splitter():
    a = synthetic_iid()
    SCVI.setup_anndata(a, batch_key="batch", labels_key="labels")
    model = SCVI(a, n_latent=5)
    adata_manager = model.adata_manager
    # test leaving validataion_size empty works
    ds = DeviceBackedDataSplitter(adata_manager, train_size=1.0, use_gpu=None)
    ds.setup()
    train_dl = ds.train_dataloader()
    ds.val_dataloader()
    loaded_x = next(iter(train_dl))["X"]
    assert len(loaded_x) == a.shape[0]
    np.testing.assert_array_equal(loaded_x.cpu().numpy(), a.X)

    training_plan = TrainingPlan(model.module, len(ds.train_idx))
    runner = TrainRunner(
        model,
        training_plan=training_plan,
        data_splitter=ds,
        max_epochs=1,
        use_gpu=None,
    )
    runner()
예제 #3
0
파일: _pyromixin.py 프로젝트: saketkc/scVI
    def train(
        self,
        max_epochs: Optional[int] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        batch_size: int = 128,
        early_stopping: bool = False,
        lr: Optional[float] = None,
        plan_kwargs: Optional[dict] = None,
        **trainer_kwargs,
    ):
        """
        Train the model.

        Parameters
        ----------
        max_epochs
            Number of passes through the dataset. If `None`, defaults to
            `np.min([round((20000 / n_cells) * 400), 400])`
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training. If `None`, no minibatching occurs and all
            data is copied to device (e.g., GPU).
        early_stopping
            Perform early stopping. Additional arguments can be passed in `**kwargs`.
            See :class:`~scvi.train.Trainer` for further options.
        lr
            Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
            Specifying optimiser via plan_kwargs overrides this choice of lr.
        plan_kwargs
            Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **trainer_kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        if max_epochs is None:
            n_obs = self.adata.n_obs
            max_epochs = np.min([round((20000 / n_obs) * 1000), 1000])

        plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
        if lr is not None and "optim" not in plan_kwargs.keys():
            plan_kwargs.update({"optim_kwargs": {"lr": lr}})

        if batch_size is None:
            # use data splitter which moves data to GPU once
            data_splitter = DeviceBackedDataSplitter(
                self.adata,
                train_size=train_size,
                validation_size=validation_size,
                batch_size=batch_size,
                use_gpu=use_gpu,
            )
        else:
            data_splitter = DataSplitter(
                self.adata,
                train_size=train_size,
                validation_size=validation_size,
                batch_size=batch_size,
                use_gpu=use_gpu,
            )
        training_plan = PyroTrainingPlan(pyro_module=self.module,
                                         **plan_kwargs)

        es = "early_stopping"
        trainer_kwargs[es] = (early_stopping
                              if es not in trainer_kwargs.keys() else
                              trainer_kwargs[es])

        data_splitter.setup()
        if "callbacks" not in trainer_kwargs.keys():
            trainer_kwargs["callbacks"] = []
        trainer_kwargs["callbacks"].append(
            PyroJitGuideWarmup(data_splitter.train_dataloader()))

        runner = TrainRunner(
            self,
            training_plan=training_plan,
            data_splitter=data_splitter,
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            **trainer_kwargs,
        )
        return runner()