Пример #1
0
def main(ref_fasta, model_type, input_length, profile_length, use_controls,
         chrom_set, out_dir, files_spec_path, model_path):
    if not input_length:
        if model_type == "binary":
            input_length = 1000
        else:
            input_length = 1346

    if model_type == "binary":
        model_class = binary_models.BinaryPredictor
    elif model_type == "profile" and use_controls:
        model_class = profile_models.ProfilePredictorWithControls
    else:
        model_class = profile_models.ProfilePredictorWithoutControls

    chrom_set = chrom_set.split(",")

    print("Inputs supplied:")
    print("\tReference fasta: %s" % ref_fasta)
    print("\tModel type: %s" % model_type)
    print("\tChromosome set: %s" % chrom_set)
    print("\tFiles spec: %s" % files_spec_path)
    print("\tModel: %s" % model_path)

    # Import model
    torch.set_grad_enabled(True)
    device = torch.device("cuda") if torch.cuda.is_available() \
        else torch.device("cpu")
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)

    print("Computing gradients...")
    input_grads, input_seqs = compute_gradients.get_input_grads(
        model,
        model_type,
        files_spec_path,
        input_length,
        ref_fasta,
        chrom_set=chrom_set,
        profile_length=profile_length,
        use_controls=use_controls)

    hyp_seqlets, act_seqlets, seqlet_seqs, cluster_assignments, cluster_ids = \
        cluster_gradients(
            input_grads, input_seqs
        )

    print("Saving output...")
    os.makedirs(out_dir, exist_ok=True)
    np.save(os.path.join(out_dir, "hyp_seqlets"), hyp_seqlets)
    np.save(os.path.join(out_dir, "act_seqlets"), act_seqlets)
    np.save(os.path.join(out_dir, "seqlet_seqs"), seqlet_seqs)
    np.save(os.path.join(out_dir, "cluster_assignments"), cluster_assignments)
    np.save(os.path.join(out_dir, "cluster_ids"), cluster_ids)
Пример #2
0
def load_model(model_path, model_classname, gpu_id, model_args_extras=None):
    torch.set_grad_enabled(True)
    device = torch.device(
        f"cuda:{gpu_id}") if torch.cuda.is_available() else torch.device("cpu")
    print(f"Running on {device}")

    if model_classname == "prof_trans":
        model_class = profile_models.ProfilePredictorTransfer
    else:
        model_class = profile_models.ProfilePredictorWithMatchedControls
    model = restore_model(model_class,
                          model_path,
                          model_args_extras=model_args_extras)
    model.eval()
    model = model.to(device)
    return model
Пример #3
0
    model_path = "/users/amtseng/att_priors/models/trained_models/profile/SPI1/1/model_ckpt_epoch_1.pt"

    input_func = data_loading.get_profile_input_func(
        files_spec_path,
        input_length,
        profile_length,
        reference_fasta,
    )
    pos_coords = data_loading.get_positive_profile_coords(files_spec_path,
                                                          chrom_set=chrom_set)

    print("Loading model...")
    torch.set_grad_enabled(True)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)

    print("Running predictions...")
    x = get_profile_model_predictions(model,
                                      pos_coords,
                                      num_tasks,
                                      input_func,
                                      controls=controls,
                                      return_losses=True,
                                      return_gradients=True,
                                      show_progress=True)

    print("")
