Example #1
0
def plot_auprc(
    truth,
    preds,
    title_prefix: str = "Precision recall curve",
    fname: str = "",
):
    """Plot AUPRC"""
    truth = utils.ensure_arr(truth).flatten()
    preds = utils.ensure_arr(preds).flatten()

    precision, recall, _thresholds = metrics.precision_recall_curve(
        truth, preds)
    average_precision = metrics.average_precision_score(truth, preds)
    logging.info(f"Found AUPRC of {average_precision:.4f}")

    fig, ax = plt.subplots(dpi=SAVEFIG_DPI, figsize=(7, 5))
    ax.plot(recall, precision)
    ax.set(
        xlabel="Recall",
        ylabel="Precision",
        title=f"{title_prefix} (AUPRC={average_precision:.4f})",
    )
    if fname:
        fig.savefig(fname, bbox_inches="tight")
    return fig
Example #2
0
def do_evaluation_atac_from_rna(
    spliced_net,
    sc_dual_full_dataset,
    gene_names: str,
    atac_names: str,
    outdir: str,
    ext: str,
    marker_genes: List[str],
    prefix: str = "",
):
    ### RNA > ATAC
    logging.info("Inferring ATAC from RNA")
    sc_rna_atac_full_preds = spliced_net.translate_1_to_2(sc_dual_full_dataset)
    sc_rna_atac_full_preds_anndata = sc.AnnData(
        scipy.sparse.csr_matrix(sc_rna_atac_full_preds),
        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
    )
    sc_rna_atac_full_preds_anndata.var_names = atac_names
    logging.info("Writing ATAC from RNA")
    sc_rna_atac_full_preds_anndata.write(
        os.path.join(outdir, f"{prefix}_rna_atac_adata.h5ad".strip("_"))
    )

    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
        logging.info("Plotting ATAC from RNA")
        plot_utils.plot_auroc(
            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
            utils.ensure_arr(sc_rna_atac_full_preds).flatten(),
            title_prefix=f"{DATASET_NAME} RNA > ATAC".strip(),
            fname=os.path.join(outdir, f"{prefix}_rna_atac_auroc.{ext}".strip("_")),
        )
Example #3
0
def main():
    """Run script"""
    parser = build_parser()
    args = parser.parse_args()

    truth = load_file_flex_format(args.truth)
    truth.X = utils.ensure_arr(truth.X)
    logging.info(f"Loaded truth {args.truth}: {truth.shape}")
    preds = load_file_flex_format(args.preds)
    preds.X = utils.ensure_arr(preds.X)
    logging.info(f"Loaded preds {args.preds}: {preds.shape}")

    common_genes = sorted(
        list(set(truth.var_names).intersection(preds.var_names)))
    logging.info(f"Shared genes: {len(common_genes)}")

    common_obs = sorted(
        list(set(truth.obs_names).intersection(preds.obs_names)))
    # All obs naames should intersect between preds and truth
    assert len(common_obs) == len(truth.obs_names) == len(preds.obs_names)

    plot_utils.plot_var_vs_explained_var(
        truth,
        preds,
        highlight_genes={k: REF_MARKER_GENES[k]
                         for k in args.highlight},
        logscale=not args.linear,
        constrain_y_axis=not args.unconstriained,
        label_outliers=args.outliers,
        fname=args.plotname,
        fname_gene_list=args.genelist,
    )
Example #4
0
def plot_auroc(
    truth,
    preds,
    title_prefix: str = "Receiver operating characteristic",
    fname: str = "",
):
    """
    Plot AUROC after flattening inputs
    """
    truth = utils.ensure_arr(truth).flatten()
    preds = utils.ensure_arr(preds).flatten()
    fpr, tpr, _thresholds = metrics.roc_curve(truth, preds)
    auc = metrics.auc(fpr, tpr)
    logging.info(f"Found AUROC of {auc:.4f}")

    fig, ax = plt.subplots(dpi=300, figsize=(7, 5))
    ax.plot(fpr, tpr)
    ax.set(
        xlim=(0, 1.0),
        ylim=(0.0, 1.05),
        xlabel="False positive rate",
        ylabel="True positive rate",
        title=f"{title_prefix} (AUROC={auc:.2f})",
    )
    if fname:
        fig.savefig(fname, dpi=SAVEFIG_DPI, bbox_inches="tight")
    return fig
