コード例 #1
0
    def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True):
        """
        Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.

        Parameters
        ----------
        q
            quantile to compute
        use_gpu
            Bool, use gpu?

        Returns
        -------
        dictionary {variable_name: posterior median}

        """

        self.module.eval()
        gpus, device = parse_use_gpu_arg(use_gpu)

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)
        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        means = self.module.guide.quantiles([q], *args, **kwargs)
        means = {k: means[k].cpu().detach().numpy() for k in means.keys()}

        return means
コード例 #2
0
ファイル: _data_splitting.py プロジェクト: vals/scVI
    def __call__(self, remake_splits=False):
        if remake_splits:
            self.train_idx, self.test_idx, self.val_idx = self.make_splits()

        gpus = parse_use_gpu_arg(self.use_gpu, return_device=False)
        pin_memory = (True if (settings.dl_pin_memory_gpu_training
                               and gpus != 0) else False)

        # do not remove drop_last=3, skips over small minibatches
        return (
            AnnDataLoader(
                self.adata,
                indices=self.train_idx,
                shuffle=True,
                drop_last=3,
                pin_memory=pin_memory,
                **self.data_loader_kwargs,
            ),
            AnnDataLoader(
                self.adata,
                indices=self.val_idx,
                shuffle=True,
                drop_last=3,
                pin_memory=pin_memory,
                **self.data_loader_kwargs,
            ),
            AnnDataLoader(
                self.adata,
                indices=self.test_idx,
                shuffle=True,
                drop_last=3,
                pin_memory=pin_memory,
                **self.data_loader_kwargs,
            ),
        )
コード例 #3
0
    def setup(self, stage: Optional[str] = None):
        """Split indices in train/test/val sets."""
        n = self.adata.n_obs
        n_train, n_val = validate_data_split(n, self.train_size,
                                             self.validation_size)
        random_state = np.random.RandomState(seed=settings.seed)
        permutation = random_state.permutation(n)
        self.val_idx = permutation[:n_val]
        self.train_idx = permutation[n_val:(n_val + n_train)]
        self.test_idx = permutation[(n_val + n_train):]

        gpus, self.device = parse_use_gpu_arg(self.use_gpu, return_device=True)
        self.pin_memory = (True if (settings.dl_pin_memory_gpu_training
                                    and gpus != 0) else False)
コード例 #4
0
ファイル: _trainrunner.py プロジェクト: vitkl/scvi-tools
 def __init__(
     self,
     model: BaseModelClass,
     training_plan: pl.LightningModule,
     data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter],
     max_epochs: int,
     use_gpu: Optional[Union[str, int, bool]] = None,
     **trainer_kwargs,
 ):
     self.training_plan = training_plan
     self.data_splitter = data_splitter
     self.model = model
     gpus, device = parse_use_gpu_arg(use_gpu)
     self.gpus = gpus
     self.device = device
     self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs)
コード例 #5
0
ファイル: _data_splitting.py プロジェクト: vals/scVI
    def __call__(self, remake_splits=False):
        if remake_splits:
            self.train_idx, self.test_idx, self.val_idx = self.make_splits()

        gpus = parse_use_gpu_arg(self.use_gpu, return_device=False)
        pin_memory = (True if (settings.dl_pin_memory_gpu_training
                               and gpus != 0) else False)

        if len(self._labeled_indices) != 0:
            data_loader_class = SemiSupervisedDataLoader
            dl_kwargs = {
                "unlabeled_category": self.unlabeled_category,
                "n_samples_per_label": self.n_samples_per_label,
            }
        else:
            data_loader_class = AnnDataLoader
            dl_kwargs = {}

        dl_kwargs.update(self.data_loader_kwargs)

        scanvi_train_dl = data_loader_class(
            self.adata,
            indices=self.train_idx,
            shuffle=True,
            drop_last=3,
            pin_memory=pin_memory,
            **dl_kwargs,
        )
        scanvi_val_dl = data_loader_class(
            self.adata,
            indices=self.val_idx,
            shuffle=True,
            drop_last=3,
            pin_memory=pin_memory,
            **dl_kwargs,
        )
        scanvi_test_dl = data_loader_class(
            self.adata,
            indices=self.test_idx,
            shuffle=True,
            drop_last=3,
            pin_memory=pin_memory,
            **dl_kwargs,
        )

        return scanvi_train_dl, scanvi_val_dl, scanvi_test_dl
