Пример #1
0
def load_rna_files(rna_counts_fnames: List[str],
                   model_dir: str,
                   transpose: bool = True) -> ad.AnnData:
    """Load the RNA files in, filling in unmeasured genes as necessary"""
    # Find the genes that the model understands
    rna_genes_list_fname = os.path.join(model_dir, "rna_genes.txt")
    assert os.path.isfile(
        rna_genes_list_fname
    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
    learned_rna_genes = utils.read_delimited_file(rna_genes_list_fname)
    assert isinstance(learned_rna_genes, list)
    assert utils.is_all_unique(
        learned_rna_genes), "Learned genes list contains duplicates"

    temp_ad = utils.sc_read_multi_files(
        rna_counts_fnames,
        feature_type="Gene Expression",
        transpose=transpose,
        join="outer",
    )
    logging.info(f"Read input RNA files for {temp_ad.shape}")
    temp_ad.X = utils.ensure_arr(temp_ad.X)

    # Filter for mouse genes and remove human/mouse prefix
    temp_ad.var_names_make_unique()
    kept_var_names = [
        vname for vname in temp_ad.var_names if not vname.startswith("MOUSE_")
    ]
    if len(kept_var_names) != temp_ad.n_vars:
        temp_ad = temp_ad[:, kept_var_names]
    temp_ad.var = pd.DataFrame(
        index=[v.strip("HUMAN_") for v in kept_var_names])

    # Expand adata to span all genes
    # Initiating as a sparse matrix doesn't allow vectorized building
    intersected_genes = set(temp_ad.var_names).intersection(learned_rna_genes)
    assert intersected_genes, "No overlap between learned and input genes!"
    expanded_mat = np.zeros((temp_ad.n_obs, len(learned_rna_genes)))
    skip_count = 0
    for gene in intersected_genes:
        dest_idx = learned_rna_genes.index(gene)
        src_idx = temp_ad.var_names.get_loc(gene)
        if not isinstance(src_idx, int):
            logging.warn(f"Got multiple source matches for {gene}, skipping")
            skip_count += 1
            continue
        v = utils.ensure_arr(temp_ad.X[:, src_idx]).flatten()
        expanded_mat[:, dest_idx] = v
    if skip_count:
        logging.warning(
            f"Skipped {skip_count}/{len(intersected_genes)} genes due to multiple matches"
        )
    expanded_mat = sparse.csr_matrix(expanded_mat)  # Compress
    retval = ad.AnnData(expanded_mat,
                        obs=temp_ad.obs,
                        var=pd.DataFrame(index=learned_rna_genes))
    return retval
Пример #2
0
def main():
    """Run the script"""
    parser = build_parser()
    args = parser.parse_args()

    cell_df = pd.read_csv(
        args.cell_info,
        delimiter=","
        if utils.get_file_extension_no_gz(args.cell_info) == "csv" else "\t",
        index_col=args.cellindexcol,
        header=None if args.noheader else "infer",  # 'infer' is default
    )
    if "Barcodes" in cell_df.columns and args.cellindexcol is not None:
        cell_df.index = cell_df["Barcodes"]
    cell_df.index = cell_df.index.rename("barcode")
    cell_df.columns = cell_df.columns.map(str)

    logging.info(f"Read cell metadata from {args.cell_info} {cell_df.shape}")
    logging.info(f"Cell metadata cols: {cell_df.columns}")
    logging.info(cell_df)

    var_df = pd.read_csv(
        args.var_info,
        delimiter=","
        if utils.get_file_extension_no_gz(args.var_info) == "csv" else "\t",
        index_col=args.varindexcol,
        header=None if args.noheader else "infer",  # 'infer' is default
    )
    if "Feature" in var_df.columns and args.varindexcol is not None:
        var_df.index = [ensure_sane_interval(s) for s in var_df["Feature"]]
    var_df.index = var_df.index.rename("ft")
    var_df.columns = var_df.columns.map(str)
    # var_df.index = var_df.index.map(str)
    logging.info(f"Read variable metadata from {args.var_info} {var_df.shape}")
    logging.info(f"Var metadata cols: {var_df.columns}")
    logging.info(var_df)

    # Transpose because bio considers rows to be features
    adata = ad.read_mtx(args.mat_file).T
    logging.info(f"Read matrix {args.mat_file} {adata.shape}")
    adata.obs = cell_df
    adata.var = var_df
    logging.info(f"Created AnnData object: {adata}")
    logging.info(f"Obs names: {adata.obs_names}")
    logging.info(f"Var names: {adata.var_names}")

    if args.reindexvar:
        assert args.varindexcol is not None, "Must provide var index col to reindex var"
        target_vars = utils.read_delimited_file(args.reindexvar)
        logging.info(
            f"Read {args.reindexvar} for {len(target_vars)} vars to reindex")
        adata = adata_utils.reindex_adata_vars(adata, target_vars)

    adata.X = csr_matrix(adata.X)
    logging.info(f"Writing to {args.out_h5ad}")
    adata.write_h5ad(args.out_h5ad, compression=None)
Пример #3
0
def load_rna_files_for_eval(
    data, checkpoint: str, rna_genes_list_fname: str = "", no_filter: bool = False
):
    """ """
    if not rna_genes_list_fname:
        rna_genes_list_fname = os.path.join(checkpoint, "rna_genes.txt")
    assert os.path.isfile(
        rna_genes_list_fname
    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
    rna_genes = utils.read_delimited_file(rna_genes_list_fname)
    rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
    if no_filter:
        rna_data_kwargs = {
            k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_")
        }
        # Always discard cells with no expressed genes
        rna_data_kwargs["filt_cell_min_genes"] = 1
    rna_data_kwargs["fname"] = data
    reader_func = functools.partial(
        utils.sc_read_multi_files,
        reader=lambda x: sc_data_loaders.repool_genes(
            utils.get_ad_reader(x, ft_type="Gene Expression")(x), rna_genes
        ),
    )
    rna_data_kwargs["reader"] = reader_func
    try:
        logging.info(f"Building RNA dataset with parameters: {rna_data_kwargs}")
        sc_rna_full_dataset = sc_data_loaders.SingleCellDataset(
            mode="skip",
            **rna_data_kwargs,
        )
        assert all(
            [x == y for x, y in zip(rna_genes, sc_rna_full_dataset.data_raw.var_names)]
        ), "Mismatched genes"
        _temp = sc_rna_full_dataset[0]  # Try that query works
        # adata_utils.find_marker_genes(sc_rna_full_dataset.data_raw, n_genes=25)
        # marker_genes = adata_utils.flatten_marker_genes(
        #     sc_rna_full_dataset.data_raw.uns["rank_genes_leiden"]
        # )
        marker_genes = []
        # Write out the truth
    except (AssertionError, IndexError) as e:
        logging.warning(f"Error when reading RNA gene expression data from {data}: {e}")
        logging.warning("Ignoring RNA data")
        # Update length later
        sc_rna_full_dataset = sc_data_loaders.DummyDataset(
            shape=len(rna_genes), length=-1
        )
        marker_genes = []
    return sc_rna_full_dataset, rna_genes, marker_genes
Пример #4
0
def load_protein_accessory_model(dirname: str) -> skorch.NeuralNet:
    """Loads the protein accessory model"""
    predicted_proteins = utils.read_delimited_file(
        os.path.join(dirname, "protein_proteins.txt"))
    with open(os.path.join(dirname, "params.json")) as source:
        model_params = json.load(source)

    encoded_to_protein_skorch = skorch.NeuralNet(
        module=autoencoders.Decoder,
        module__num_units=16,
        module__intermediate_dim=model_params["interdim"],
        module__num_outputs=len(predicted_proteins),
        module__final_activation=nn.Identity(),
        module__activation=ACT_DICT[model_params["act"]],
        # module__final_activation=nn.Linear(
        #     len(predicted_proteins), len(predicted_proteins), bias=True
        # ),  # Paper uses identity activation instead
        lr=model_params["lr"],
        criterion=LOSS_DICT[model_params["loss"]],  # Other works use L1 loss
        optimizer=OPTIM_DICT[model_params["optim"]],
        batch_size=model_params["bs"],
        max_epochs=500,
        callbacks=[
            skorch.callbacks.EarlyStopping(patience=25),
            skorch.callbacks.LRScheduler(
                policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
                **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS,
            ),
            skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
        ],
        iterator_train__num_workers=8,
        iterator_valid__num_workers=8,
        device="cpu",
    )
    encoded_to_protein_skorch_cp = skorch.callbacks.Checkpoint(
        dirname=dirname, fn_prefix="net_")
    encoded_to_protein_skorch.load_params(
        checkpoint=encoded_to_protein_skorch_cp)
    return encoded_to_protein_skorch
Пример #5
0
def load_atac_files_for_eval(
    data: List[str],
    checkpoint: str,
    atac_bins_list_fname: str = "",
    lift_hg19_to_hg39: bool = False,
    predefined_split=None,
):
    """Load the ATAC files for evaluation"""
    if not atac_bins_list_fname:
        atac_bins_list_fname = os.path.join(checkpoint, "atac_bins.txt")
        logging.info(f"Auto-set atac bins fname to {atac_bins_list_fname}")
    assert os.path.isfile(
        atac_bins_list_fname
    ), f"Cannot find ATAC bins file: {atac_bins_list_fname}"
    atac_bins = utils.read_delimited_file(
        atac_bins_list_fname
    )  # These are the bins we are using (i.e. the bins the model was trained on)
    atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS)
    atac_data_kwargs["fname"] = data
    atac_data_kwargs["cluster_res"] = 0  # Disable clustering
    filt_atac_keys = [k for k in atac_data_kwargs.keys() if k.startswith("filt")]
    for k in filt_atac_keys:  # Reset filtering
        atac_data_kwargs[k] = None
    atac_data_kwargs["pool_genomic_interval"] = atac_bins
    if not lift_hg19_to_hg39:
        atac_data_kwargs["reader"] = functools.partial(
            utils.sc_read_multi_files,
            reader=lambda x: sc_data_loaders.repool_atac_bins(
                infer_reader(data[0], mode="atac")(x),
                atac_bins,
            ),
        )
    else:  # Requires liftover
        # Read, liftover, then repool
        atac_data_kwargs["reader"] = functools.partial(
            utils.sc_read_multi_files,
            reader=lambda x: sc_data_loaders.repool_atac_bins(
                sc_data_loaders.liftover_atac_adata(
                    # utils.sc_read_10x_h5_ft_type(x, "Peaks")
                    infer_reader(data[0], mode="atac")(x)
                ),
                atac_bins,
            ),
        )

    try:
        sc_atac_full_dataset = sc_data_loaders.SingleCellDataset(
            mode="skip",
            predefined_split=predefined_split if predefined_split else None,
            **atac_data_kwargs,
        )
        _temp = sc_atac_full_dataset[0]  # Try that query works
        assert all(
            [x == y for x, y in zip(atac_bins, sc_atac_full_dataset.data_raw.var_names)]
        )
    except AssertionError as err:
        logging.warning(f"Error when reading ATAC data from {data}: {err}")
        logging.warning("Ignoring ATAC data, returning dummy dataset instead")
        sc_atac_full_dataset = sc_data_loaders.DummyDataset(
            shape=len(atac_bins), length=-1
        )
    return sc_atac_full_dataset, atac_bins
Пример #6
0
def load_model(
    checkpoint: Optional[str] = None,
    input_dim1: int = -1,
    input_dim2: int = -1,
    prefix: str = "net_",
    device: str = "cpu",
    verbose: bool = False,
):
    """Load the primary model, flexible to hidden dim, for evaluation only"""
    # Load the model
    device_parsed = device
    try:
        device_parsed = utils.get_device(int(device))
    except (TypeError, ValueError):
        device_parsed = "cpu"

    # Download the model if we are not given a path
    if checkpoint is None:
        dl_path = gdown.cached_download(
            MODEL_URL,
            path=os.path.join(MODEL_CACHE_DIR, MODEL_FILE_BASENAME),
            md5=MODEL_MD5SUM,
            postprocess=gdown.extractall,
            quiet=not verbose,
        )
        logging.info(f"Model tarball at: {dl_path}")
        checkpoint = os.path.join(MODEL_CACHE_DIR, "cv_logsplit_01_model_only")
        assert os.path.isdir(
            checkpoint), f"Failed to find downloaded model in {checkpoint}"

    # Infer input dim sizes if they aren't given
    if input_dim1 is None or input_dim1 <= 0:
        rna_genes = utils.read_delimited_file(
            os.path.join(checkpoint, "rna_genes.txt"))
        input_dim1 = len(rna_genes)
        logging.info(f"Inferred RNA input dimension: {input_dim1}")
    if input_dim2 is None or (isinstance(input_dim2, int) and input_dim2 <= 0):
        atac_bins = utils.read_delimited_file(
            os.path.join(checkpoint, "atac_bins.txt"))
        chrom_counter = collections.defaultdict(int)
        for b in atac_bins:
            chrom = b.split(":")[0]
            chrom_counter[chrom] += 1
        # input_dim2 = list(chrom_counter.values())
        input_dim2 = [chrom_counter[c] for c in sorted(chrom_counter.keys())]
        logging.info(
            f"Inferred ATAC input dimension: {input_dim2} (sum={np.sum(input_dim2)})"
        )

    # Dynamically determine the model we are looking at based on name
    checkpoint_basename = os.path.basename(checkpoint)
    if checkpoint_basename.startswith("naive"):
        logging.info(
            f"Inferred model with basename {checkpoint_basename} to be naive")
        model_class = autoencoders.NaiveSplicedAutoEncoder
    else:
        logging.info(
            f"Inferred model with basename {checkpoint_basename} be normal (non-naive)"
        )
        model_class = autoencoders.AssymSplicedAutoEncoder

    spliced_net = None
    for hidden_dim_size in [16, 32]:
        try:
            spliced_net_ = autoencoders.SplicedAutoEncoderSkorchNet(
                module=model_class,
                module__input_dim1=input_dim1,
                module__input_dim2=input_dim2,
                module__hidden_dim=hidden_dim_size,
                # These don't matter because we're not training
                lr=0.01,
                criterion=loss_functions.QuadLoss,
                optimizer=torch.optim.Adam,
                batch_size=128,  # Reduced for memory saving
                max_epochs=500,
                # iterator_train__num_workers=8,
                # iterator_valid__num_workers=8,
                device=device_parsed,
            )
            spliced_net_.initialize()
            if checkpoint:
                cp = skorch.callbacks.Checkpoint(dirname=checkpoint,
                                                 fn_prefix=prefix)
                spliced_net_.load_params(checkpoint=cp)
            else:
                logging.warn("Using untrained model")
            # Upon successfully finding correct hiden size, break out of loop
            logging.info(f"Loaded model with hidden size {hidden_dim_size}")
            spliced_net = spliced_net_
            break
        except RuntimeError as e:
            logging.info(f"Failed to load with hidden size {hidden_dim_size}")
            if verbose:
                logging.info(e)
    if spliced_net is None:
        raise RuntimeError("Could not infer hidden size")

    spliced_net.module_.eval()
    return spliced_net
Пример #7
0
def main():
    parser = build_parser()
    args = parser.parse_args()

    if args.x_rna.endswith(".h5ad"):
        x_rna = ad.read_h5ad(args.x_rna)
    elif args.x_rna.endswith(".h5"):
        x_rna = sc.read_10x_h5(args.x_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.x_rna}")
    x_rna.X = utils.ensure_arr(x_rna.X)
    x_rna.obs_names = sanitize_obs_names(x_rna.obs_names)
    x_rna.obs_names_make_unique()
    logging.info(f"Read in {args.x_rna} for {x_rna.shape}")

    if args.y_rna.endswith(".h5ad"):
        y_rna = ad.read_h5ad(args.y_rna)
    elif args.y_rna.endswith(".h5"):
        y_rna = sc.read_10x_h5(args.y_rna, gex_only=False)
    else:
        raise ValueError(f"Unrecognized file extension: {args.y_rna}")
    y_rna.X = utils.ensure_arr(y_rna.X)
    y_rna.obs_names = sanitize_obs_names(y_rna.obs_names)
    y_rna.obs_names_make_unique()
    logging.info(f"Read in {args.y_rna} for {y_rna.shape}")

    if not (len(x_rna.obs_names) == len(y_rna.obs_names)
            and np.all(x_rna.obs_names == y_rna.obs_names)):
        logging.warning("Rematching obs axis")
        shared_obs_names = sorted(
            list(set(x_rna.obs_names).intersection(y_rna.obs_names)))
        logging.info(f"Found {len(shared_obs_names)} shared obs")
        assert shared_obs_names, ("Got empty list of shared obs" + "\n" +
                                  str(x_rna.obs_names) + "\n" +
                                  str(y_rna.obs_names))
        x_rna = x_rna[shared_obs_names]
        y_rna = y_rna[shared_obs_names]
    assert np.all(x_rna.obs_names == y_rna.obs_names)
    if not (len(x_rna.var_names) == len(y_rna.var_names)
            and np.all(x_rna.var_names == y_rna.var_names)):
        logging.warning("Rematching variable axis")
        shared_var_names = sorted(
            list(set(x_rna.var_names).intersection(y_rna.var_names)))
        logging.info(f"Found {len(shared_var_names)} shared variables")
        assert shared_var_names, ("Got empty list of shared vars" + "\n" +
                                  str(x_rna.var_names) + "\n" +
                                  str(y_rna.var_names))
        x_rna = x_rna[:, shared_var_names]
        y_rna = y_rna[:, shared_var_names]
    assert np.all(x_rna.var_names == y_rna.var_names)

    # Subset by gene list if given
    if args.genelist:
        gene_list = utils.read_delimited_file(args.genelist)
        logging.info(f"Read {len(gene_list)} genes from {args.genelist}")
        x_rna = x_rna[:, gene_list]
        y_rna = y_rna[:, gene_list]

    assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}"

    fig = plot_utils.plot_scatter_with_r(
        x_rna.X,
        y_rna.X,
        subset=args.subset,
        one_to_one=True,
        logscale=not args.linear,
        density_heatmap=args.density,
        density_logstretch=args.densitylogstretch,
        fname=args.outfname,
        title=args.title,
        xlabel=args.xlabel,
        ylabel=args.ylabel,
        figsize=args.figsize,
    )