Example #5
0
def plot_bulk_scatter(
    x: AnnData,
    y: AnnData,
    x_subset: dict = None,
    y_subset: dict = None,
    logscale: bool = True,
    corr_func: Callable = scipy.stats.pearsonr,
    title: str = "",
    xlabel: str = "Average measured counts",
    ylabel: str = "Average inferred counts",
    fname: str = "",
):
    """
    Create bulk signature and plot
    """
    if x_subset is not None:
        orig_size = x.n_obs
        x = adata_utils.filter_adata(x, filt_cells=x_subset)
        logging.info(f"Subsetted x from {orig_size} to {x.n_obs}")
    if y_subset is not None:
        orig_size = y.n_obs
        y = adata_utils.filter_adata(y, filt_cells=y_subset)
        logging.info(f"Subsetted y from {orig_size} to {y.n_obs}")

    # Make sure variables match
    shared_var_names = sorted(list(set(x.var_names).intersection(y.var_names)))
    logging.info(f"Found {len(shared_var_names)} shared variables")
    x = x[:, shared_var_names]
    y = y[:, shared_var_names]

    x_vals = x.X
    y_vals = y.X

    if logscale:
        x_vals = np.log1p(x_vals)
        y_vals = np.log1p(y_vals)

    # Ensure correct format
    x_vals = utils.ensure_arr(x_vals).mean(axis=0)
    y_vals = utils.ensure_arr(y_vals).mean(axis=0)

    assert not np.any(np.isnan(x_vals))
    assert not np.any(np.isnan(y_vals))

    pearson_r, pearson_p = corr_func(x_vals, y_vals)
    logging.info(f"Found pearson's correlation of {pearson_r:.4f}")

    fig, ax = plt.subplots(dpi=SAVEFIG_DPI, figsize=(7, 5))
    ax.scatter(x_vals, y_vals, alpha=0.4)
    ax.set(
        xlabel=xlabel + (" (log)" if logscale else ""),
        ylabel=ylabel + (" (log)" if logscale else ""),
        title=(title + f" ($r={pearson_r:.2f}$)").strip(),
    )
    if fname:
        fig.savefig(fname, bbox_inches="tight")
    return fig