コード例 #6
0
    def _posterior_quantile(self,
                            q: float = 0.5,
                            batch_size: int = None,
                            use_gpu: bool = None,
                            use_median: bool = False):
        """
        Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.

        Parameters
        ----------
        q
            Quantile to compute
        use_gpu
            Bool, use gpu?
        use_median
            Bool, when q=0.5 use median rather than quantile method of the guide

        Returns
        -------
        dictionary {variable_name: posterior quantile}

        """

        self.module.eval()
        gpus, device = parse_use_gpu_arg(use_gpu)
        if batch_size is None:
            batch_size = self.adata_manager.adata.n_obs
        train_dl = AnnDataLoader(self.adata_manager,
                                 shuffle=False,
                                 batch_size=batch_size)
        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        if use_median and q == 0.5:
            means = self.module.guide.median(*args, **kwargs)
        else:
            means = self.module.guide.quantiles([q], *args, **kwargs)
        means = {k: means[k].cpu().detach().numpy() for k in means.keys()}

        return means
コード例 #7
0
    def _posterior_quantile_minibatch(self,
                                      q: float = 0.5,
                                      batch_size: int = 2048,
                                      use_gpu: bool = None,
                                      use_median: bool = False):
        """
        Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
        and global variables, which is necessary when performing amortised inference.

        Note for developers: requires model class method which lists observation/minibatch plate
        variables (self.module.model.list_obs_plate_vars()).

        Parameters
        ----------
        q
            quantile to compute
        batch_size
            number of observations per batch
        use_gpu
            Bool, use gpu?
        use_median
            Bool, when q=0.5 use median rather than quantile method of the guide

        Returns
        -------
        dictionary {variable_name: posterior quantile}

        """

        gpus, device = parse_use_gpu_arg(use_gpu)

        self.module.eval()

        train_dl = AnnDataLoader(self.adata_manager,
                                 shuffle=False,
                                 batch_size=batch_size)

        # sample local parameters
        i = 0
        for tensor_dict in train_dl:

            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:
                # find plate sites
                obs_plate_sites = self._get_obs_plate_sites(
                    args, kwargs, return_observed=True)
                if len(obs_plate_sites) == 0:
                    # if no local variables - don't sample
                    break
                # find plate dimension
                obs_plate_dim = list(obs_plate_sites.values())[0]
                if use_median and q == 0.5:
                    means = self.module.guide.median(*args, **kwargs)
                else:
                    means = self.module.guide.quantiles([q], *args, **kwargs)
                means = {
                    k: means[k].cpu().numpy()
                    for k in means.keys() if k in obs_plate_sites
                }

            else:
                if use_median and q == 0.5:
                    means_ = self.module.guide.median(*args, **kwargs)
                else:
                    means_ = self.module.guide.quantiles([q], *args, **kwargs)

                means_ = {
                    k: means_[k].cpu().numpy()
                    for k in means_.keys() if k in obs_plate_sites
                }
                means = {
                    k: np.concatenate([means[k], means_[k]],
                                      axis=obs_plate_dim)
                    for k in means.keys()
                }
            i += 1

        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        if use_median and q == 0.5:
            global_means = self.module.guide.median(*args, **kwargs)
        else:
            global_means = self.module.guide.quantiles([q], *args, **kwargs)
        global_means = {
            k: global_means[k].cpu().numpy()
            for k in global_means.keys() if k not in obs_plate_sites
        }

        for k in global_means.keys():
            means[k] = global_means[k]

        self.module.to(device)

        return means