def make_shap_scores(model_path,
                     model_type,
                     files_spec_path,
                     input_length,
                     num_tasks,
                     out_path,
                     reference_fasta,
                     chrom_sizes,
                     task_index=None,
                     profile_length=1000,
                     controls=None,
                     num_strands=2,
                     chrom_set=None,
                     batch_size=128):
    """
    Computes SHAP scores over an entire dataset, and saves them as an HDF5 file.
    The SHAP scores are computed for all positive input sequences (i.e. peaks or
    positive bins).
    Arguments:
        `model_path`: path to saved model
        `model_type`: either "binary" or "profile"
        `files_spec_path`: path to files specs JSON
        `input_length`: length of input sequences
        `num_tasks`: number of tasks in the model
        `out_path`: path to HDF5 to save SHAP scores and input sequences
        `reference_fasta`: path to reference FASTA
        `chrom_sizes`: path to chromosome sizes TSV
        `task_index`: index of task to explain; if None, explain all tasks in
            aggregate
        `profile_length`: for profile models, the length of output profiles
        `controls`: for profile models, the kind of controls used: "matched",
            "shared", or None; this also determines the class of the model
        `chrom_set`: the set of chromosomes to compute SHAP scores for; if None,
            defaults to all chromosomes
        `batch_size`: batch size for SHAP score computation
    Creates/saves an HDF5 containing the SHAP scores and the input sequences.
    The HDF5 has the following keys:
        `coords_chrom`: an N-array of the coordinate chromosomes
        `coords_start`: an N-array of the coordinate starts
        `coords_end`: an N-array of the coordinate ends
        `one_hot_seqs`: an N x I x 4 array of one-hot encoded input sequences
        `hyp_scores`: an N x I x 4 array of hypothetical SHAP contribution
            scores
        `model`: path to the model, `model_path`
    """
    assert model_type in ("binary", "profile")

    # Determine the model class and import the model
    if model_type == "binary":
        model_class = binary_models.BinaryPredictor
    elif controls == "matched":
        model_class = profile_models.ProfilePredictorWithMatchedControls
    elif controls == "shared":
        model_class = profile_models.ProfilePredictorWithSharedControls
    elif controls is None:
        model_class = profile_models.ProfilePredictorWithoutControls
    torch.set_grad_enabled(True)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)

    # Create the data loaders
    if model_type == "binary":
        input_func = data_loading.get_binary_input_func(
            files_spec_path, input_length, reference_fasta)
        pos_samples = data_loading.get_positive_binary_bins(
            files_spec_path, chrom_set=chrom_set)
    else:
        input_func = data_loading.get_profile_input_func(
            files_spec_path,
            input_length,
            profile_length,
            reference_fasta,
        )
        pos_samples = data_loading.get_positive_profile_coords(
            files_spec_path, chrom_set=chrom_set)

    num_pos = len(pos_samples)
    num_batches = int(np.ceil(num_pos / batch_size))

    # Allocate arrays to hold the results
    coords_chrom = np.empty(num_pos, dtype=object)
    coords_start = np.empty(num_pos, dtype=int)
    coords_end = np.empty(num_pos, dtype=int)
    status = np.empty(num_pos, dtype=int)
    one_hot_seqs = np.empty((num_pos, input_length, 4))
    hyp_scores = np.empty((num_pos, input_length, 4))

    # Create the explainer
    if model_type == "binary":
        explainer = compute_shap.create_binary_explainer(model,
                                                         input_length,
                                                         task_index=task_index)
    else:
        explainer = compute_shap.create_profile_explainer(
            model,
            input_length,
            profile_length,
            num_tasks,
            num_strands,
            controls,
            task_index=task_index)

    # Compute the importance scores
    for i in tqdm.trange(num_batches):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        # Compute scores
        if model_type == "binary":
            input_seqs, _, coords = input_func(pos_samples[batch_slice])
            scores = explainer(input_seqs, hide_shap_output=True)
        else:
            coords = pos_samples[batch_slice]
            input_seqs, profiles = input_func(coords)
            scores = explainer(
                input_seqs, profiles[:, num_tasks:], hide_shap_output=True
            )  # Regardless of the type of controls, we can always put this in

        # Fill in data
        coords_chrom[batch_slice] = coords[:, 0]
        coords_start[batch_slice] = coords[:, 1]
        coords_end[batch_slice] = coords[:, 2]
        one_hot_seqs[batch_slice] = input_seqs
        hyp_scores[batch_slice] = scores

    # Write to HDF5
    print("Saving result to HDF5...")
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with h5py.File(out_path, "w") as f:
        f.create_dataset("coords_chrom", data=coords_chrom.astype("S"))
        f.create_dataset("coords_start", data=coords_start)
        f.create_dataset("coords_end", data=coords_end)
        f.create_dataset("hyp_scores", data=hyp_scores)
        f.create_dataset("one_hot_seqs", data=one_hot_seqs)
        model = f.create_dataset("model", data=0)
        model.attrs["model"] = model_path