Example #6
0
def load_rna_files(rna_counts_fnames: List[str],
                   model_dir: str,
                   transpose: bool = True) -> ad.AnnData:
    """Load the RNA files in, filling in unmeasured genes as necessary"""
    # Find the genes that the model understands
    rna_genes_list_fname = os.path.join(model_dir, "rna_genes.txt")
    assert os.path.isfile(
        rna_genes_list_fname
    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
    learned_rna_genes = utils.read_delimited_file(rna_genes_list_fname)
    assert isinstance(learned_rna_genes, list)
    assert utils.is_all_unique(
        learned_rna_genes), "Learned genes list contains duplicates"

    temp_ad = utils.sc_read_multi_files(
        rna_counts_fnames,
        feature_type="Gene Expression",
        transpose=transpose,
        join="outer",
    )
    logging.info(f"Read input RNA files for {temp_ad.shape}")
    temp_ad.X = utils.ensure_arr(temp_ad.X)

    # Filter for mouse genes and remove human/mouse prefix
    temp_ad.var_names_make_unique()
    kept_var_names = [
        vname for vname in temp_ad.var_names if not vname.startswith("MOUSE_")
    ]
    if len(kept_var_names) != temp_ad.n_vars:
        temp_ad = temp_ad[:, kept_var_names]
    temp_ad.var = pd.DataFrame(
        index=[v.strip("HUMAN_") for v in kept_var_names])

    # Expand adata to span all genes
    # Initiating as a sparse matrix doesn't allow vectorized building
    intersected_genes = set(temp_ad.var_names).intersection(learned_rna_genes)
    assert intersected_genes, "No overlap between learned and input genes!"
    expanded_mat = np.zeros((temp_ad.n_obs, len(learned_rna_genes)))
    skip_count = 0
    for gene in intersected_genes:
        dest_idx = learned_rna_genes.index(gene)
        src_idx = temp_ad.var_names.get_loc(gene)
        if not isinstance(src_idx, int):
            logging.warn(f"Got multiple source matches for {gene}, skipping")
            skip_count += 1
            continue
        v = utils.ensure_arr(temp_ad.X[:, src_idx]).flatten()
        expanded_mat[:, dest_idx] = v
    if skip_count:
        logging.warning(
            f"Skipped {skip_count}/{len(intersected_genes)} genes due to multiple matches"
        )
    expanded_mat = sparse.csr_matrix(expanded_mat)  # Compress
    retval = ad.AnnData(expanded_mat,
                        obs=temp_ad.obs,
                        var=pd.DataFrame(index=learned_rna_genes))
    return retval
Example #7
0
def do_evaluation_atac_from_atac(
    spliced_net,
    sc_dual_full_dataset,
    gene_names: str,
    atac_names: str,
    outdir: str,
    ext: str,
    marker_genes: List[str],
    prefix: str = "",
):
    ### ATAC > ATAC
    logging.info("Inferring ATAC from ATAC")
    sc_atac_full_preds = spliced_net.translate_2_to_2(sc_dual_full_dataset)
    sc_atac_full_preds_anndata = sc.AnnData(
        sc_atac_full_preds,
        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
    )
    sc_atac_full_preds_anndata.var_names = atac_names
    logging.info("Writing ATAC from ATAC")

    # Infer marker bins
    # logging.info("Getting marker bins for ATAC from ATAC")
    # plot_utils.preprocess_anndata(sc_atac_full_preds_anndata)
    # adata_utils.find_marker_genes(sc_atac_full_preds_anndata)
    # inferred_marker_bins = adata_utils.flatten_marker_genes(
    #     sc_atac_full_preds_anndata.uns["rank_genes_leiden"]
    # )
    # logging.info(f"Found {len(inferred_marker_bins)} marker bins for ATAC from ATAC")
    # with open(
    #     os.path.join(outdir, f"{prefix}_atac_atac_marker_bins.txt".strip("_")), "w"
    # ) as sink:
    #     sink.write("\n".join(inferred_marker_bins) + "\n")

    sc_atac_full_preds_anndata.write(
        os.path.join(outdir, f"{prefix}_atac_atac_adata.h5ad".strip("_"))
    )
    if hasattr(sc_dual_full_dataset.dataset_y, "data_raw") and ext is not None:
        logging.info("Plotting ATAC from ATAC")
        plot_utils.plot_auroc(
            utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
            utils.ensure_arr(sc_atac_full_preds).flatten(),
            title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
            fname=os.path.join(outdir, f"{prefix}_atac_atac_auroc.{ext}".strip("_")),
        )
        # plot_utils.plot_auprc(
        #     utils.ensure_arr(sc_dual_full_dataset.dataset_y.data_raw.X).flatten(),
        #     utils.ensure_arr(sc_atac_full_preds).flatten(),
        #     title_prefix=f"{DATASET_NAME} ATAC > ATAC".strip(),
        #     fname=os.path.join(outdir, f"{prefix}_atac_atac_auprc.{ext}".strip("_")),
        # )

    # Remove some objects to free memory
    del sc_atac_full_preds
    del sc_atac_full_preds_anndata
Example #8
0
def main():
    parser = build_parser()
    args = parser.parse_args()

    _, input_ext = os.path.splitext(args.input)
    if input_ext == ".h5ad":
        x = ad.read_h5ad(args.input)
    elif input_ext == ".h5":
        x = sc.read_10x_h5(args.input)
    else:
        raise ValueError(f"Unrecognized file extension: {args.input}")
    logging.info(f"Read input: {x}")

    logging.info("Reading gtf for gene name map")
    gene_name_map = utils.read_gtf_gene_symbol_to_id()

    # Tranpose because BIRD wants features x obs
    x_df = pd.DataFrame(utils.ensure_arr(x.X),
                        index=x.obs_names,
                        columns=x.var_names).T
    assert np.all(x_df.values >= 0.0)
    x_df.index = [gene_name_map[g] for g in x_df.index]

    # Write output (tab-separated table
    logging.info(f"Writing output to {args.output_table_txt}")
    x_df.to_csv(args.output_table_txt, sep="\t")
Example #9
0
def reindex_adata_vars(adata: AnnData, target_vars: List[str]) -> AnnData:
    """Reindexes the adata to match the given var_list, verbatim"""
    assert len(adata.var_names) == adata.n_vars
    if not utils.is_all_unique(adata.var_names):
        logging.warn("De-duping variable names before reindexing")
        adata.var_names_make_unique()
    assert utils.is_all_unique(target_vars), "Target vars are not all unique"
    intersected = set(adata.var_names).intersection(target_vars)
    logging.info(
        f"Overlap of {len(intersected)}/{adata.n_vars}, 0 vector will be filled in for {len(target_vars) - len(intersected)} 'missing' features"
    )
    vars_to_cols = dict(zip(adata.var_names, utils.ensure_arr(adata.X).T))
    assert (len(vars_to_cols) == adata.n_vars
            ), f"Size mismatch: {len(vars_to_cols)} {adata.n_vars}"

    default_null = np.zeros(adata.n_obs)
    mat = np.vstack([
        vars_to_cols[v] if v in vars_to_cols else np.copy(default_null)
        for v in target_vars
    ]).T
    target_shape = (adata.n_obs, len(target_vars))
    assert mat.shape == target_shape, f"Size mismatch: {mat.shape} {target_shape}"

    retval = AnnData(mat)
    retval.obs_names = adata.obs_names
    retval.var_names = target_vars
    return retval
Example #10
0
def main():
    """Run script"""
    parser = build_parser()
    args = parser.parse_args()

    adata = ad.read_h5ad(args.h5ad)
    logging.info(f"Read {args.h5ad} for adata of {adata.shape}")

    if args.discrete:
        # Use the discrete algorithm from pyitlib
        # https://pafoster.github.io/pyitlib/#discrete_random_variable.entropy_joint
        # https://github.com/pafoster/pyitlib/blob/master/pyitlib/discrete_random_variable.py#L3535
        # Successive realisations of a random variable are indexed by the last axis in the array; multiple random variables may be specified using preceding axes.
        # In other words, different variables are axis 0, samples are axis 1
        # This is contrary to the default ML format which is samples axis 0, variables axes 1
        # Therefore we must transpose
        input_arr = utils.ensure_arr(adata.X).T
        h = drv.entropy_joint(input_arr, base=np.e)
        logging.info(f"Found discrete joint entropy of {h:.6f}")
    else:
        raise NotImplementedError
Example #11
0
def main():
    parser = build_parser()
    args = parser.parse_args()

    if args.x_rna.endswith(".h5ad"):
        x_rna = ad.read_h5ad(args.x_rna)
    elif args.x_rna.endswith(".h5"):
        x_rna = sc.read_10x_h5(args.x_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.x_rna}")
    x_rna.X = utils.ensure_arr(x_rna.X)
    x_rna.obs_names = sanitize_obs_names(x_rna.obs_names)
    x_rna.obs_names_make_unique()
    logging.info(f"Read in {args.x_rna} for {x_rna.shape}")

    if args.y_rna.endswith(".h5ad"):
        y_rna = ad.read_h5ad(args.y_rna)
    elif args.y_rna.endswith(".h5"):
        y_rna = sc.read_10x_h5(args.y_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.y_rna}")
    y_rna.X = utils.ensure_arr(y_rna.X)
    y_rna.obs_names = sanitize_obs_names(y_rna.obs_names)
    y_rna.obs_names_make_unique()
    logging.info(f"Read in {args.y_rna} for {y_rna.shape}")

    if not (len(x_rna.obs_names) == len(y_rna.obs_names)
            and np.all(x_rna.obs_names == y_rna.obs_names)):
        logging.warning("Rematching obs axis")
        shared_obs_names = sorted(
            list(set(x_rna.obs_names).intersection(y_rna.obs_names)))
        logging.info(f"Found {len(shared_obs_names)} shared obs")
        assert shared_obs_names, ("Got empty list of shared obs" + "\n" +
                                  str(x_rna.obs_names) + "\n" +
                                  str(y_rna.obs_names))
        x_rna = x_rna[shared_obs_names]
        y_rna = y_rna[shared_obs_names]
    assert np.all(x_rna.obs_names == y_rna.obs_names)
    if not (len(x_rna.var_names) == len(y_rna.var_names)
            and np.all(x_rna.var_names == y_rna.var_names)):
        logging.warning("Rematching variable axis")
        shared_var_names = sorted(
            list(set(x_rna.var_names).intersection(y_rna.var_names)))
        logging.info(f"Found {len(shared_var_names)} shared variables")
        assert shared_var_names, ("Got empty list of shared vars" + "\n" +
                                  str(x_rna.var_names) + "\n" +
                                  str(y_rna.var_names))
        x_rna = x_rna[:, shared_var_names]
        y_rna = y_rna[:, shared_var_names]
    assert np.all(x_rna.var_names == y_rna.var_names)

    # Subset by gene list if given
    if args.genelist:
        gene_list = utils.read_delimited_file(args.genelist)
        logging.info(f"Read {len(gene_list)} genes from {args.genelist}")
        x_rna = x_rna[:, gene_list]
        y_rna = y_rna[:, gene_list]

    assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}"

    fig = plot_utils.plot_scatter_with_r(
        x_rna.X,
        y_rna.X,
        subset=args.subset,
        one_to_one=True,
        logscale=not args.linear,
        density_heatmap=args.density,
        density_logstretch=args.densitylogstretch,
        fname=args.outfname,
        title=args.title,
        xlabel=args.xlabel,
        ylabel=args.ylabel,
        figsize=args.figsize,
    )