コード例 #8
0
ファイル: _base_model.py プロジェクト: yf817/scvi-tools
    def load(
        cls,
        dir_path: str,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = SCVI.load(save_path, adata)
        >>> vae.get_latent_representation()
        """
        load_adata = adata is None
        use_gpu, device = parse_use_gpu_arg(use_gpu)

        (
            scvi_setup_dict,
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path, load_adata, map_location=device)
        adata = new_adata if new_adata is not None else adata

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata)
        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        # some Pyro modules with AutoGuides may need one training step
        try:
            model.module.load_state_dict(model_state_dict)
        except RuntimeError as err:
            if isinstance(model.module, PyroBaseModuleClass):
                logger.info("Preparing underlying module for load")
                model.train(max_steps=1)
                pyro.clear_param_store()
                model.module.load_state_dict(model_state_dict)
            else:
                raise err

        model.to_device(device)
        model.module.eval()
        model._validate_anndata(adata)
        return model
コード例 #9
0
ファイル: _archesmixin.py プロジェクト: vitkl/scvi-tools
    def load_query_data(
        cls,
        adata: AnnData,
        reference_model: Union[str, BaseModelClass],
        inplace_subset_query_vars: bool = False,
        use_gpu: Optional[Union[str, int, bool]] = None,
        unfrozen: bool = False,
        freeze_dropout: bool = False,
        freeze_expression: bool = True,
        freeze_decoder_first_layer: bool = True,
        freeze_batchnorm_encoder: bool = True,
        freeze_batchnorm_decoder: bool = False,
        freeze_classifier: bool = True,
    ):
        """
        Online update of a reference model with scArches algorithm [Lotfollahi21]_.

        Parameters
        ----------
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run setup_anndata,
            as AnnData is validated against the saved `scvi` setup dictionary.
        reference_model
            Either an already instantiated model of the same class, or a path to
            saved outputs for reference model.
        inplace_subset_query_vars
            Whether to subset and rearrange query vars inplace based on vars used to
            train reference model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
        unfrozen
            Override all other freeze options for a fully unfrozen model
        freeze_dropout
            Whether to freeze dropout during training
        freeze_expression
            Freeze neurons corersponding to expression in first layer
        freeze_decoder_first_layer
            Freeze neurons corersponding to first layer in decoder
        freeze_batchnorm_encoder
            Whether to freeze batchnorm weight and bias during training for encoder
        freeze_batchnorm_decoder
            Whether to freeze batchnorm weight and bias during training for decoder
        freeze_classifier
            Whether to freeze classifier completely. Only applies to `SCANVI`.
        """
        use_gpu, device = parse_use_gpu_arg(use_gpu)
        if isinstance(reference_model, str):
            (
                attr_dict,
                var_names,
                load_state_dict,
                _,
            ) = _load_saved_files(reference_model,
                                  load_adata=False,
                                  map_location=device)
        else:
            attr_dict = reference_model._get_user_attributes()
            attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
            var_names = reference_model.adata.var_names
            load_state_dict = deepcopy(reference_model.module.state_dict())

        scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")

        if inplace_subset_query_vars:
            logger.debug("Subsetting query vars to reference vars.")
            adata._inplace_subset_var(var_names)
        _validate_var_names(adata, var_names)

        version_split = scvi_setup_dict["scvi_version"].split(".")
        if version_split[1] < "8" and version_split[0] == "0":
            warnings.warn(
                "Query integration should be performed using models trained with version >= 0.8"
            )

        transfer_anndata_setup(scvi_setup_dict, adata, extend_categories=True)

        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        model.to_device(device)

        # model tweaking
        new_state_dict = model.module.state_dict()
        for key, load_ten in load_state_dict.items():
            new_ten = new_state_dict[key]
            if new_ten.size() == load_ten.size():
                continue
            # new categoricals changed size
            else:
                dim_diff = new_ten.size()[-1] - load_ten.size()[-1]
                fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]],
                                      dim=-1)
                load_state_dict[key] = fixed_ten

        model.module.load_state_dict(load_state_dict)
        model.module.eval()

        _set_params_online_update(
            model.module,
            unfrozen=unfrozen,
            freeze_decoder_first_layer=freeze_decoder_first_layer,
            freeze_batchnorm_encoder=freeze_batchnorm_encoder,
            freeze_batchnorm_decoder=freeze_batchnorm_decoder,
            freeze_dropout=freeze_dropout,
            freeze_expression=freeze_expression,
            freeze_classifier=freeze_classifier,
        )
        model.is_trained_ = False

        return model
コード例 #10
0
    def load(
        cls,
        dir_path: str,
        prefix: Optional[str] = None,
        adata_seq: Optional[AnnData] = None,
        adata_spatial: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        prefix
            Prefix of saved file names.
        adata_seq
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :meth:`~scvi.external.GIMVI.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            AnnData must be registered via :meth:`~scvi.external.GIMVI.setup_anndata`.
        adata_spatial
            AnnData organized in the same way as data used to train model.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path)
        >>> vae.get_latent_representation()
        """
        _, device = parse_use_gpu_arg(use_gpu)

        (
            attr_dict,
            seq_var_names,
            spatial_var_names,
            model_state_dict,
            loaded_adata_seq,
            loaded_adata_spatial,
        ) = _load_saved_gimvi_files(
            dir_path,
            adata_seq is None,
            adata_spatial is None,
            prefix=prefix,
            map_location=device,
        )
        adata_seq = loaded_adata_seq or adata_seq
        adata_spatial = loaded_adata_spatial or adata_spatial
        adatas = [adata_seq, adata_spatial]
        var_names = [seq_var_names, spatial_var_names]

        for i, adata in enumerate(adatas):
            saved_var_names = var_names[i]
            user_var_names = adata.var_names.astype(str)
            if not np.array_equal(saved_var_names, user_var_names):
                warnings.warn(
                    "var_names for adata passed in does not match var_names of "
                    "adata used to train the model. For valid results, the vars "
                    "need to be the same and in the same order as the adata used to train the model."
                )

        if "scvi_setup_dicts_" in attr_dict:
            scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
            for adata, scvi_setup_dict in zip(adatas, scvi_setup_dicts):
                cls.register_manager(
                    manager_from_setup_dict(cls, adata, scvi_setup_dict)
                )
        else:
            registries = attr_dict.pop("registries_")
            for adata, registry in zip(adatas, registries):
                if (
                    _MODEL_NAME_KEY in registry
                    and registry[_MODEL_NAME_KEY] != cls.__name__
                ):
                    raise ValueError(
                        "It appears you are loading a model from a different class."
                    )

                if _SETUP_KWARGS_KEY not in registry:
                    raise ValueError(
                        "Saved model does not contain original setup inputs. "
                        "Cannot load the original setup."
                    )

                cls.setup_anndata(
                    adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY]
                )

        # get the parameters for the class init signiture
        init_params = attr_dict.pop("init_params_")

        # new saving and loading, enable backwards compatibility
        if "non_kwargs" in init_params.keys():
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = init_params["non_kwargs"]
            kwargs = init_params["kwargs"]

            # expand out kwargs
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        else:
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = {
                k: v for k, v in init_params.items() if not isinstance(v, dict)
            }
            kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs)

        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        model.module.load_state_dict(model_state_dict)
        model.module.eval()
        model.to_device(device)
        return model