Пример #8
0
def main():
    """Train a protein predictor"""
    parser = build_parser()
    args = parser.parse_args()

    # Create output directory
    if not os.path.isdir(args.outdir):
        os.makedirs(args.outdir)

    # Specify output log file
    logger = logging.getLogger()
    fh = logging.FileHandler(os.path.join(args.outdir, "training.log"))
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Log parameters
    for arg in vars(args):
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")
    with open(os.path.join(args.outdir, "params.json"), "w") as sink:
        json.dump(vars(args), sink, indent=4)

    # Load the model
    pretrained_net = model_utils.load_model(args.encoder, device=args.device)

    # Load in some files
    rna_genes = utils.read_delimited_file(
        os.path.join(args.encoder, "rna_genes.txt"))
    atac_bins = utils.read_delimited_file(
        os.path.join(args.encoder, "atac_bins.txt"))

    # Read in the RNA
    rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
    rna_data_kwargs["cluster_res"] = args.clusterres
    rna_data_kwargs["fname"] = args.rnaCounts
    rna_data_kwargs["reader"] = lambda x: load_rna_files(
        x, args.encoder, transpose=not args.notrans)

    # Construct data folds
    full_sc_rna_dataset = sc_data_loaders.SingleCellDataset(
        valid_cluster_id=args.validcluster,
        test_cluster_id=args.testcluster,
        **rna_data_kwargs,
    )
    full_sc_rna_dataset.data_raw.write_h5ad(
        os.path.join(args.outdir, "full_rna.h5ad"))

    train_valid_test_dsets = []
    for mode in ["all", "train", "valid", "test"]:
        logging.info(f"Constructing {mode} dataset")
        sc_rna_dataset = sc_data_loaders.SingleCellDatasetSplit(
            full_sc_rna_dataset, split=mode)
        sc_rna_dataset.data_raw.write_h5ad(
            os.path.join(args.outdir, f"{mode}_rna.h5ad"))  # Write RNA input
        sc_atac_dummy_dataset = sc_data_loaders.DummyDataset(
            shape=len(atac_bins), length=len(sc_rna_dataset))
        # RNA and fake ATAC
        sc_dual_dataset = sc_data_loaders.PairedDataset(
            sc_rna_dataset,
            sc_atac_dummy_dataset,
            flat_mode=True,
        )
        # encoded(RNA) as "x" and RNA + fake ATAC as "y"
        sc_rna_encoded_dataset = sc_data_loaders.EncodedDataset(
            sc_dual_dataset, model=pretrained_net, input_mode="RNA")
        sc_rna_encoded_dataset.encoded.write_h5ad(
            os.path.join(args.outdir, f"{mode}_encoded.h5ad"))
        sc_protein_dataset = sc_data_loaders.SingleCellProteinDataset(
            args.proteinCounts,
            obs_names=sc_rna_dataset.obs_names,
            transpose=not args.notrans,
        )
        sc_protein_dataset.data_raw.write_h5ad(
            os.path.join(args.outdir, f"{mode}_protein.h5ad"))  # Write protein
        # x = 16 dimensional encoded layer, y = 25 dimensional protein array
        sc_rna_protein_dataset = sc_data_loaders.SplicedDataset(
            sc_rna_encoded_dataset, sc_protein_dataset)
        _temp = sc_rna_protein_dataset[0]  # ensure calling works
        train_valid_test_dsets.append(sc_rna_protein_dataset)

    # Unpack and do sanity checks
    _, sc_rna_prot_train, sc_rna_prot_valid, sc_rna_prot_test = train_valid_test_dsets
    x, y, z = sc_rna_prot_train[0], sc_rna_prot_valid[0], sc_rna_prot_test[0]
    assert (x[0].shape == y[0].shape == z[0].shape
            ), f"Got mismatched shapes: {x[0].shape} {y[0].shape} {z[0].shape}"
    assert (x[1].shape == y[1].shape == z[1].shape
            ), f"Got mismatched shapes: {x[1].shape} {y[1].shape} {z[1].shape}"

    protein_markers = list(sc_protein_dataset.data_raw.var_names)
    with open(os.path.join(args.outdir, "protein_proteins.txt"), "w") as sink:
        sink.write("\n".join(protein_markers) + "\n")
    assert len(
        utils.read_delimited_file(
            os.path.join(args.outdir,
                         "protein_proteins.txt"))) == len(protein_markers)
    logging.info(f"Predicting on {len(protein_markers)} proteins")

    if args.preprocessonly:
        return

    protein_decoder_skorch = skorch.NeuralNet(
        module=autoencoders.Decoder,
        module__num_units=16,
        module__intermediate_dim=args.interdim,
        module__num_outputs=len(protein_markers),
        module__activation=ACT_DICT[args.act],
        module__final_activation=nn.Identity(),
        # module__final_activation=nn.Linear(
        #     len(protein_markers), len(protein_markers), bias=True
        # ),  # Paper uses identity activation instead
        lr=args.lr,
        criterion=LOSS_DICT[args.loss],  # Other works use L1 loss
        optimizer=OPTIM_DICT[args.optim],
        batch_size=args.bs,
        max_epochs=args.epochs,
        callbacks=[
            skorch.callbacks.EarlyStopping(patience=15),
            skorch.callbacks.LRScheduler(
                policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
                patience=5,
                factor=0.1,
                min_lr=1e-6,
                # **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS,
            ),
            skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
            skorch.callbacks.Checkpoint(
                dirname=args.outdir,
                fn_prefix="net_",
                monitor="valid_loss_best",
            ),
        ],
        train_split=skorch.helper.predefined_split(sc_rna_prot_valid),
        iterator_train__num_workers=8,
        iterator_valid__num_workers=8,
        device=utils.get_device(args.device),
    )
    protein_decoder_skorch.fit(sc_rna_prot_train, y=None)

    # Plot the loss history
    fig = plot_loss_history(protein_decoder_skorch.history,
                            os.path.join(args.outdir, "loss.pdf"))
