예제 #1
0
 def init_train_tqdm(self, trainer):
     """Override this to customize the tqdm bar for training."""
     bar = track(
         None,
         total=trainer.max_epochs,
         description="Training",
         style=settings.progress_bar_style,
         initial=self.train_batch_idx,
         disable=self.is_disabled,
     )
     return bar
예제 #2
0
    def train(self,
              n_epochs=400,
              lr=1e-3,
              eps=0.01,
              params=None,
              **extras_kwargs):
        begin = time.time()
        self.model.train()

        if params is None:
            params = filter(lambda p: p.requires_grad, self.model.parameters())

        self.optimizer = torch.optim.Adam(params,
                                          lr=lr,
                                          eps=eps,
                                          weight_decay=self.weight_decay)

        # Initialization of other model's optimizers
        self.training_extras_init(**extras_kwargs)

        self.compute_metrics_time = 0
        self.n_epochs = n_epochs
        self.compute_metrics()

        self.on_training_begin()

        for self.epoch in track(range(n_epochs),
                                description="Training...",
                                disable=self.silent):
            self.on_epoch_begin()
            for tensors_dict in self.data_loaders_loop():
                if tensors_dict[0][_CONSTANTS.X_KEY].shape[0] < 3:
                    continue
                self.on_iteration_begin()
                # Update the model's parameters after seeing the data
                self.on_training_loop(tensors_dict)
                # Checks the training status, ensures no nan loss
                self.on_iteration_end()

            # Computes metrics and controls early stopping
            if not self.on_epoch_end():
                break
        if self.early_stopping.save_best_state_metric is not None:
            self.model.load_state_dict(self.best_state_dict)
            self.compute_metrics()

        self.model.eval()
        self.training_extras_end()

        self.training_time += (time.time() - begin) - self.compute_metrics_time
        self.on_training_end()
예제 #3
0
def _download(url: str, save_path: str, filename: str):
    """Writes data from url to file."""
    if os.path.exists(os.path.join(save_path, filename)):
        logger.info("File %s already downloaded" %
                    (os.path.join(save_path, filename)))
        return
    req = urllib.request.Request(url, headers={"User-Agent": "Magic Browser"})
    r = urllib.request.urlopen(req)
    logger.info("Downloading file at %s" % os.path.join(save_path, filename))

    def read_iter(file, block_size=1000):
        """
        Iterates through file.

        Given a file 'file', returns an iterator that returns bytes of
        size 'blocksize' from the file, using read().
        """
        while True:
            block = file.read(block_size)
            if not block:
                break
            yield block

    # Create the path to save the data
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    block_size = 1000

    filesize = int(r.getheader("Content-Length"))
    filesize = np.rint(filesize / block_size)
    with open(os.path.join(save_path, filename), "wb") as f:
        iterator = read_iter(r, block_size=block_size)
        for data in track(iterator,
                          style="tqdm",
                          total=filesize,
                          description="Downloading..."):
            f.write(data)
예제 #4
0
def _de_core(
    adata,
    model_fn,
    groupby,
    group1,
    group2,
    idx1,
    idx2,
    all_stats,
    all_stats_fn,
    col_names,
    mode,
    batchid1,
    batchid2,
    delta,
    batch_correction,
    fdr,
    **kwargs
):
    """Internal function for DE interface."""
    if group1 is None and idx1 is None:
        group1 = adata.obs[groupby].astype("category").cat.categories.tolist()
        if len(group1) == 1:
            raise ValueError(
                "Only a single group in the data. Can't run DE on a single group."
            )

    if not isinstance(group1, IterableClass) or isinstance(group1, str):
        group1 = [group1]

    # make a temp obs key using indices
    temp_key = None
    if idx1 is not None:
        idx1 = np.asarray(idx1).ravel()
        g1_key = "one"
        obs_col = np.array(["None"] * adata.shape[0], dtype=str)
        obs_col[idx1] = g1_key
        group2 = None if idx2 is None else "two"
        if idx2 is not None:
            idx2 = np.asarray(idx2).ravel()
            obs_col[idx2] = group2
        temp_key = "_scvi_temp_de"
        adata.obs[temp_key] = obs_col
        groupby = temp_key
        group1 = [g1_key]

    df_results = []
    dc = DifferentialComputation(model_fn, adata)
    for g1 in track(
        group1,
        description="DE...",
    ):
        cell_idx1 = (adata.obs[groupby] == g1).to_numpy().ravel()
        if group2 is None:
            cell_idx2 = ~cell_idx1
        else:
            cell_idx2 = (adata.obs[groupby] == group2).to_numpy().ravel()

        all_info = dc.get_bayes_factors(
            cell_idx1,
            cell_idx2,
            mode=mode,
            delta=delta,
            batchid1=batchid1,
            batchid2=batchid2,
            use_observed_batches=not batch_correction,
            **kwargs,
        )

        if all_stats is True:
            genes_properties_dict = all_stats_fn(adata, cell_idx1, cell_idx2)
            all_info = {**all_info, **genes_properties_dict}

        res = pd.DataFrame(all_info, index=col_names)
        sort_key = "proba_de" if mode == "change" else "bayes_factor"
        res = res.sort_values(by=sort_key, ascending=False)
        if mode == "change":
            res["is_de_fdr_{}".format(fdr)] = _fdr_de_prediction(
                res["proba_de"], fdr=fdr
            )
        if idx1 is None:
            g2 = "Rest" if group2 is None else group2
            res["comparison"] = "{} vs {}".format(g1, g2)
        df_results.append(res)

    if temp_key is not None:
        del adata.obs[temp_key]

    result = pd.concat(df_results, axis=0)

    return result