コード例 #11
0
ファイル: _model.py プロジェクト: saketkc/scVI
    def load(
        cls,
        dir_path: str,
        adata_seq: Optional[AnnData] = None,
        adata_spatial: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        adata_seq
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            AnnData must be registered via :func:`~scvi.data.setup_anndata`.
        adata_spatial
            AnnData organized in the same way as data used to train model.
            If None, will check for and load anndata saved with the model.
        dir_path
            Path to saved outputs.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path)
        >>> vae.get_latent_representation()
        """
        model_path = os.path.join(dir_path, "model_params.pt")
        setup_dict_path = os.path.join(dir_path, "attr.pkl")
        seq_data_path = os.path.join(dir_path, "adata_seq.h5ad")
        spatial_data_path = os.path.join(dir_path, "adata_spatial.h5ad")
        seq_var_names_path = os.path.join(dir_path, "var_names_seq.csv")
        spatial_var_names_path = os.path.join(dir_path, "var_names_spatial.csv")

        if adata_seq is None and os.path.exists(seq_data_path):
            adata_seq = read(seq_data_path)
        elif adata_seq is None and not os.path.exists(seq_data_path):
            raise ValueError(
                "Save path contains no saved anndata and no adata was passed."
            )
        if adata_spatial is None and os.path.exists(spatial_data_path):
            adata_spatial = read(spatial_data_path)
        elif adata_spatial is None and not os.path.exists(spatial_data_path):
            raise ValueError(
                "Save path contains no saved anndata and no adata was passed."
            )
        adatas = [adata_seq, adata_spatial]

        seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str)
        spatial_var_names = np.genfromtxt(
            spatial_var_names_path, delimiter=",", dtype=str
        )
        var_names = [seq_var_names, spatial_var_names]

        for i, adata in enumerate(adatas):
            saved_var_names = var_names[i]
            user_var_names = adata.var_names.astype(str)
            if not np.array_equal(saved_var_names, user_var_names):
                warnings.warn(
                    "var_names for adata passed in does not match var_names of "
                    "adata used to train the model. For valid results, the vars "
                    "need to be the same and in the same order as the adata used to train the model."
                )

        with open(setup_dict_path, "rb") as handle:
            attr_dict = pickle.load(handle)

        scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
        transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq)
        transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial)

        # get the parameters for the class init signiture
        init_params = attr_dict.pop("init_params_")

        # new saving and loading, enable backwards compatibility
        if "non_kwargs" in init_params.keys():
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = init_params["non_kwargs"]
            kwargs = init_params["kwargs"]

            # expand out kwargs
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        else:
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = {
                k: v for k, v in init_params.items() if not isinstance(v, dict)
            }
            kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs)

        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        _, device = parse_use_gpu_arg(use_gpu)
        model.module.load_state_dict(torch.load(model_path, map_location=device))
        model.module.eval()
        model.to_device(device)
        return model
コード例 #12
0
ファイル: _pyromixin.py プロジェクト: vitkl/scvi-tools
    def _posterior_samples_minibatch(
        self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
    ):
        """
        Generate samples of the posterior distribution in minibatches.

        Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables
        and global variables, which is necessary when performing minibatch inference.

        Parameters
        ----------
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

        Returns
        -------
        dictionary {variable_name: [array with samples in 0 dimension]}
        """
        samples = dict()

        _, device = parse_use_gpu_arg(use_gpu)

        batch_size = batch_size if batch_size is not None else settings.batch_size

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)
        # sample local parameters
        i = 0
        for tensor_dict in track(
            train_dl,
            style="tqdm",
            description="Sampling local variables, batch: ",
        ):
            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:
                return_observed = getattr(sample_kwargs, "return_observed", False)
                obs_plate_sites = self._get_obs_plate_sites(
                    args, kwargs, return_observed=return_observed
                )
                if len(obs_plate_sites) == 0:
                    # if no local variables - don't sample
                    break
                obs_plate_dim = list(obs_plate_sites.values())[0]

                sample_kwargs_obs_plate = sample_kwargs.copy()
                sample_kwargs_obs_plate[
                    "return_sites"
                ] = self._get_obs_plate_return_sites(
                    sample_kwargs["return_sites"], list(obs_plate_sites.keys())
                )
                sample_kwargs_obs_plate["show_progress"] = False

                samples = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )
            else:
                samples_ = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )

                samples = {
                    k: np.array(
                        [
                            np.concatenate(
                                [samples[k][j], samples_[k][j]],
                                axis=obs_plate_dim,
                            )
                            for j in range(
                                len(samples[k])
                            )  # for each sample (in 0 dimension
                        ]
                    )
                    for k in samples.keys()  # for each variable
                }
            i += 1

        # sample global parameters
        global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs)
        global_samples = {
            k: v
            for k, v in global_samples.items()
            if k not in list(obs_plate_sites.keys())
        }

        for k in global_samples.keys():
            samples[k] = global_samples[k]

        self.module.to(device)

        return samples
コード例 #13
0
    def load(
        cls,
        dir_path: str,
        prefix: Optional[str] = None,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        prefix
            Prefix of saved file names.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run setup_anndata,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> model = ModelClass.load(save_path, adata) # use the name of the model class used to save
        >>> model.get_....
        """
        load_adata = adata is None
        use_gpu, device = parse_use_gpu_arg(use_gpu)

        (
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path,
                              load_adata,
                              map_location=device,
                              prefix=prefix)
        adata = new_adata if new_adata is not None else adata
        _validate_var_names(adata, var_names)

        # Legacy support for old setup dict format.
        if "scvi_setup_dict_" in attr_dict:
            scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
            cls.register_manager(
                manager_from_setup_dict(cls, adata, scvi_setup_dict))
        else:
            registry = attr_dict.pop("registry_")
            if (_MODEL_NAME_KEY in registry
                    and registry[_MODEL_NAME_KEY] != cls.__name__):
                raise ValueError(
                    "It appears you are loading a model from a different class."
                )

            if _SETUP_KWARGS_KEY not in registry:
                raise ValueError(
                    "Saved model does not contain original setup inputs. "
                    "Cannot load the original setup.")

            # Calling ``setup_anndata`` method with the original arguments passed into
            # the saved model. This enables simple backwards compatibility in the case of
            # newly introduced fields or parameters.
            cls.setup_anndata(adata,
                              source_registry=registry,
                              **registry[_SETUP_KWARGS_KEY])

        model = _initialize_model(cls, adata, attr_dict)

        # some Pyro modules with AutoGuides may need one training step
        try:
            model.module.load_state_dict(model_state_dict)
        except RuntimeError as err:
            if isinstance(model.module, PyroBaseModuleClass):
                old_history = model.history_.copy()
                logger.info("Preparing underlying module for load")
                model.train(max_steps=1)
                model.history_ = old_history
                pyro.clear_param_store()
                model.module.load_state_dict(model_state_dict)
            else:
                raise err

        model.to_device(device)
        model.module.eval()
        model._validate_anndata(adata)
        return model
