Exemplo n.º 1
0
def do_evaluation_rna_from_atac(
    spliced_net,
    sc_dual_full_dataset,
    gene_names: str,
    atac_names: str,
    outdir: str,
    ext: str,
    marker_genes: List[str],
    prefix: str = "",
):
    ### ATAC > RNA
    logging.info("Inferring RNA from ATAC")
    sc_atac_rna_full_preds = spliced_net.translate_2_to_1(sc_dual_full_dataset)
    # Seurat expects everything to be sparse
    # https://github.com/satijalab/seurat/issues/2228
    sc_atac_rna_full_preds_anndata = sc.AnnData(
        sc_atac_rna_full_preds,
        obs=sc_dual_full_dataset.dataset_y.data_raw.obs.copy(deep=True),
    )
    sc_atac_rna_full_preds_anndata.var_names = gene_names
    logging.info("Writing RNA from ATAC")

    # Seurat also expects the raw attribute to be populated
    sc_atac_rna_full_preds_anndata.raw = sc_atac_rna_full_preds_anndata.copy()
    sc_atac_rna_full_preds_anndata.write(
        os.path.join(outdir, f"{prefix}_atac_rna_adata.h5ad".strip("_"))
    )
    # sc_atac_rna_full_preds_anndata.write_csvs(
    #     os.path.join(outdir, f"{prefix}_atac_rna_constituent_csv".strip("_")),
    #     skip_data=False,
    # )
    # sc_atac_rna_full_preds_anndata.to_df().to_csv(
    #     os.path.join(outdir, f"{prefix}_atac_rna_table.csv".strip("_"))
    # )

    # If there eixsts a ground truth RNA, do RNA plotting
    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
        logging.info("Plotting RNA from ATAC")
        plot_utils.plot_scatter_with_r(
            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
            sc_atac_rna_full_preds,
            one_to_one=True,
            logscale=True,
            density_heatmap=True,
            title=f"{DATASET_NAME} ATAC > RNA".strip(),
            fname=os.path.join(outdir, f"{prefix}_atac_rna_log.{ext}".strip("_")),
        )

    # Remove objects to free memory
    del sc_atac_rna_full_preds
    del sc_atac_rna_full_preds_anndata
Exemplo n.º 2
0
def do_evaluation_rna_from_rna(
    spliced_net,
    sc_dual_full_dataset,
    gene_names: str,
    atac_names: str,
    outdir: str,
    ext: str,
    marker_genes: List[str],
    prefix: str = "",
):
    """
    Evaluate the given network on the dataset
    """
    # Do inference and plotting
    ### RNA > RNA
    logging.info("Inferring RNA from RNA...")
    sc_rna_full_preds = spliced_net.translate_1_to_1(sc_dual_full_dataset)
    sc_rna_full_preds_anndata = sc.AnnData(
        sc_rna_full_preds,
        obs=sc_dual_full_dataset.dataset_x.data_raw.obs,
    )
    sc_rna_full_preds_anndata.var_names = gene_names

    logging.info("Writing RNA from RNA")
    sc_rna_full_preds_anndata.write(
        os.path.join(outdir, f"{prefix}_rna_rna_adata.h5ad".strip("_"))
    )
    if hasattr(sc_dual_full_dataset.dataset_x, "size_norm_counts") and ext is not None:
        logging.info("Plotting RNA from RNA")
        plot_utils.plot_scatter_with_r(
            sc_dual_full_dataset.dataset_x.size_norm_counts.X,
            sc_rna_full_preds,
            one_to_one=True,
            logscale=True,
            density_heatmap=True,
            title=f"{DATASET_NAME} RNA > RNA".strip(),
            fname=os.path.join(outdir, f"{prefix}_rna_rna_log.{ext}".strip("_")),
        )