Example #12
0
def plot_scatter_with_r(
        x: Union[np.ndarray, scipy.sparse.csr_matrix],
        y: Union[np.ndarray, scipy.sparse.csr_matrix],
        color=None,
        subset: int = 0,
        logscale: bool = False,
        density_heatmap: bool = False,
        density_dpi: int = 150,
        density_logstretch: int = 1000,
        title: str = "",
        xlabel: str = "Original norm counts",
        ylabel: str = "Inferred norm counts",
        xlim: Tuple[int, int] = None,
        ylim: Tuple[int, int] = None,
        one_to_one: bool = False,
        corr_func: Callable = scipy.stats.pearsonr,
        figsize: Tuple[float, float] = (7, 5),
        fname: str = "",
        ax=None,
):
    """
    Plot the given x y coordinates, appending Pearsons r
    Setting xlim/ylim will affect both plot and R2 calculation
    In other words, plot view mirrors the range for which correlation is calculated
    """
    assert x.shape == y.shape, f"Mismatched shapes: {x.shape} {y.shape}"
    if color is not None:
        assert color.size == x.size
    if one_to_one and (xlim is not None or ylim is not None):
        assert xlim == ylim
    if xlim:
        keep_idx = utils.ensure_arr((x >= xlim[0]).multiply(x <= xlim[1]))
        x = utils.ensure_arr(x[keep_idx])
        y = utils.ensure_arr(y[keep_idx])
    if ylim:
        keep_idx = utils.ensure_arr((y >= ylim[0]).multiply(x <= xlim[1]))
        x = utils.ensure_arr(x[keep_idx])
        y = utils.ensure_arr(y[keep_idx])
    # x and y may or may not be sparse at this point
    assert x.shape == y.shape
    if subset > 0 and subset < x.size:
        logging.info(f"Subsetting to {subset} points")
        random.seed(1234)
        # Converts flat index to coordinates
        indices = np.unravel_index(np.array(
            random.sample(range(np.product(x.shape)), k=subset)),
                                   shape=x.shape)
        x = utils.ensure_arr(x[indices])
        y = utils.ensure_arr(y[indices])
        if isinstance(color, (tuple, list, np.ndarray)):
            color = np.array([color[i] for i in indices])

    if logscale:
        x = np.log1p(x)
        y = np.log1p(y)

    # Ensure correct format
    x = utils.ensure_arr(x).flatten()
    y = utils.ensure_arr(y).flatten()
    assert not np.any(np.isnan(x))
    assert not np.any(np.isnan(y))

    pearson_r, pearson_p = scipy.stats.pearsonr(x, y)
    logging.info(
        f"Found pearson's correlation/p of {pearson_r:.4f}/{pearson_p:.4g}")
    spearman_corr, spearman_p = scipy.stats.spearmanr(x, y)
    logging.info(
        f"Found spearman's collelation/p of {spearman_corr:.4f}/{spearman_p:.4g}"
    )

    if ax is None:
        fig = plt.figure(dpi=300, figsize=figsize)
        if density_heatmap:
            # https://github.com/astrofrog/mpl-scatter-density
            ax = fig.add_subplot(1, 1, 1, projection="scatter_density")
        else:
            ax = fig.add_subplot(1, 1, 1)
    else:
        fig = None

    if density_heatmap:
        norm = None
        if density_logstretch:
            norm = ImageNormalize(vmin=0,
                                  vmax=100,
                                  stretch=LogStretch(a=density_logstretch))
        ax.scatter_density(x, y, dpi=density_dpi, norm=norm, color="tab:blue")
    else:
        ax.scatter(x, y, alpha=0.2, c=color)

    if one_to_one:
        unit = np.linspace(*ax.get_xlim())
        ax.plot(unit,
                unit,
                linestyle="--",
                alpha=0.5,
                label="$y=x$",
                color="grey")
        ax.legend()
    ax.set(
        xlabel=xlabel + (" (log)" if logscale else ""),
        ylabel=ylabel + (" (log)" if logscale else ""),
        title=(title + f" ($r={pearson_r:.2f}$)").strip(),
    )
    if xlim:
        ax.set(xlim=xlim)
    if ylim:
        ax.set(ylim=ylim)

    if fig is not None and fname:
        fig.savefig(fname, dpi=SAVEFIG_DPI, bbox_inches="tight")

    return fig