コード例 #14
0
    def setup(self, stage: Optional[str] = None):
        """Split indices in train/test/val sets."""
        n_labeled_idx = len(self._labeled_indices)
        n_unlabeled_idx = len(self._unlabeled_indices)

        if n_labeled_idx != 0:
            n_labeled_train, n_labeled_val = validate_data_split(
                n_labeled_idx, self.train_size, self.validation_size)
            rs = np.random.RandomState(seed=settings.seed)
            labeled_permutation = rs.choice(self._labeled_indices,
                                            len(self._labeled_indices),
                                            replace=False)
            labeled_idx_val = labeled_permutation[:n_labeled_val]
            labeled_idx_train = labeled_permutation[n_labeled_val:(
                n_labeled_val + n_labeled_train)]
            labeled_idx_test = labeled_permutation[(n_labeled_val +
                                                    n_labeled_train):]
        else:
            labeled_idx_test = []
            labeled_idx_train = []
            labeled_idx_val = []

        if n_unlabeled_idx != 0:
            n_unlabeled_train, n_unlabeled_val = validate_data_split(
                n_unlabeled_idx, self.train_size, self.validation_size)
            rs = np.random.RandomState(seed=settings.seed)
            unlabeled_permutation = rs.choice(self._unlabeled_indices,
                                              len(self._unlabeled_indices))
            unlabeled_idx_val = unlabeled_permutation[:n_unlabeled_val]
            unlabeled_idx_train = unlabeled_permutation[n_unlabeled_val:(
                n_unlabeled_val + n_unlabeled_train)]
            unlabeled_idx_test = unlabeled_permutation[(n_unlabeled_val +
                                                        n_unlabeled_train):]
        else:
            unlabeled_idx_train = []
            unlabeled_idx_val = []
            unlabeled_idx_test = []

        indices_train = np.concatenate(
            (labeled_idx_train, unlabeled_idx_train))
        indices_val = np.concatenate((labeled_idx_val, unlabeled_idx_val))
        indices_test = np.concatenate((labeled_idx_test, unlabeled_idx_test))

        self.train_idx = indices_train.astype(int)
        self.val_idx = indices_val.astype(int)
        self.test_idx = indices_test.astype(int)

        gpus = parse_use_gpu_arg(self.use_gpu, return_device=False)
        self.pin_memory = (True if (settings.dl_pin_memory_gpu_training
                                    and gpus != 0) else False)

        if len(self._labeled_indices) != 0:
            self.data_loader_class = SemiSupervisedDataLoader
            dl_kwargs = {
                "unlabeled_category": self.unlabeled_category,
                "n_samples_per_label": self.n_samples_per_label,
            }
        else:
            self.data_loader_class = AnnDataLoader
            dl_kwargs = {}

        self.data_loader_kwargs.update(dl_kwargs)