예제 #5
0
def poisson_gene_selection(
    adata,
    layer: Optional[str] = None,
    n_top_genes: int = 4000,
    use_cuda: bool = True,
    subset: bool = False,
    inplace: bool = True,
    n_samples: int = 10000,
    batch_key: str = None,
    silent: bool = False,
    minibatch_size: int = 5000,
    **kwargs,
):
    """
    Rank and select genes based on the enrichment of zero counts in data compared to a Poisson count model.

    This is based on M3Drop: https://github.com/tallulandrews/M3Drop

    The method accounts for library size internally, a raw count matrix should be provided.

    Instead of Z-test, enrichment of zeros is quantified by posterior
    probabilites from a binomial model, computed through sampling.


    Parameters
    ----------
    adata
        AnnData object (with sparse X matrix).
    layer
        If provided, use `adata.layers[layer]` for expression values instead of `adata.X`.
    n_top_genes
        How many variable genes to select.
    use_cuda
        Whether to use GPU
    subset
        Inplace subset to highly-variable genes if `True` otherwise merely indicate
        highly variable genes.
    inplace
        Whether to place calculated metrics in `.var` or return them.
    n_samples
        The number of Binomial samples to use to estimate posterior probability
        of enrichment of zeros for each gene.
    batch_key
        key in adata.obs that contains batch info. If None, do not use batch info.
        Defatult: ``None``.
    silent
        If ``True``, disables the progress bar.
    minibatch_size
        Size of temporary matrix for incremental calculation. Larger is faster but
        requires more RAM or GPU memory. (The default should be fine unless
        there are hundreds of millions cells or millions of genes.)

    Returns
    -------
    Depending on `inplace` returns calculated metrics (:class:`~pd.DataFrame`) or
    updates `.var` with the following fields

    highly_variable : bool
        boolean indicator of highly-variable genes
    **observed_fraction_zeros**
        fraction of observed zeros per gene
    **expected_fraction_zeros**
        expected fraction of observed zeros per gene
    prob_zero_enrichment : float
        Probability of zero enrichment, median across batches in the case of multiple batches
    prob_zero_enrichment_rank : float
        Rank of the gene according to probability of zero enrichment, median rank in the case of multiple batches
    prob_zero_enriched_nbatches : int
        If batch_key is given, this denotes in how many batches genes are detected as zero enriched

    """
    data = adata.layers[layer] if layer is not None else adata.X
    if _check_nonnegative_integers(data) is False:
        raise ValueError("`poisson_gene_selection` expects " "raw count data.")

    use_cuda = use_cuda and torch.cuda.is_available()

    if batch_key is None:
        batch_info = pd.Categorical(np.zeros(adata.shape[0], dtype=int))
    else:
        batch_info = adata.obs[batch_key]

    prob_zero_enrichments = []
    obs_frac_zeross = []
    exp_frac_zeross = []
    for b in np.unique(batch_info):

        ad = adata[batch_info == b]
        data = ad.layers[layer] if layer is not None else ad.X

        # Calculate empirical statistics.
        scaled_means = torch.from_numpy(
            np.asarray(data.sum(0) / data.sum()).ravel())
        if use_cuda is True:
            scaled_means = scaled_means.cuda()
        dev = scaled_means.device
        total_counts = torch.from_numpy(np.asarray(
            data.sum(1)).ravel()).to(dev)

        observed_fraction_zeros = torch.from_numpy(
            np.asarray(1.0 -
                       (data > 0).sum(0) / data.shape[0]).ravel()).to(dev)

        # Calculate probability of zero for a Poisson model.
        # Perform in batches to save memory.
        minibatch_size = min(total_counts.shape[0], minibatch_size)
        n_batches = total_counts.shape[0] // minibatch_size

        expected_fraction_zeros = torch.zeros(scaled_means.shape).to(dev)

        for i in range(n_batches):
            total_counts_batch = total_counts[i * minibatch_size:(i + 1) *
                                              minibatch_size]
            # Use einsum for outer product.
            expected_fraction_zeros += torch.exp(-torch.einsum(
                "i,j->ij", [scaled_means, total_counts_batch])).sum(1)

        total_counts_batch = total_counts[(i + 1) * minibatch_size:]
        expected_fraction_zeros += torch.exp(-torch.einsum(
            "i,j->ij", [scaled_means, total_counts_batch])).sum(1)
        expected_fraction_zeros /= data.shape[0]

        # Compute probability of enriched zeros through sampling from Binomial distributions.
        observed_zero = torch.distributions.Binomial(
            probs=observed_fraction_zeros)
        expected_zero = torch.distributions.Binomial(
            probs=expected_fraction_zeros)

        extra_zeros = torch.zeros(expected_fraction_zeros.shape).to(dev)
        for i in track(
                range(n_samples),
                description="Sampling from binomial...",
                disable=silent,
                style="tqdm",  # do not change
        ):
            extra_zeros += observed_zero.sample() > expected_zero.sample()

        prob_zero_enrichment = (extra_zeros / n_samples).cpu().numpy()

        obs_frac_zeros = observed_fraction_zeros.cpu().numpy()
        exp_frac_zeros = expected_fraction_zeros.cpu().numpy()

        # Clean up memory (tensors seem to stay in GPU unless actively deleted).
        del scaled_means
        del total_counts
        del expected_fraction_zeros
        del observed_fraction_zeros
        del extra_zeros

        if use_cuda:
            torch.cuda.empty_cache()

        prob_zero_enrichments.append(prob_zero_enrichment.reshape(1, -1))
        obs_frac_zeross.append(obs_frac_zeros.reshape(1, -1))
        exp_frac_zeross.append(exp_frac_zeros.reshape(1, -1))

    # Combine per batch results

    prob_zero_enrichments = np.concatenate(prob_zero_enrichments, axis=0)
    obs_frac_zeross = np.concatenate(obs_frac_zeross, axis=0)
    exp_frac_zeross = np.concatenate(exp_frac_zeross, axis=0)

    ranked_prob_zero_enrichments = prob_zero_enrichments.argsort(
        axis=1).argsort(axis=1)
    median_prob_zero_enrichments = np.median(prob_zero_enrichments, axis=0)

    median_obs_frac_zeross = np.median(obs_frac_zeross, axis=0)
    median_exp_frac_zeross = np.median(exp_frac_zeross, axis=0)

    median_ranked = np.median(ranked_prob_zero_enrichments, axis=0)

    num_batches_zero_enriched = np.sum(ranked_prob_zero_enrichments >=
                                       (adata.shape[1] - n_top_genes),
                                       axis=0)

    df = pd.DataFrame(index=np.array(adata.var_names))
    df["observed_fraction_zeros"] = median_obs_frac_zeross
    df["expected_fraction_zeros"] = median_exp_frac_zeross
    df["prob_zero_enriched_nbatches"] = num_batches_zero_enriched
    df["prob_zero_enrichment"] = median_prob_zero_enrichments
    df["prob_zero_enrichment_rank"] = median_ranked

    df["highly_variable"] = False
    sort_columns = ["prob_zero_enriched_nbatches", "prob_zero_enrichment_rank"]
    top_genes = df.nlargest(n_top_genes, sort_columns).index
    df.loc[top_genes, "highly_variable"] = True

    if inplace or subset:
        adata.uns["hvg"] = {"flavor": "poisson_zeros"}
        logger.debug(
            "added\n"
            "    'highly_variable', boolean vector (adata.var)\n"
            "    'prob_zero_enrichment_rank', float vector (adata.var)\n"
            "    'prob_zero_enrichment' float vector (adata.var)\n"
            "    'observed_fraction_zeros', float vector (adata.var)\n"
            "    'expected_fraction_zeros', float vector (adata.var)\n")
        adata.var["highly_variable"] = df["highly_variable"].values
        adata.var["observed_fraction_zeros"] = df[
            "observed_fraction_zeros"].values
        adata.var["expected_fraction_zeros"] = df[
            "expected_fraction_zeros"].values
        adata.var["prob_zero_enriched_nbatches"] = df[
            "prob_zero_enriched_nbatches"].values
        adata.var["prob_zero_enrichment"] = df["prob_zero_enrichment"].values
        adata.var["prob_zero_enrichment_rank"] = df[
            "prob_zero_enrichment_rank"].values

        if batch_key is not None:
            adata.var["prob_zero_enriched_nbatches"] = df[
                "prob_zero_enriched_nbatches"].values
        if subset:
            adata._inplace_subset_var(df["highly_variable"].values)
    else:
        if batch_key is None:
            df = df.drop(["prob_zero_enriched_nbatches"], axis=1)
        return df