Exemplo n.º 3
0
def main():
    parser = build_parser()
    args = parser.parse_args()

    if args.x_rna.endswith(".h5ad"):
        x_rna = ad.read_h5ad(args.x_rna)
    elif args.x_rna.endswith(".h5"):
        x_rna = sc.read_10x_h5(args.x_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.x_rna}")
    x_rna.X = utils.ensure_arr(x_rna.X)
    x_rna.obs_names = sanitize_obs_names(x_rna.obs_names)
    x_rna.obs_names_make_unique()
    logging.info(f"Read in {args.x_rna} for {x_rna.shape}")

    if args.y_rna.endswith(".h5ad"):
        y_rna = ad.read_h5ad(args.y_rna)
    elif args.y_rna.endswith(".h5"):
        y_rna = sc.read_10x_h5(args.y_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.y_rna}")
    y_rna.X = utils.ensure_arr(y_rna.X)
    y_rna.obs_names = sanitize_obs_names(y_rna.obs_names)
    y_rna.obs_names_make_unique()
    logging.info(f"Read in {args.y_rna} for {y_rna.shape}")

    if not (len(x_rna.obs_names) == len(y_rna.obs_names)
            and np.all(x_rna.obs_names == y_rna.obs_names)):
        logging.warning("Rematching obs axis")
        shared_obs_names = sorted(
            list(set(x_rna.obs_names).intersection(y_rna.obs_names)))
        logging.info(f"Found {len(shared_obs_names)} shared obs")
        assert shared_obs_names, ("Got empty list of shared obs" + "\n" +
                                  str(x_rna.obs_names) + "\n" +
                                  str(y_rna.obs_names))
        x_rna = x_rna[shared_obs_names]
        y_rna = y_rna[shared_obs_names]
    assert np.all(x_rna.obs_names == y_rna.obs_names)
    if not (len(x_rna.var_names) == len(y_rna.var_names)
            and np.all(x_rna.var_names == y_rna.var_names)):
        logging.warning("Rematching variable axis")
        shared_var_names = sorted(
            list(set(x_rna.var_names).intersection(y_rna.var_names)))
        logging.info(f"Found {len(shared_var_names)} shared variables")
        assert shared_var_names, ("Got empty list of shared vars" + "\n" +
                                  str(x_rna.var_names) + "\n" +
                                  str(y_rna.var_names))
        x_rna = x_rna[:, shared_var_names]
        y_rna = y_rna[:, shared_var_names]
    assert np.all(x_rna.var_names == y_rna.var_names)

    # Subset by gene list if given
    if args.genelist:
        gene_list = utils.read_delimited_file(args.genelist)
        logging.info(f"Read {len(gene_list)} genes from {args.genelist}")
        x_rna = x_rna[:, gene_list]
        y_rna = y_rna[:, gene_list]

    assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}"

    fig = plot_utils.plot_scatter_with_r(
        x_rna.X,
        y_rna.X,
        subset=args.subset,
        one_to_one=True,
        logscale=not args.linear,
        density_heatmap=args.density,
        density_logstretch=args.densitylogstretch,
        fname=args.outfname,
        title=args.title,
        xlabel=args.xlabel,
        ylabel=args.ylabel,
        figsize=args.figsize,
    )