コード例 #15
0
    def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True):
        """
        Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
        and global variables, which is necessary when performing amortised inference.

        Note for developers: requires model class method which lists observation/minibatch plate
        variables (self.module.model.list_obs_plate_vars()).

        Parameters
        ----------
        q
            quantile to compute
        batch_size
            number of observations per batch
        use_gpu
            Bool, use gpu?

        Returns
        -------
        dictionary {variable_name: posterior median}

        """

        gpus, device = parse_use_gpu_arg(use_gpu)

        self.module.eval()

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)

        # sample local parameters
        i = 0
        for tensor_dict in train_dl:

            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:

                means = self.module.guide.quantiles([q], *args, **kwargs)
                means = {
                    k: means[k].cpu().numpy()
                    for k in means.keys()
                    if k in self.module.model.list_obs_plate_vars()["sites"]
                }

                # find plate dimension
                trace = poutine.trace(self.module.model).get_trace(*args, **kwargs)
                # print(trace.nodes[self.module.model.list_obs_plate_vars()['name']])
                obs_plate = {
                    name: site["cond_indep_stack"][0].dim
                    for name, site in trace.nodes.items()
                    if site["type"] == "sample"
                    if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"])
                }

            else:

                means_ = self.module.guide.quantiles([q], *args, **kwargs)
                means_ = {
                    k: means_[k].cpu().numpy()
                    for k in means_.keys()
                    if k in list(self.module.model.list_obs_plate_vars()["sites"].keys())
                }
                means = {
                    k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys()
                }
            i += 1

        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        global_means = self.module.guide.quantiles([q], *args, **kwargs)
        global_means = {
            k: global_means[k].cpu().numpy()
            for k in global_means.keys()
            if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys())
        }

        for k in global_means.keys():
            means[k] = global_means[k]

        self.module.to(device)

        return means
