def make_shap_scores(model_path,
                     model_type,
                     files_spec_path,
                     input_length,
                     num_tasks,
                     out_path,
                     reference_fasta,
                     chrom_sizes,
                     task_index=None,
                     profile_length=1000,
                     controls=None,
                     num_strands=2,
                     chrom_set=None,
                     batch_size=128):
    """
    Computes SHAP scores over an entire dataset, and saves them as an HDF5 file.
    The SHAP scores are computed for all positive input sequences (i.e. peaks or
    positive bins).
    Arguments:
        `model_path`: path to saved model
        `model_type`: either "binary" or "profile"
        `files_spec_path`: path to files specs JSON
        `input_length`: length of input sequences
        `num_tasks`: number of tasks in the model
        `out_path`: path to HDF5 to save SHAP scores and input sequences
        `reference_fasta`: path to reference FASTA
        `chrom_sizes`: path to chromosome sizes TSV
        `task_index`: index of task to explain; if None, explain all tasks in
            aggregate
        `profile_length`: for profile models, the length of output profiles
        `controls`: for profile models, the kind of controls used: "matched",
            "shared", or None; this also determines the class of the model
        `chrom_set`: the set of chromosomes to compute SHAP scores for; if None,
            defaults to all chromosomes
        `batch_size`: batch size for SHAP score computation
    Creates/saves an HDF5 containing the SHAP scores and the input sequences.
    The HDF5 has the following keys:
        `coords_chrom`: an N-array of the coordinate chromosomes
        `coords_start`: an N-array of the coordinate starts
        `coords_end`: an N-array of the coordinate ends
        `one_hot_seqs`: an N x I x 4 array of one-hot encoded input sequences
        `hyp_scores`: an N x I x 4 array of hypothetical SHAP contribution
            scores
        `model`: path to the model, `model_path`
    """
    assert model_type in ("binary", "profile")

    # Determine the model class and import the model
    if model_type == "binary":
        model_class = binary_models.BinaryPredictor
    elif controls == "matched":
        model_class = profile_models.ProfilePredictorWithMatchedControls
    elif controls == "shared":
        model_class = profile_models.ProfilePredictorWithSharedControls
    elif controls is None:
        model_class = profile_models.ProfilePredictorWithoutControls
    torch.set_grad_enabled(True)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)

    # Create the data loaders
    if model_type == "binary":
        input_func = data_loading.get_binary_input_func(
            files_spec_path, input_length, reference_fasta)
        pos_samples = data_loading.get_positive_binary_bins(
            files_spec_path, chrom_set=chrom_set)
    else:
        input_func = data_loading.get_profile_input_func(
            files_spec_path,
            input_length,
            profile_length,
            reference_fasta,
        )
        pos_samples = data_loading.get_positive_profile_coords(
            files_spec_path, chrom_set=chrom_set)

    num_pos = len(pos_samples)
    num_batches = int(np.ceil(num_pos / batch_size))

    # Allocate arrays to hold the results
    coords_chrom = np.empty(num_pos, dtype=object)
    coords_start = np.empty(num_pos, dtype=int)
    coords_end = np.empty(num_pos, dtype=int)
    status = np.empty(num_pos, dtype=int)
    one_hot_seqs = np.empty((num_pos, input_length, 4))
    hyp_scores = np.empty((num_pos, input_length, 4))

    # Create the explainer
    if model_type == "binary":
        explainer = compute_shap.create_binary_explainer(model,
                                                         input_length,
                                                         task_index=task_index)
    else:
        explainer = compute_shap.create_profile_explainer(
            model,
            input_length,
            profile_length,
            num_tasks,
            num_strands,
            controls,
            task_index=task_index)

    # Compute the importance scores
    for i in tqdm.trange(num_batches):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        # Compute scores
        if model_type == "binary":
            input_seqs, _, coords = input_func(pos_samples[batch_slice])
            scores = explainer(input_seqs, hide_shap_output=True)
        else:
            coords = pos_samples[batch_slice]
            input_seqs, profiles = input_func(coords)
            scores = explainer(
                input_seqs, profiles[:, num_tasks:], hide_shap_output=True
            )  # Regardless of the type of controls, we can always put this in

        # Fill in data
        coords_chrom[batch_slice] = coords[:, 0]
        coords_start[batch_slice] = coords[:, 1]
        coords_end[batch_slice] = coords[:, 2]
        one_hot_seqs[batch_slice] = input_seqs
        hyp_scores[batch_slice] = scores

    # Write to HDF5
    print("Saving result to HDF5...")
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with h5py.File(out_path, "w") as f:
        f.create_dataset("coords_chrom", data=coords_chrom.astype("S"))
        f.create_dataset("coords_start", data=coords_start)
        f.create_dataset("coords_end", data=coords_end)
        f.create_dataset("hyp_scores", data=hyp_scores)
        f.create_dataset("one_hot_seqs", data=one_hot_seqs)
        model = f.create_dataset("model", data=0)
        model.attrs["model"] = model_path
