def interpret_atac_to_rna( model: nn.Module, inputs: torch.Tensor, target: int, baseline: torch.Tensor = None, interpret_class=IntegratedGradients, progressbar: bool = False, device=DEVICE, batch_size: int = 8, ) -> sparse.csr_matrix: """ For each output gene in RNA space, in the given single data point, perform interpretation Target is given as RNA index (ATAC is ignored in output space) """ model.eval() model = model.to(device) if baseline is None: baseline = torch.zeros_like(inputs[0, :]).view(1, -1).to(device) model_forward = lambda x: model.forward(x, mode=(2, 1))[0] # 2 > 1 is ATAC > RNA mode explainer = interpret_class(forward_func=model_forward) # Establish index pairs baseline_indices = list(range(0, baseline.shape[0], batch_size)) baseline_index_pairs = list( zip(baseline_indices[:-1], baseline_indices[1:])) if not baseline_index_pairs: baseline_index_pairs = [(0, baseline.shape[0])] inputs_indices = list(range(0, inputs.shape[0], batch_size)) inputs_index_pairs = list(zip(inputs_indices[:-1], inputs_indices[1:])) if not inputs_index_pairs: inputs_index_pairs = [(0, inputs.shape[0])] # Loop through the examples and do attributions attrs = [] pbar = tqdm.tqdm_notebook if isnotebook() else tqdm.tqdm for baseline_start, baseline_stop in pbar(baseline_index_pairs, disable=not progressbar): for input_start, input_stop in inputs_index_pairs: attributions = (explainer.attribute( inputs[input_start:input_stop].to(device), baseline[baseline_start:baseline_stop].to(device), target=target, ).cpu().numpy()) attr_sparse = sparse.csr_matrix(attributions) # delta = delta.cpu().numpy() attrs.append(attr_sparse) # attrs = np.vstack(attrs) attrs = sparse.vstack(attrs) return attrs
import time import pickle from path import Path import numpy as np import pandas as pd from sklearn.metrics import precision_recall_fscore_support, classification_report, confusion_matrix import torch import torch.nn as nn from utils import LABEL_NAME, isnotebook, set_seed, format_time if isnotebook(): from tqdm.notebook import tqdm else: from tqdm import tqdm set_seed(seed=228) def model_train(model, train_data_loader, valid_data_loader, test_data_loader, logger, optimizer, scheduler, num_epochs, seed, out_dir): # move model to gpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) num_gpus = torch.cuda.device_count() logger.info("Let's use {} GPUs!".format(num_gpus))
def evaluate_pwm_importance(model, dataset, ppms, prop=0.9, ablation_strat="N", progressbar=True, device=DEVICE): """ Given a dataset and PWMs, evaluate which PPMs, when ablated, impact the model the most Works for both pytorch and sklearn models. For sklearn models, assumes the dataset is per-transcript-part kmer featurization with kmer sizes of 3, 4, 5 """ def seq_parts_to_ft_vec(transcript_parts, kmer_sizes=[3, 4, 5]): """Helper function to translate seq parts to feature vector of one sample""" return np.atleast_2d( np.hstack([ kmer.sequence_to_kmer_freqs(part, kmer_size=s) for s in kmer_sizes for part in transcript_parts ])) def seq_parts_to_ft_vec_torch(transcript_parts, kmer_sizes=[3, 4, 5], with_len=False): """Helper function to translate seq parts to feature vector of one sample""" seq_encoded = np.array( [seq.BASE_TO_INT[b] for b in ''.join(transcript_parts)]) retval = torch.from_numpy(seq_encoded[:, np.newaxis, np.newaxis]).type( torch.LongTensor).to(device) length_retval = torch.LongTensor([len(''.join(transcript_parts))]) if with_len: return retval, length_retval else: return retval def sigmoid(x): """Calculate the sigmoid function""" return 1.0 / (1.0 + np.exp(-x)) is_torch_model = isinstance(model, nn.Module) is_batched_torch_model = is_torch_model and model.forward.__code__.co_argcount == 3 retval = { localization: pd.DataFrame( # Create empty dataframe that we'll fill in later np.nan, index=np.arange(len(dataset)), columns=[ f"{trans_part}_{ppm_name}" for trans_part, ppm_name in itertools.product(['u5', 'cds', 'u3'], ppms.keys()) ]) for localization in dataset.compartments } if is_torch_model: model.to(device) ppm_names = list(ppms.keys()) ppm_mats = [ppms[k] for k in ppm_names] pbar = tqdm.tqdm if not utils.isnotebook() else tqdm.tqdm_notebook for i in pbar(range(len(dataset)), disable=not progressbar): seq_parts = dataset.get_ith_trans_parts(i) if is_torch_model: if not is_batched_torch_model: orig_preds = model(seq_parts_to_ft_vec_torch(seq_parts)) else: orig_preds = model( *seq_parts_to_ft_vec_torch(seq_parts, with_len=True)) orig_preds = sigmoid( orig_preds.detach().cpu().numpy()) # Make it a numpy array else: orig_preds = model_utils.list_preds_to_array_preds( model.predict_proba(seq_parts_to_ft_vec(seq_parts))) assert np.all(orig_preds <= 1.0) and np.all( orig_preds >= 0 ), "Preds are expected to be probabilties, so cannot exceed [0, 1]" for j, (part_name, part) in enumerate(zip(['u5', 'cds', 'u3'], seq_parts)): ablations = [ ablate_ppm(part, p, ablation_strat, prop) for p in ppm_mats ] for ppm_name, ppm, ablated_part in zip(ppm_names, ppm_mats, ablations): if ablated_part != part: # There was *something* to ablate seq_parts_ablated = list( seq_parts ) # Make a copy of the original and sub in the part we just ablated seq_parts_ablated[j] = ablated_part if is_torch_model: if not is_batched_torch_model: ablated_preds = model( seq_parts_to_ft_vec_torch(seq_parts_ablated)) else: ablated_preds = model(*seq_parts_to_ft_vec_torch( seq_parts_ablated, with_len=True)) ablated_preds = sigmoid( ablated_preds.detach().cpu().numpy()) else: ablated_preds = model_utils.list_preds_to_array_preds( model.predict_proba( seq_parts_to_ft_vec(seq_parts_ablated))) assert np.all(ablated_preds <= 1.0) and np.all( ablated_preds >= 0.0 ), "Preds are expected to be probabilities, so cannot exceed [0, 1]" delta = ablated_preds - orig_preds for localization, impact, true_label in zip( dataset.compartments, delta.flatten(), np.squeeze(dataset.get_ith_labels(i))): if true_label: # only insert if it's suppose to be positive localization retval[localization].loc[ i, f"{part_name}_{ppm_name}"] = impact return retval
def __call__(self, i, solver, fitness_fn, fitnesses_fn, best_params_fn): if isnotebook(): best_params = best_params_fn(solver) img = arr2img(self.render_fn(best_params)) # pylint:disable=undefined-variable display(img) # type: ignore
from datetime import datetime # from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.utils.data import DataLoader import numpy as np import wandb import utils from utils import accuracy, RunningAverageMeter, MovingAverageMeter, get_dataset from data_loader import get_test_loader if utils.isnotebook(): from tqdm.notebook import tqdm else: from tqdm import tqdm # get rid of the torch.nn.KLDivLoss(reduction='batchmean') warning warnings.filterwarnings( action="ignore", message= "reduction: 'mean' divides the total loss by both the batch size and the support size." ) class Trainer(object): """ Trainer encapsulates all the logic necessary for