Example #13
0
def archr_gene_activity_matrix_from_adata(
    adata: ad.AnnData,
    annotation: str = sc_data_loaders.HG38_GTF,
    gene_model: Callable = lambda x: np.exp(-np.abs(x) / 5000) + np.exp(-1),
    extension: int = 100000,
    gene_upstream: int = 5000,
    use_gene_boundaries: bool = True,
    use_TSS: bool = False,
    gene_scale_factor: float = 5.0,  # Numeric scaling factor to weight genes based on inverse of length
    ceiling: float = 4.0,
    scale_to: float = 10000.0,
) -> ad.AnnData:
    """
    Use the more sophisiticated gene activity scoring method described here:
    https://www.archrproject.com/bookdown/calculating-gene-scores-in-archr.html
    https://github.com/GreenleafLab/ArchR/blob/ddcaae4a6093685875052219141e5ea41030fc55/R/MatrixGeneScores.R

    Note this isn't a FULL reimplementation since we do slightly different handling of distances
    and tiling and such, but this should be a very close approximation

    Some other notes:
    - By default, ArchR appears to be doing distance calculations based on the entire gene body, which we do
    """
    gene_to_pos = utils.read_gtf_gene_to_pos(annotation, extend_upstream=gene_upstream)
    genes = list(gene_to_pos.keys())
    # Map each gene and bin to a corresponding index in axis
    gene_to_idx = {g: i for i, g in enumerate(genes)}
    bin_to_idx = {b: i for i, b in enumerate(adata.var_names)}

    # Map of chrom without chr prefix to intervaltree
    chrom_to_gene_intervals = utils.gene_pos_dict_to_range(gene_to_pos)
    # Create a mapping of where atac bins are, so we can easily grep for overlap later
    chrom_to_atac_intervals = collections.defaultdict(itree.IntervalTree)
    for atac_bin in adata.var_names:
        chrom, span = atac_bin.split(":")
        if chrom.startswith("chr"):
            chrom = chrom.strip("chr")
        start, stop = map(int, span.split("-"))
        chrom_to_atac_intervals[chrom][start:stop] = atac_bin

    # Create a matrix that maps each feature to overlapping genes with weight
    weights = scipy.sparse.lil_matrix((adata.shape[1], len(genes)), dtype=float)

    # Create per-gene weights based on size of the gene
    logging.info("Determining gene weights based on gene size")
    gene_widths = np.array([gene_to_pos[g][2] - gene_to_pos[g][1] for g in genes])
    assert np.all(gene_widths > 0)
    inv_gene_widths = 1 / gene_widths
    gene_weight = 1 + inv_gene_widths * (gene_scale_factor - 1) / (
        np.max(inv_gene_widths) - np.min(inv_gene_widths)
    )
    assert np.all(gene_weight >= 1)
    if not np.all(gene_weight <= gene_scale_factor):
        logging.warning(
            f"Found values exceeding gene scale factor {gene_scale_factor}: {gene_weight[np.where(gene_weight > gene_scale_factor)]}"
        )
    if not np.all(gene_weight >= 1.0):
        logging.warning(
            f"Found values below minimum expected value of 1: {gene_weight[np.where(gene_weight < 1.0)]}"
        )

    logging.info("Constructing bin to gene matrix")
    for gene in genes:  # Compute weight for each gene
        gene_gi = genomic_interval.GenomicInterval(
            gene_to_pos[gene], metadata_dict={"gene": gene}
        )
        chrom, start, stop = gene_to_pos[gene]
        assert start < stop
        # Get all ATAC bins
        gene_overlap_atac_bins = chrom_to_atac_intervals[chrom][
            start - extension : stop + extension
        ]
        # Drop the ATAC bins that overlap a gene that isn't this current gene
        filtered_gene_overlap_atac_bins = []
        for o in gene_overlap_atac_bins:
            atac_start, atac_end = o.begin, o.end
            atac_bin_gene_overlaps = chrom_to_gene_intervals[chrom][atac_start:atac_end]
            is_matched = [g.data != gene for g in atac_bin_gene_overlaps]
            if any(is_matched):
                continue
            filtered_gene_overlap_atac_bins.append(o)
        # Calculate the distance and the corresponding weight
        for o in filtered_gene_overlap_atac_bins:
            bin_gi = genomic_interval.GenomicInterval((chrom, o.begin, o.end))
            d = gene_gi.difference(bin_gi)
            assert d >= 0
            w = gene_model(d)
            assert weights[bin_to_idx[o.data], gene_to_idx[gene]] == 0
            # Note, ArchR works on 500bp bins, so may count a little differently
            # To adjust multiply by teh size of the bin divided by 500
            # This approximates how many times each bin might be counted
            # as fragments
            v = w * gene_weight[gene_to_idx[gene]] * max(bin_gi.size / 500, 1)
            weights[bin_to_idx[o.data], gene_to_idx[gene]] = v
    weights = scipy.sparse.csc_matrix(weights)

    if ceiling > 0:
        logging.info(
            f"Calculating maximum capped counts per bin at {ceiling} per 500bp"
        )
        bin_to_width = lambda x: genomic_interval.GenomicInterval(x).size
        per_bin_cap = np.array(
            [max(bin_to_width(b) / 500, 1) * ceiling for b in adata.var_names]
        )
        assert np.all(per_bin_cap >= ceiling)
        adata.X = np.minimum(utils.ensure_arr(adata.X), per_bin_cap)

    # Map the ATAC bins to features
    # Converting to an array is necessary for correct broadcasting
    logging.info("Calculating gene activity scores")
    mat = utils.ensure_arr(adata.X @ weights)

    # Normalize depths
    if scale_to > 0:
        logging.info(f"Depth normalizing gene activity scores to {scale_to}")
        per_cell_depths = np.array(mat.sum(axis=1)).flatten()
        per_cell_scaling = scale_to / per_cell_depths
        per_cell_scaling[np.where(per_cell_depths == 0)] = 0.0
        mat = scipy.sparse.csr_matrix(mat * per_cell_scaling[:, np.newaxis])
        assert np.all(
            np.logical_or(
                np.isclose(scale_to, mat.sum(axis=1)), np.isclose(0, mat.sum(axis=1))
            )  # Either is the correct size, or is all 0
        )

    retval = ad.AnnData(mat)
    if hasattr(adata, "obs_names"):
        retval.obs_names = adata.obs_names
    retval.var_names = genes
    return retval