Exemple #2
0
    reference_fasta = "/users/amtseng/genomes/hg38.fasta"
    chrom_set = ["chr21"]

    print("Testing profile model")
    input_length = 1346
    profile_length = 1000
    controls = "matched"
    num_tasks = 4

    files_spec_path = "/users/amtseng/att_priors/data/processed/ENCODE_TFChIP/profile/config/SPI1/SPI1_training_paths.json"
    model_class = profile_models.ProfilePredictorWithMatchedControls
    model_path = "/users/amtseng/att_priors/models/trained_models/profile/SPI1/1/model_ckpt_epoch_1.pt"

    input_func = data_loading.get_profile_input_func(
        files_spec_path,
        input_length,
        profile_length,
        reference_fasta,
    )
    pos_coords = data_loading.get_positive_profile_coords(files_spec_path,
                                                          chrom_set=chrom_set)

    print("Loading model...")
    torch.set_grad_enabled(True)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)

    print("Running predictions...")
    x = get_profile_model_predictions(model,
Exemple #3
0
def run(files_spec,
        model_path,
        reference_fasta,
        model_class,
        out_path,
        num_runs,
        chrom_set,
        num_tasks,
        prof_size,
        center_size_to_use,
        batch_size,
        model_args_extras=None):
    print("Loading footprints...")

    peaks, peak_to_fp_prof, peak_to_fp_reg = data_loading.get_profile_footprint_coords(
        files_spec,
        prof_size=prof_size,
        region_size=center_size_to_use,
        chrom_set=chrom_set)
    masks = {k: create_mask(k, v) for k, v in peak_to_fp_reg.items()}
    fp_to_peak = get_fp_to_peak(peak_to_fp_prof)
    fps = list(fp_to_peak.keys())

    print("Loading model...")

    if model_class == "prof_trans":
        input_func = data_loading.get_profile_trans_input_func(
            files_spec,
            center_size_to_use,
            prof_size,
            reference_fasta,
        )
    else:
        input_func = data_loading.get_profile_input_func(
            files_spec,
            center_size_to_use,
            prof_size,
            reference_fasta,
        )

    model = load_model(model_path,
                       model_class,
                       model_args_extras=model_args_extras)

    print("Computing metrics...")

    results = []
    fp_idx = {}
    # fps = fps[:100] ####
    for batch, i in enumerate(tqdm.tqdm(range(0, len(fps), batch_size))):
        # if batch < 680: ####
        #     continue
        try:
            j = min(i + batch_size, len(fps))
            fps_slice = fps[i:j]
            peaks_slice = list(set(fp_to_peak[i] for i in fps_slice))
            peak_to_seq_idx = {val: ind for ind, val in enumerate(peaks_slice)}
            fp_to_seq_slice = get_fp_to_seq_slice(fps_slice, fp_to_peak,
                                                  peak_to_seq_idx,
                                                  center_size_to_use)

            if model_class == "prof_trans":
                seqs, profiles, profs_trans = input_func(peaks_slice)
                profs_trans = profs_trans[:, :num_tasks]
                profs_ctrls = profiles[:, num_tasks:]
                # print(profs_trans.shape) ####
                # print(profs_ctrls.shape) ####
                seqs_in, profs_ctrls_in, profs_trans_in = get_ablated_inputs(
                    fps_slice,
                    seqs,
                    profs_ctrls,
                    fp_to_seq_slice,
                    fp_to_peak,
                    masks,
                    num_runs,
                    profs_trans=profs_trans)
                profs_preds_logits, counts_preds = run_model(
                    model,
                    seqs_in,
                    profs_ctrls_in,
                    fps,
                    profs_trans=profs_trans_in)
            else:
                seqs, profiles = input_func(peaks_slice)
                profs_ctrls = profiles[:, num_tasks:]
                seqs_in, profs_ctrls_in = get_ablated_inputs(
                    fps_slice, seqs, profs_ctrls, fp_to_seq_slice, fp_to_peak,
                    masks, num_runs)
                profs_preds_logits, counts_preds = run_model(
                    model, seqs_in, profs_ctrls_in, fps)

            metrics = get_metrics(profs_preds_logits, counts_preds, num_runs)
            result_b = {
                "footprints": fps_slice,
                "peaks": [fp_to_peak[i] for i in fps_slice],
                "metrics": metrics,
            }
            results.append(result_b)

            for ind, val in enumerate(fps_slice):
                fp_idx[val] = (batch, ind)

        except Exception as e:
            traceback.print_exc()

    export = {"results": results, "index": fp_idx}
    # print(export) ####

    print(f"Saving to {out_path}")
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "wb") as out_file:
        pickle.dump(export, out_file)