コード例 #16
0
    def train(
        self,
        max_epochs: int = 200,
        use_gpu: Optional[Union[str, int, bool]] = None,
        kappa: int = 5,
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        batch_size: int = 128,
        plan_kwargs: Optional[dict] = None,
        **kwargs,
    ):
        """
        Train the model.

        Parameters
        ----------
        max_epochs
            Number of passes through the dataset. If `None`, defaults to
            `np.min([round((20000 / n_cells) * 400), 400])`
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
        kappa
            Scaling parameter for the discriminator loss.
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training.
        plan_kwargs
            Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        gpus, device = parse_use_gpu_arg(use_gpu)

        self.trainer = Trainer(
            max_epochs=max_epochs,
            gpus=gpus,
            **kwargs,
        )
        self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], []
        train_dls, test_dls, val_dls = [], [], []
        for i, adm in enumerate(self.adata_managers.values()):
            ds = DataSplitter(
                adm,
                train_size=train_size,
                validation_size=validation_size,
                batch_size=batch_size,
                use_gpu=use_gpu,
            )
            ds.setup()
            train_dls.append(ds.train_dataloader())
            test_dls.append(ds.test_dataloader())
            val = ds.val_dataloader()
            val_dls.append(val)
            val.mode = i
            self.train_indices_.append(ds.train_idx)
            self.test_indices_.append(ds.test_idx)
            self.validation_indices_.append(ds.val_idx)
        train_dl = TrainDL(train_dls)

        plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
        self._training_plan = GIMVITrainingPlan(
            self.module,
            adversarial_classifier=True,
            scale_adversarial_loss=kappa,
            **plan_kwargs,
        )

        if train_size == 1.0:
            # circumvent the empty data loader problem if all dataset used for training
            self.trainer.fit(self._training_plan, train_dl)
        else:
            # accepts list of val dataloaders
            self.trainer.fit(self._training_plan, train_dl, val_dls)
        try:
            self.history_ = self.trainer.logger.history
        except AttributeError:
            self.history_ = None
        self.module.eval()

        self.to_device(device)
        self.is_trained_ = True
コード例 #17
0
    def load(
        cls,
        dir_path: str,
        prefix: Optional[str] = None,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        prefix
            Prefix of saved file names.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run setup_anndata,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> model = ModelClass.load(save_path, adata) # use the name of the model class used to save
        >>> model.get_....
        """
        load_adata = adata is None
        use_gpu, device = parse_use_gpu_arg(use_gpu)

        (
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path, load_adata, map_location=device, prefix=prefix)
        adata = new_adata if new_adata is not None else adata

        scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")

        # Filter out keys that are no longer populated by setup_anndata.
        # TODO(jhong): remove hack with setup_anndata refactor.
        deprecated_keys = {"local_l_mean", "local_l_var"}
        scvi_setup_dict["data_registry"] = {
            k: v
            for k, v in scvi_setup_dict["data_registry"].items()
            if k not in deprecated_keys
        }

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata)
        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        # some Pyro modules with AutoGuides may need one training step
        try:
            model.module.load_state_dict(model_state_dict)
        except RuntimeError as err:
            if isinstance(model.module, PyroBaseModuleClass):
                old_history = model.history_.copy()
                logger.info("Preparing underlying module for load")
                model.train(max_steps=1)
                model.history_ = old_history
                pyro.clear_param_store()
                model.module.load_state_dict(model_state_dict)
            else:
                raise err

        model.to_device(device)
        model.module.eval()
        model._validate_anndata(adata)
        return model