Пример #9
0
def main():
    parser = build_parser()
    args = parser.parse_args()
    assert args.output.endswith(".csv")

    # Specify output log file
    logger = logging.getLogger()
    fh = logging.FileHandler(args.output + ".log")
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Log parameters
    for arg in vars(args):
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")

    # Load the model
    babel = model_utils.load_model(args.babel, device=args.device)
    # Load in some related files
    rna_genes = utils.read_delimited_file(
        os.path.join(args.babel, "rna_genes.txt"))
    atac_bins = utils.read_delimited_file(
        os.path.join(args.babel, "atac_bins.txt"))

    # Load in the protein accesory model
    babel_prot_acc_model = protein_utils.load_protein_accessory_model(
        args.protmodel)
    proteins = utils.read_delimited_file(
        os.path.join(args.protmodel, "protein_proteins.txt"))

    # Get the encoded layer based on input
    if args.rna:
        (
            sc_rna_dset,
            _rna_genes,
            _marker_genes,
            _housekeeper_genes,
        ) = load_rna_files_for_eval(args.rna,
                                    checkpoint=args.babel,
                                    no_filter=True)
        sc_atac_dummy_dset = sc_data_loaders.DummyDataset(
            shape=len(atac_bins), length=len(sc_rna_dset))
        sc_dual_dataset = sc_data_loaders.PairedDataset(
            sc_rna_dset,
            sc_atac_dummy_dset,
            flat_mode=True,
        )
        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
            sc_dual_dataset, model=babel, input_mode="RNA")
        cell_barcodes = list(sc_rna_dset.data_raw.obs_names)
        encoded = sc_dual_encoded_dataset.encoded
    else:
        sc_atac_dset, _loaded_atac_bins = load_atac_files_for_eval(
            args.atac,
            checkpoint=args.babel,
            lift_hg19_to_hg39=args.liftHg19toHg38)
        sc_rna_dummy_dset = sc_data_loaders.DummyDataset(
            shape=len(rna_genes), length=len(sc_atac_dset))
        sc_dual_dataset = sc_data_loaders.PairedDataset(sc_rna_dummy_dset,
                                                        sc_atac_dset,
                                                        flat_mode=True)
        sc_dual_encoded_dataset = sc_data_loaders.EncodedDataset(
            sc_dual_dataset, model=babel, input_mode="ATAC")
        cell_barcodes = list(sc_atac_dset.data_raw.obs_names)
        encoded = sc_dual_encoded_dataset.encoded

    # Array of preds
    prot_preds = babel_prot_acc_model.predict(encoded.X)
    prot_preds_df = pd.DataFrame(
        prot_preds,
        index=cell_barcodes,
        columns=proteins,
    )
    prot_preds_df.to_csv(args.output)
Пример #10
0
import utils

from evaluate_bulk_rna_concordance import load_file_flex_format

REF_MARKER_GENES = {
    "PBMC":
    set(
        itertools.chain.from_iterable(
            interpretation.PBMC_MARKER_GENES.values())),
    "PBMC_Seurat":
    set(
        itertools.chain.from_iterable(
            interpretation.SEURAT_PBMC_MARKER_GENES.values())),
    "Housekeeper":
    utils.read_delimited_file(
        os.path.join(os.path.dirname(SRC_DIR), "data",
                     "housekeeper_genes.txt")),
}


def build_parser():
    """Build basic CLI parser"""
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("preds",
                        type=str,
                        help="File with predicted expression")
    parser.add_argument("truth",
                        type=str,
                        help="File with ground truth expression")