Exemplo n.º 4
0
def main():
    """Run the script"""
    parser = build_parser()
    args = parser.parse_args()
    args.outdir = os.path.abspath(args.outdir)

    if not os.path.isdir(os.path.dirname(args.outdir)):
        os.makedirs(os.path.dirname(args.outdir))

    # Specify output log file
    logger = logging.getLogger()
    fh = logging.FileHandler(f"{args.outdir}_training.log", "w")
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Log parameters and pytorch version
    if torch.cuda.is_available():
        logging.info(f"PyTorch CUDA version: {torch.version.cuda}")
    for arg in vars(args):
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")

    # Borrow parameters
    logging.info("Reading RNA data")
    if args.snareseq:
        rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS)
    elif args.shareseq:
        logging.info(f"Loading in SHAREseq RNA data for: {args.shareseq}")
        rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS)
        rna_data_kwargs["fname"] = None
        rna_data_kwargs["reader"] = None
        rna_data_kwargs["cell_info"] = None
        rna_data_kwargs["gene_info"] = None
        rna_data_kwargs["transpose"] = False
        # Load in the datasets
        shareseq_rna_adatas = []
        for tissuetype in args.shareseq:
            shareseq_rna_adatas.append(
                adata_utils.load_shareseq_data(
                    tissuetype,
                    dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq",
                    mode="RNA",
                ))
        shareseq_rna_adata = shareseq_rna_adatas[0]
        if len(shareseq_rna_adatas) > 1:
            shareseq_rna_adata = shareseq_rna_adata.concatenate(
                *shareseq_rna_adatas[1:],
                join="inner",
                batch_key="tissue",
                batch_categories=args.shareseq,
            )
        rna_data_kwargs["raw_adata"] = shareseq_rna_adata
    else:
        rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
        rna_data_kwargs["fname"] = args.data
        if args.nofilter:
            rna_data_kwargs = {
                k: v
                for k, v in rna_data_kwargs.items()
                if not k.startswith("filt_")
            }
    rna_data_kwargs["data_split_by_cluster_log"] = not args.linear
    rna_data_kwargs["data_split_by_cluster"] = args.clustermethod

    sc_rna_dataset = sc_data_loaders.SingleCellDataset(
        valid_cluster_id=args.validcluster,
        test_cluster_id=args.testcluster,
        **rna_data_kwargs,
    )

    sc_rna_train_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_rna_dataset,
        split="train",
    )
    sc_rna_valid_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_rna_dataset,
        split="valid",
    )
    sc_rna_test_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_rna_dataset,
        split="test",
    )

    # ATAC
    logging.info("Aggregating ATAC clusters")
    if args.snareseq:
        atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS)
    elif args.shareseq:
        logging.info(f"Loading in SHAREseq ATAC data for {args.shareseq}")
        atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS)
        atac_data_kwargs["reader"] = None
        atac_data_kwargs["fname"] = None
        atac_data_kwargs["cell_info"] = None
        atac_data_kwargs["gene_info"] = None
        atac_data_kwargs["transpose"] = False
        atac_adatas = []
        for tissuetype in args.shareseq:
            atac_adatas.append(
                adata_utils.load_shareseq_data(
                    tissuetype,
                    dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq",
                    mode="ATAC",
                ))
        atac_bins = [a.var_names for a in atac_adatas]
        if len(atac_adatas) > 1:
            atac_bins_harmonized = sc_data_loaders.harmonize_atac_intervals(
                *atac_bins)
            atac_adatas = [
                sc_data_loaders.repool_atac_bins(a, atac_bins_harmonized)
                for a in atac_adatas
            ]
        shareseq_atac_adata = atac_adatas[0]
        if len(atac_adatas) > 1:
            shareseq_atac_adata = shareseq_atac_adata.concatenate(
                *atac_adatas[1:],
                join="inner",
                batch_key="tissue",
                batch_categories=args.shareseq,
            )
        atac_data_kwargs["raw_adata"] = shareseq_atac_adata
    else:
        atac_parsed = [
            utils.sc_read_10x_h5_ft_type(fname, "Peaks") for fname in args.data
        ]
        if len(atac_parsed) > 1:
            atac_bins = sc_data_loaders.harmonize_atac_intervals(
                atac_parsed[0].var_names, atac_parsed[1].var_names)
            for bins in atac_parsed[2:]:
                atac_bins = sc_data_loaders.harmonize_atac_intervals(
                    atac_bins, bins.var_names)
            logging.info(f"Aggregated {len(atac_bins)} bins")
        else:
            atac_bins = list(atac_parsed[0].var_names)

        atac_data_kwargs = copy.copy(
            sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS)
        atac_data_kwargs["fname"] = rna_data_kwargs["fname"]
        atac_data_kwargs["pool_genomic_interval"] = 0  # Do not pool
        atac_data_kwargs["reader"] = functools.partial(
            utils.sc_read_multi_files,
            reader=lambda x: sc_data_loaders.repool_atac_bins(
                utils.sc_read_10x_h5_ft_type(x, "Peaks"),
                atac_bins,
            ),
        )
    atac_data_kwargs["cluster_res"] = 0  # Do not bother clustering ATAC data

    sc_atac_dataset = sc_data_loaders.SingleCellDataset(
        predefined_split=sc_rna_dataset, **atac_data_kwargs)
    sc_atac_train_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_atac_dataset,
        split="train",
    )
    sc_atac_valid_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_atac_dataset,
        split="valid",
    )
    sc_atac_test_dataset = sc_data_loaders.SingleCellDatasetSplit(
        sc_atac_dataset,
        split="test",
    )

    sc_dual_train_dataset = sc_data_loaders.PairedDataset(
        sc_rna_train_dataset,
        sc_atac_train_dataset,
        flat_mode=True,
    )
    sc_dual_valid_dataset = sc_data_loaders.PairedDataset(
        sc_rna_valid_dataset,
        sc_atac_valid_dataset,
        flat_mode=True,
    )
    sc_dual_test_dataset = sc_data_loaders.PairedDataset(
        sc_rna_test_dataset,
        sc_atac_test_dataset,
        flat_mode=True,
    )
    sc_dual_full_dataset = sc_data_loaders.PairedDataset(
        sc_rna_dataset,
        sc_atac_dataset,
        flat_mode=True,
    )

    # Model
    param_combos = list(
        itertools.product(args.hidden, args.lossweight, args.lr,
                          args.batchsize, args.seed))
    for h_dim, lw, lr, bs, rand_seed in param_combos:
        outdir_name = (
            f"{args.outdir}_hidden_{h_dim}_lossweight_{lw}_lr_{lr}_batchsize_{bs}_seed_{rand_seed}"
            if len(param_combos) > 1 else args.outdir)
        if not os.path.isdir(outdir_name):
            assert not os.path.exists(outdir_name)
            os.makedirs(outdir_name)
        assert os.path.isdir(outdir_name)
        with open(os.path.join(outdir_name, "rna_genes.txt"), "w") as sink:
            for gene in sc_rna_dataset.data_raw.var_names:
                sink.write(gene + "\n")
        with open(os.path.join(outdir_name, "atac_bins.txt"), "w") as sink:
            for atac_bin in sc_atac_dataset.data_raw.var_names:
                sink.write(atac_bin + "\n")

        # Write dataset
        ### Full
        sc_rna_dataset.size_norm_counts.write_h5ad(
            os.path.join(outdir_name, "full_rna.h5ad"))
        sc_rna_dataset.size_norm_log_counts.write_h5ad(
            os.path.join(outdir_name, "full_rna_log.h5ad"))
        sc_atac_dataset.data_raw.write_h5ad(
            os.path.join(outdir_name, "full_atac.h5ad"))
        ### Train
        sc_rna_train_dataset.size_norm_counts.write_h5ad(
            os.path.join(outdir_name, "train_rna.h5ad"))
        sc_atac_train_dataset.data_raw.write_h5ad(
            os.path.join(outdir_name, "train_atac.h5ad"))
        ### Valid
        sc_rna_valid_dataset.size_norm_counts.write_h5ad(
            os.path.join(outdir_name, "valid_rna.h5ad"))
        sc_atac_valid_dataset.data_raw.write_h5ad(
            os.path.join(outdir_name, "valid_atac.h5ad"))
        ### Test
        sc_rna_test_dataset.size_norm_counts.write_h5ad(
            os.path.join(outdir_name, "truth_rna.h5ad"))
        sc_atac_dataset.data_raw.write_h5ad(
            os.path.join(outdir_name, "full_atac.h5ad"))
        sc_atac_test_dataset.data_raw.write_h5ad(
            os.path.join(outdir_name, "truth_atac.h5ad"))

        # Instantiate and train model
        model_class = (autoencoders.NaiveSplicedAutoEncoder
                       if args.naive else autoencoders.AssymSplicedAutoEncoder)
        spliced_net = autoencoders.SplicedAutoEncoderSkorchNet(
            module=model_class,
            module__hidden_dim=h_dim,  # Based on hyperparam tuning
            module__input_dim1=sc_rna_dataset.data_raw.shape[1],
            module__input_dim2=sc_atac_dataset.get_per_chrom_feature_count(),
            module__final_activations1=[
                activations.Exp(),
                activations.ClippedSoftplus(),
            ],
            module__final_activations2=nn.Sigmoid(),
            module__flat_mode=True,
            module__seed=rand_seed,
            lr=lr,  # Based on hyperparam tuning
            criterion=loss_functions.QuadLoss,
            criterion__loss2=loss_functions.
            BCELoss,  # handle output of encoded layer
            criterion__loss2_weight=
            lw,  # numerically balance the two losses with different magnitudes
            criterion__record_history=True,
            optimizer=OPTIMIZER_DICT[args.optim],
            iterator_train__shuffle=True,
            device=utils.get_device(args.device),
            batch_size=bs,  # Based on  hyperparam tuning
            max_epochs=500,
            callbacks=[
                skorch.callbacks.EarlyStopping(patience=args.earlystop),
                skorch.callbacks.LRScheduler(
                    policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
                    **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS,
                ),
                skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
                skorch.callbacks.Checkpoint(
                    dirname=outdir_name,
                    fn_prefix="net_",
                    monitor="valid_loss_best",
                ),
            ],
            train_split=skorch.helper.predefined_split(sc_dual_valid_dataset),
            iterator_train__num_workers=8,
            iterator_valid__num_workers=8,
        )
        if args.pretrain:
            # Load in the warm start parameters
            spliced_net.load_params(f_params=args.pretrain)
            spliced_net.partial_fit(sc_dual_train_dataset, y=None)
        else:
            spliced_net.fit(sc_dual_train_dataset, y=None)

        fig = plot_loss_history(spliced_net.history,
                                os.path.join(outdir_name, f"loss.{args.ext}"))
        plt.close(fig)

        logging.info("Evaluating on test set")
        logging.info("Evaluating RNA > RNA")
        sc_rna_test_preds = spliced_net.translate_1_to_1(sc_dual_test_dataset)
        sc_rna_test_preds_anndata = sc.AnnData(
            sc_rna_test_preds,
            var=sc_rna_test_dataset.data_raw.var,
            obs=sc_rna_test_dataset.data_raw.obs,
        )
        sc_rna_test_preds_anndata.write_h5ad(
            os.path.join(outdir_name, "rna_rna_test_preds.h5ad"))
        fig = plot_utils.plot_scatter_with_r(
            sc_rna_test_dataset.size_norm_counts.X,
            sc_rna_test_preds,
            one_to_one=True,
            logscale=True,
            density_heatmap=True,
            title="RNA > RNA (test set)",
            fname=os.path.join(outdir_name, f"rna_rna_scatter_log.{args.ext}"),
        )
        plt.close(fig)

        logging.info("Evaluating ATAC > ATAC")
        sc_atac_test_preds = spliced_net.translate_2_to_2(sc_dual_test_dataset)
        sc_atac_test_preds_anndata = sc.AnnData(
            sc_atac_test_preds,
            var=sc_atac_test_dataset.data_raw.var,
            obs=sc_atac_test_dataset.data_raw.obs,
        )
        sc_atac_test_preds_anndata.write_h5ad(
            os.path.join(outdir_name, "atac_atac_test_preds.h5ad"))
        fig = plot_utils.plot_auroc(
            sc_atac_test_dataset.data_raw.X,
            sc_atac_test_preds,
            title_prefix="ATAC > ATAC",
            fname=os.path.join(outdir_name, f"atac_atac_auroc.{args.ext}"),
        )
        plt.close(fig)

        logging.info("Evaluating ATAC > RNA")
        sc_atac_rna_test_preds = spliced_net.translate_2_to_1(
            sc_dual_test_dataset)
        sc_atac_rna_test_preds_anndata = sc.AnnData(
            sc_atac_rna_test_preds,
            var=sc_rna_test_dataset.data_raw.var,
            obs=sc_rna_test_dataset.data_raw.obs,
        )
        sc_atac_rna_test_preds_anndata.write_h5ad(
            os.path.join(outdir_name, "atac_rna_test_preds.h5ad"))
        fig = plot_utils.plot_scatter_with_r(
            sc_rna_test_dataset.size_norm_counts.X,
            sc_atac_rna_test_preds,
            one_to_one=True,
            logscale=True,
            density_heatmap=True,
            title="ATAC > RNA (test set)",
            fname=os.path.join(outdir_name,
                               f"atac_rna_scatter_log.{args.ext}"),
        )
        plt.close(fig)

        logging.info("Evaluating RNA > ATAC")
        sc_rna_atac_test_preds = spliced_net.translate_1_to_2(
            sc_dual_test_dataset)
        sc_rna_atac_test_preds_anndata = sc.AnnData(
            sc_rna_atac_test_preds,
            var=sc_atac_test_dataset.data_raw.var,
            obs=sc_atac_test_dataset.data_raw.obs,
        )
        sc_rna_atac_test_preds_anndata.write_h5ad(
            os.path.join(outdir_name, "rna_atac_test_preds.h5ad"))
        fig = plot_utils.plot_auroc(
            sc_atac_test_dataset.data_raw.X,
            sc_rna_atac_test_preds,
            title_prefix="RNA > ATAC",
            fname=os.path.join(outdir_name, f"rna_atac_auroc.{args.ext}"),
        )
        plt.close(fig)

        del spliced_net
