def plot_auprc( truth, preds, title_prefix: str = "Precision recall curve", fname: str = "", ): """Plot AUPRC""" truth = utils.ensure_arr(truth).flatten() preds = utils.ensure_arr(preds).flatten() precision, recall, _thresholds = metrics.precision_recall_curve( truth, preds) average_precision = metrics.average_precision_score(truth, preds) logging.info(f"Found AUPRC of {average_precision:.4f}") fig, ax = plt.subplots(dpi=SAVEFIG_DPI, figsize=(7, 5)) ax.plot(recall, precision) ax.set( xlabel="Recall", ylabel="Precision", title=f"{title_prefix} (AUPRC={average_precision:.4f})", ) if fname: fig.savefig(fname, bbox_inches="tight") return fig
def do_evaluation_atac_from_rna( spliced_net, sc_dual_full_dataset, gene_names: str, atac_names: str, outdir: str, ext: str, marker_genes: List[str], prefix: str = "", ): ### RNA > ATAC logging.info("Inferring ATAC from RNA") sc_rna_atac_full_preds = spliced_net.translate_1_to_2(sc_dual_full_dataset) sc_rna_atac_full_preds_anndata = sc.AnnData( scipy.sparse.csr_matrix(sc_rna_atac_full_preds), obs=sc_dual_full_dataset.dataset_x.data_raw.obs, ) sc_rna_atac_full_preds_anndata.var_names = atac_names logging.info("Writing ATAC from RNA") sc_rna_atac_full_preds_anndata.write( os.path.join(outdir, f"{prefix}_rna_atac_adata.h5ad".strip("_")) ) if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None: logging.info("Plotting ATAC from RNA") plot_utils.plot_auroc( utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), utils.ensure_arr(sc_rna_atac_full_preds).flatten(), title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(), fname=os.path.join(outdir, f"{prefix}_rna_atac_auroc.{ext}".strip("_")), )
def main(): """Run script""" parser = build_parser() args = parser.parse_args() truth = load_file_flex_format(args.truth) truth.X = utils.ensure_arr(truth.X) logging.info(f"Loaded truth {args.truth}: {truth.shape}") preds = load_file_flex_format(args.preds) preds.X = utils.ensure_arr(preds.X) logging.info(f"Loaded preds {args.preds}: {preds.shape}") common_genes = sorted( list(set(truth.var_names).intersection(preds.var_names))) logging.info(f"Shared genes: {len(common_genes)}") common_obs = sorted( list(set(truth.obs_names).intersection(preds.obs_names))) # All obs naames should intersect between preds and truth assert len(common_obs) == len(truth.obs_names) == len(preds.obs_names) plot_utils.plot_var_vs_explained_var( truth, preds, highlight_genes={k: REF_MARKER_GENES[k] for k in args.highlight}, logscale=not args.linear, constrain_y_axis=not args.unconstriained, label_outliers=args.outliers, fname=args.plotname, fname_gene_list=args.genelist, )
def plot_auroc( truth, preds, title_prefix: str = "Receiver operating characteristic", fname: str = "", ): """ Plot AUROC after flattening inputs """ truth = utils.ensure_arr(truth).flatten() preds = utils.ensure_arr(preds).flatten() fpr, tpr, _thresholds = metrics.roc_curve(truth, preds) auc = metrics.auc(fpr, tpr) logging.info(f"Found AUROC of {auc:.4f}") fig, ax = plt.subplots(dpi=300, figsize=(7, 5)) ax.plot(fpr, tpr) ax.set( xlim=(0, 1.0), ylim=(0.0, 1.05), xlabel="False positive rate", ylabel="True positive rate", title=f"{title_prefix} (AUROC={auc:.2f})", ) if fname: fig.savefig(fname, dpi=SAVEFIG_DPI, bbox_inches="tight") return fig
def plot_bulk_scatter( x: AnnData, y: AnnData, x_subset: dict = None, y_subset: dict = None, logscale: bool = True, corr_func: Callable = scipy.stats.pearsonr, title: str = "", xlabel: str = "Average measured counts", ylabel: str = "Average inferred counts", fname: str = "", ): """ Create bulk signature and plot """ if x_subset is not None: orig_size = x.n_obs x = adata_utils.filter_adata(x, filt_cells=x_subset) logging.info(f"Subsetted x from {orig_size} to {x.n_obs}") if y_subset is not None: orig_size = y.n_obs y = adata_utils.filter_adata(y, filt_cells=y_subset) logging.info(f"Subsetted y from {orig_size} to {y.n_obs}") # Make sure variables match shared_var_names = sorted(list(set(x.var_names).intersection(y.var_names))) logging.info(f"Found {len(shared_var_names)} shared variables") x = x[:, shared_var_names] y = y[:, shared_var_names] x_vals = x.X y_vals = y.X if logscale: x_vals = np.log1p(x_vals) y_vals = np.log1p(y_vals) # Ensure correct format x_vals = utils.ensure_arr(x_vals).mean(axis=0) y_vals = utils.ensure_arr(y_vals).mean(axis=0) assert not np.any(np.isnan(x_vals)) assert not np.any(np.isnan(y_vals)) pearson_r, pearson_p = corr_func(x_vals, y_vals) logging.info(f"Found pearson's correlation of {pearson_r:.4f}") fig, ax = plt.subplots(dpi=SAVEFIG_DPI, figsize=(7, 5)) ax.scatter(x_vals, y_vals, alpha=0.4) ax.set( xlabel=xlabel + (" (log)" if logscale else ""), ylabel=ylabel + (" (log)" if logscale else ""), title=(title + f" ($r={pearson_r:.2f}$)").strip(), ) if fname: fig.savefig(fname, bbox_inches="tight") return fig
def load_rna_files(rna_counts_fnames: List[str], model_dir: str, transpose: bool = True) -> ad.AnnData: """Load the RNA files in, filling in unmeasured genes as necessary""" # Find the genes that the model understands rna_genes_list_fname = os.path.join(model_dir, "rna_genes.txt") assert os.path.isfile( rna_genes_list_fname ), f"Cannot find RNA genes file: {rna_genes_list_fname}" learned_rna_genes = utils.read_delimited_file(rna_genes_list_fname) assert isinstance(learned_rna_genes, list) assert utils.is_all_unique( learned_rna_genes), "Learned genes list contains duplicates" temp_ad = utils.sc_read_multi_files( rna_counts_fnames, feature_type="Gene Expression", transpose=transpose, join="outer", ) logging.info(f"Read input RNA files for {temp_ad.shape}") temp_ad.X = utils.ensure_arr(temp_ad.X) # Filter for mouse genes and remove human/mouse prefix temp_ad.var_names_make_unique() kept_var_names = [ vname for vname in temp_ad.var_names if not vname.startswith("MOUSE_") ] if len(kept_var_names) != temp_ad.n_vars: temp_ad = temp_ad[:, kept_var_names] temp_ad.var = pd.DataFrame( index=[v.strip("HUMAN_") for v in kept_var_names]) # Expand adata to span all genes # Initiating as a sparse matrix doesn't allow vectorized building intersected_genes = set(temp_ad.var_names).intersection(learned_rna_genes) assert intersected_genes, "No overlap between learned and input genes!" expanded_mat = np.zeros((temp_ad.n_obs, len(learned_rna_genes))) skip_count = 0 for gene in intersected_genes: dest_idx = learned_rna_genes.index(gene) src_idx = temp_ad.var_names.get_loc(gene) if not isinstance(src_idx, int): logging.warn(f"Got multiple source matches for {gene}, skipping") skip_count += 1 continue v = utils.ensure_arr(temp_ad.X[:, src_idx]).flatten() expanded_mat[:, dest_idx] = v if skip_count: logging.warning( f"Skipped {skip_count}/{len(intersected_genes)} genes due to multiple matches" ) expanded_mat = sparse.csr_matrix(expanded_mat) # Compress retval = ad.AnnData(expanded_mat, obs=temp_ad.obs, var=pd.DataFrame(index=learned_rna_genes)) return retval
def do_evaluation_atac_from_atac( spliced_net, sc_dual_full_dataset, gene_names: str, atac_names: str, outdir: str, ext: str, marker_genes: List[str], prefix: str = "", ): ### ATAC > ATAC logging.info("Inferring ATAC from ATAC") sc_atac_full_preds = spliced_net.translate_2_to_2(sc_dual_full_dataset) sc_atac_full_preds_anndata = sc.AnnData( sc_atac_full_preds, obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True), ) sc_atac_full_preds_anndata.var_names = atac_names logging.info("Writing ATAC from ATAC") # Infer marker bins # logging.info("Getting marker bins for ATAC from ATAC") # plot_utils.preprocess_anndata(sc_atac_full_preds_anndata) # adata_utils.find_marker_genes(sc_atac_full_preds_anndata) # inferred_marker_bins = adata_utils.flatten_marker_genes( # sc_atac_full_preds_anndata.uns["rank_genes_leiden"] # ) # logging.info(f"Found {len(inferred_marker_bins)} marker bins for ATAC from ATAC") # with open( # os.path.join(outdir, f"{prefix}_atac_atac_marker_bins.txt".strip("_")), "w" # ) as sink: # sink.write("\n".join(inferred_marker_bins) + "\n") sc_atac_full_preds_anndata.write( os.path.join(outdir, f"{prefix}_atac_atac_adata.h5ad".strip("_")) ) if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None: logging.info("Plotting ATAC from ATAC") plot_utils.plot_auroc( utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), utils.ensure_arr(sc_atac_full_preds).flatten(), title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(), fname=os.path.join(outdir, f"{prefix}_atac_atac_auroc.{ext}".strip("_")), ) # plot_utils.plot_auprc( # utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(), # utils.ensure_arr(sc_atac_full_preds).flatten(), # title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(), # fname=os.path.join(outdir, f"{prefix}_atac_atac_auprc.{ext}".strip("_")), # ) # Remove some objects to free memory del sc_atac_full_preds del sc_atac_full_preds_anndata
def main(): parser = build_parser() args = parser.parse_args() _, input_ext = os.path.splitext(args.input) if input_ext == ".h5ad": x = ad.read_h5ad(args.input) elif input_ext == ".h5": x = sc.read_10x_h5(args.input) else: raise ValueError(f"Unrecognized file extension: {args.input}") logging.info(f"Read input: {x}") logging.info("Reading gtf for gene name map") gene_name_map = utils.read_gtf_gene_symbol_to_id() # Tranpose because BIRD wants features x obs x_df = pd.DataFrame(utils.ensure_arr(x.X), index=x.obs_names, columns=x.var_names).T assert np.all(x_df.values >= 0.0) x_df.index = [gene_name_map[g] for g in x_df.index] # Write output (tab-separated table logging.info(f"Writing output to {args.output_table_txt}") x_df.to_csv(args.output_table_txt, sep="\t")
def reindex_adata_vars(adata: AnnData, target_vars: List[str]) -> AnnData: """Reindexes the adata to match the given var_list, verbatim""" assert len(adata.var_names) == adata.n_vars if not utils.is_all_unique(adata.var_names): logging.warn("De-duping variable names before reindexing") adata.var_names_make_unique() assert utils.is_all_unique(target_vars), "Target vars are not all unique" intersected = set(adata.var_names).intersection(target_vars) logging.info( f"Overlap of {len(intersected)}/{adata.n_vars}, 0 vector will be filled in for {len(target_vars) - len(intersected)} 'missing' features" ) vars_to_cols = dict(zip(adata.var_names, utils.ensure_arr(adata.X).T)) assert (len(vars_to_cols) == adata.n_vars ), f"Size mismatch: {len(vars_to_cols)} {adata.n_vars}" default_null = np.zeros(adata.n_obs) mat = np.vstack([ vars_to_cols[v] if v in vars_to_cols else np.copy(default_null) for v in target_vars ]).T target_shape = (adata.n_obs, len(target_vars)) assert mat.shape == target_shape, f"Size mismatch: {mat.shape} {target_shape}" retval = AnnData(mat) retval.obs_names = adata.obs_names retval.var_names = target_vars return retval
def main(): """Run script""" parser = build_parser() args = parser.parse_args() adata = ad.read_h5ad(args.h5ad) logging.info(f"Read {args.h5ad} for adata of {adata.shape}") if args.discrete: # Use the discrete algorithm from pyitlib # https://pafoster.github.io/pyitlib/#discrete_random_variable.entropy_joint # https://github.com/pafoster/pyitlib/blob/master/pyitlib/discrete_random_variable.py#L3535 # Successive realisations of a random variable are indexed by the last axis in the array; multiple random variables may be specified using preceding axes. # In other words, different variables are axis 0, samples are axis 1 # This is contrary to the default ML format which is samples axis 0, variables axes 1 # Therefore we must transpose input_arr = utils.ensure_arr(adata.X).T h = drv.entropy_joint(input_arr, base=np.e) logging.info(f"Found discrete joint entropy of {h:.6f}") else: raise NotImplementedError
def main(): parser = build_parser() args = parser.parse_args() if args.x_rna.endswith(".h5ad"): x_rna = ad.read_h5ad(args.x_rna) elif args.x_rna.endswith(".h5"): x_rna = sc.read_10x_h5(args.x_rna, gex_only=False) else: raise ValueError(f"Unrecognized file extension: {args.x_rna}") x_rna.X = utils.ensure_arr(x_rna.X) x_rna.obs_names = sanitize_obs_names(x_rna.obs_names) x_rna.obs_names_make_unique() logging.info(f"Read in {args.x_rna} for {x_rna.shape}") if args.y_rna.endswith(".h5ad"): y_rna = ad.read_h5ad(args.y_rna) elif args.y_rna.endswith(".h5"): y_rna = sc.read_10x_h5(args.y_rna, gex_only=False) else: raise ValueError(f"Unrecognized file extension: {args.y_rna}") y_rna.X = utils.ensure_arr(y_rna.X) y_rna.obs_names = sanitize_obs_names(y_rna.obs_names) y_rna.obs_names_make_unique() logging.info(f"Read in {args.y_rna} for {y_rna.shape}") if not (len(x_rna.obs_names) == len(y_rna.obs_names) and np.all(x_rna.obs_names == y_rna.obs_names)): logging.warning("Rematching obs axis") shared_obs_names = sorted( list(set(x_rna.obs_names).intersection(y_rna.obs_names))) logging.info(f"Found {len(shared_obs_names)} shared obs") assert shared_obs_names, ("Got empty list of shared obs" + "\n" + str(x_rna.obs_names) + "\n" + str(y_rna.obs_names)) x_rna = x_rna[shared_obs_names] y_rna = y_rna[shared_obs_names] assert np.all(x_rna.obs_names == y_rna.obs_names) if not (len(x_rna.var_names) == len(y_rna.var_names) and np.all(x_rna.var_names == y_rna.var_names)): logging.warning("Rematching variable axis") shared_var_names = sorted( list(set(x_rna.var_names).intersection(y_rna.var_names))) logging.info(f"Found {len(shared_var_names)} shared variables") assert shared_var_names, ("Got empty list of shared vars" + "\n" + str(x_rna.var_names) + "\n" + str(y_rna.var_names)) x_rna = x_rna[:, shared_var_names] y_rna = y_rna[:, shared_var_names] assert np.all(x_rna.var_names == y_rna.var_names) # Subset by gene list if given if args.genelist: gene_list = utils.read_delimited_file(args.genelist) logging.info(f"Read {len(gene_list)} genes from {args.genelist}") x_rna = x_rna[:, gene_list] y_rna = y_rna[:, gene_list] assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}" fig = plot_utils.plot_scatter_with_r( x_rna.X, y_rna.X, subset=args.subset, one_to_one=True, logscale=not args.linear, density_heatmap=args.density, density_logstretch=args.densitylogstretch, fname=args.outfname, title=args.title, xlabel=args.xlabel, ylabel=args.ylabel, figsize=args.figsize, )
def plot_scatter_with_r( x: Union[np.ndarray, scipy.sparse.csr_matrix], y: Union[np.ndarray, scipy.sparse.csr_matrix], color=None, subset: int = 0, logscale: bool = False, density_heatmap: bool = False, density_dpi: int = 150, density_logstretch: int = 1000, title: str = "", xlabel: str = "Original norm counts", ylabel: str = "Inferred norm counts", xlim: Tuple[int, int] = None, ylim: Tuple[int, int] = None, one_to_one: bool = False, corr_func: Callable = scipy.stats.pearsonr, figsize: Tuple[float, float] = (7, 5), fname: str = "", ax=None, ): """ Plot the given x y coordinates, appending Pearsons r Setting xlim/ylim will affect both plot and R2 calculation In other words, plot view mirrors the range for which correlation is calculated """ assert x.shape == y.shape, f"Mismatched shapes: {x.shape} {y.shape}" if color is not None: assert color.size == x.size if one_to_one and (xlim is not None or ylim is not None): assert xlim == ylim if xlim: keep_idx = utils.ensure_arr((x >= xlim[0]).multiply(x <= xlim[1])) x = utils.ensure_arr(x[keep_idx]) y = utils.ensure_arr(y[keep_idx]) if ylim: keep_idx = utils.ensure_arr((y >= ylim[0]).multiply(x <= xlim[1])) x = utils.ensure_arr(x[keep_idx]) y = utils.ensure_arr(y[keep_idx]) # x and y may or may not be sparse at this point assert x.shape == y.shape if subset > 0 and subset < x.size: logging.info(f"Subsetting to {subset} points") random.seed(1234) # Converts flat index to coordinates indices = np.unravel_index(np.array( random.sample(range(np.product(x.shape)), k=subset)), shape=x.shape) x = utils.ensure_arr(x[indices]) y = utils.ensure_arr(y[indices]) if isinstance(color, (tuple, list, np.ndarray)): color = np.array([color[i] for i in indices]) if logscale: x = np.log1p(x) y = np.log1p(y) # Ensure correct format x = utils.ensure_arr(x).flatten() y = utils.ensure_arr(y).flatten() assert not np.any(np.isnan(x)) assert not np.any(np.isnan(y)) pearson_r, pearson_p = scipy.stats.pearsonr(x, y) logging.info( f"Found pearson's correlation/p of {pearson_r:.4f}/{pearson_p:.4g}") spearman_corr, spearman_p = scipy.stats.spearmanr(x, y) logging.info( f"Found spearman's collelation/p of {spearman_corr:.4f}/{spearman_p:.4g}" ) if ax is None: fig = plt.figure(dpi=300, figsize=figsize) if density_heatmap: # https://github.com/astrofrog/mpl-scatter-density ax = fig.add_subplot(1, 1, 1, projection="scatter_density") else: ax = fig.add_subplot(1, 1, 1) else: fig = None if density_heatmap: norm = None if density_logstretch: norm = ImageNormalize(vmin=0, vmax=100, stretch=LogStretch(a=density_logstretch)) ax.scatter_density(x, y, dpi=density_dpi, norm=norm, color="tab:blue") else: ax.scatter(x, y, alpha=0.2, c=color) if one_to_one: unit = np.linspace(*ax.get_xlim()) ax.plot(unit, unit, linestyle="--", alpha=0.5, label="$y=x$", color="grey") ax.legend() ax.set( xlabel=xlabel + (" (log)" if logscale else ""), ylabel=ylabel + (" (log)" if logscale else ""), title=(title + f" ($r={pearson_r:.2f}$)").strip(), ) if xlim: ax.set(xlim=xlim) if ylim: ax.set(ylim=ylim) if fig is not None and fname: fig.savefig(fname, dpi=SAVEFIG_DPI, bbox_inches="tight") return fig
def archr_gene_activity_matrix_from_adata( adata: ad.AnnData, annotation: str = sc_data_loaders.HG38_GTF, gene_model: Callable = lambda x: np.exp(-np.abs(x) / 5000) + np.exp(-1), extension: int = 100000, gene_upstream: int = 5000, use_gene_boundaries: bool = True, use_TSS: bool = False, gene_scale_factor: float = 5.0, # Numeric scaling factor to weight genes based on inverse of length ceiling: float = 4.0, scale_to: float = 10000.0, ) -> ad.AnnData: """ Use the more sophisiticated gene activity scoring method described here: https://www.archrproject.com/bookdown/calculating-gene-scores-in-archr.html https://github.com/GreenleafLab/ArchR/blob/ddcaae4a6093685875052219141e5ea41030fc55/R/MatrixGeneScores.R Note this isn't a FULL reimplementation since we do slightly different handling of distances and tiling and such, but this should be a very close approximation Some other notes: - By default, ArchR appears to be doing distance calculations based on the entire gene body, which we do """ gene_to_pos = utils.read_gtf_gene_to_pos(annotation, extend_upstream=gene_upstream) genes = list(gene_to_pos.keys()) # Map each gene and bin to a corresponding index in axis gene_to_idx = {g: i for i, g in enumerate(genes)} bin_to_idx = {b: i for i, b in enumerate(adata.var_names)} # Map of chrom without chr prefix to intervaltree chrom_to_gene_intervals = utils.gene_pos_dict_to_range(gene_to_pos) # Create a mapping of where atac bins are, so we can easily grep for overlap later chrom_to_atac_intervals = collections.defaultdict(itree.IntervalTree) for atac_bin in adata.var_names: chrom, span = atac_bin.split(":") if chrom.startswith("chr"): chrom = chrom.strip("chr") start, stop = map(int, span.split("-")) chrom_to_atac_intervals[chrom][start:stop] = atac_bin # Create a matrix that maps each feature to overlapping genes with weight weights = scipy.sparse.lil_matrix((adata.shape[1], len(genes)), dtype=float) # Create per-gene weights based on size of the gene logging.info("Determining gene weights based on gene size") gene_widths = np.array([gene_to_pos[g][2] - gene_to_pos[g][1] for g in genes]) assert np.all(gene_widths > 0) inv_gene_widths = 1 / gene_widths gene_weight = 1 + inv_gene_widths * (gene_scale_factor - 1) / ( np.max(inv_gene_widths) - np.min(inv_gene_widths) ) assert np.all(gene_weight >= 1) if not np.all(gene_weight <= gene_scale_factor): logging.warning( f"Found values exceeding gene scale factor {gene_scale_factor}: {gene_weight[np.where(gene_weight > gene_scale_factor)]}" ) if not np.all(gene_weight >= 1.0): logging.warning( f"Found values below minimum expected value of 1: {gene_weight[np.where(gene_weight < 1.0)]}" ) logging.info("Constructing bin to gene matrix") for gene in genes: # Compute weight for each gene gene_gi = genomic_interval.GenomicInterval( gene_to_pos[gene], metadata_dict={"gene": gene} ) chrom, start, stop = gene_to_pos[gene] assert start < stop # Get all ATAC bins gene_overlap_atac_bins = chrom_to_atac_intervals[chrom][ start - extension : stop + extension ] # Drop the ATAC bins that overlap a gene that isn't this current gene filtered_gene_overlap_atac_bins = [] for o in gene_overlap_atac_bins: atac_start, atac_end = o.begin, o.end atac_bin_gene_overlaps = chrom_to_gene_intervals[chrom][atac_start:atac_end] is_matched = [g.data != gene for g in atac_bin_gene_overlaps] if any(is_matched): continue filtered_gene_overlap_atac_bins.append(o) # Calculate the distance and the corresponding weight for o in filtered_gene_overlap_atac_bins: bin_gi = genomic_interval.GenomicInterval((chrom, o.begin, o.end)) d = gene_gi.difference(bin_gi) assert d >= 0 w = gene_model(d) assert weights[bin_to_idx[o.data], gene_to_idx[gene]] == 0 # Note, ArchR works on 500bp bins, so may count a little differently # To adjust multiply by teh size of the bin divided by 500 # This approximates how many times each bin might be counted # as fragments v = w * gene_weight[gene_to_idx[gene]] * max(bin_gi.size / 500, 1) weights[bin_to_idx[o.data], gene_to_idx[gene]] = v weights = scipy.sparse.csc_matrix(weights) if ceiling > 0: logging.info( f"Calculating maximum capped counts per bin at {ceiling} per 500bp" ) bin_to_width = lambda x: genomic_interval.GenomicInterval(x).size per_bin_cap = np.array( [max(bin_to_width(b) / 500, 1) * ceiling for b in adata.var_names] ) assert np.all(per_bin_cap >= ceiling) adata.X = np.minimum(utils.ensure_arr(adata.X), per_bin_cap) # Map the ATAC bins to features # Converting to an array is necessary for correct broadcasting logging.info("Calculating gene activity scores") mat = utils.ensure_arr(adata.X @ weights) # Normalize depths if scale_to > 0: logging.info(f"Depth normalizing gene activity scores to {scale_to}") per_cell_depths = np.array(mat.sum(axis=1)).flatten() per_cell_scaling = scale_to / per_cell_depths per_cell_scaling[np.where(per_cell_depths == 0)] = 0.0 mat = scipy.sparse.csr_matrix(mat * per_cell_scaling[:, np.newaxis]) assert np.all( np.logical_or( np.isclose(scale_to, mat.sum(axis=1)), np.isclose(0, mat.sum(axis=1)) ) # Either is the correct size, or is all 0 ) retval = ad.AnnData(mat) if hasattr(adata, "obs_names"): retval.obs_names = adata.obs_names retval.var_names = genes return retval