def main():
    """Run the script"""
    parser = build_parser()
    args = parser.parse_args()
    logging.info(f"Truth RNA file: {args.rna_x}")
    logging.info(f"Preds RNA file: {args.rna_y}")

    truth = load_file_flex_format(args.rna_x)
    preds = load_file_flex_format(args.rna_y)

    logging.info(f"Truth shape: {truth.shape}")
    logging.info(f"Preds shape: {preds.shape}")

    if "y" in args.normalize:
        logging.info("Normalizing y inferred input")
        preds = adata_utils.normalize_count_table(preds,
                                                  size_factors=True,
                                                  normalize=False,
                                                  log_trans=False)
    if "x" in args.normalize:
        logging.info("Normalizing x inferred input")
        truth = adata_utils.normalize_count_table(truth,
                                                  size_factors=True,
                                                  normalize=False,
                                                  log_trans=False)

    truth_bulk = pd.Series(np.array(truth.X.sum(axis=0)).flatten(),
                           index=truth.var_names)
    preds_bulk = pd.Series(np.array(preds.X.sum(axis=0)).flatten(),
                           index=preds.var_names)

    common_genes = sorted(
        list(set(truth_bulk.index).intersection(preds_bulk.index)))
    assert common_genes
    logging.info(f"{len(common_genes)} genes in common")

    plot_genes = common_genes

    if args.mode == "random":
        random.seed(1234)
        random.shuffle(plot_genes)
        plot_genes = plot_genes[:5000]
    elif args.mode == "all":
        pass
    else:
        raise ValueError(f"Unrecognized value for mode: {args.mode}")

    truth_bulk = truth_bulk[plot_genes]
    preds_bulk = preds_bulk[plot_genes]

    plot_utils.plot_scatter_with_r(
        truth_bulk,
        preds_bulk,
        subset=0,
        logscale=args.log,
        density_heatmap=args.density,
        xlabel="Reference",
        ylabel="Predicted",
        one_to_one=True,
        title=args.title + f" ({args.mode}, n={len(truth_bulk)})",
        fname=args.plotname,
    )