def forward(self, input_seqs, profs_trans, cont_profs=None): # Run through inner model, disregarding the predicted counts logit_pred_profs, _ = self.inner_model(input_seqs, cont_profs, profs_trans) # As with the computation of the gradients, instead of explaining the # logits, explain the mean-normalized logits, weighted by the final # probabilities after passing through the softmax; this exponentially # increases the weight for high-probability positions, and exponentially # reduces the weight for low-probability positions, resulting in a # cleaner signal # Subtract mean along output profile dimension; this wouldn't change # softmax probabilities, but normalizes the magnitude of the logits norm_logit_pred_profs = logit_pred_profs - \ torch.mean(logit_pred_profs, dim=2, keepdim=True) # Weight by post-softmax probabilities, but detach it from the graph to # avoid explaining those pred_prof_probs = profile_models.profile_logits_to_log_probs( logit_pred_profs).detach() weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs if self.task_index is not None: # Subset to specific task weighted_norm_logits = \ weighted_norm_logits[:, self.task_index : (self.task_index + 1)] prof_sum = torch.sum(weighted_norm_logits, dim=(1, 2, 3)) # DeepSHAP requires the shape to be B x 1 return torch.unsqueeze(prof_sum, dim=1)
def test_vectorized_multinomial_nll(): np.random.seed(20191110) batch_size, num_tasks, prof_len = 500, 10, 1000 prof_shape = (batch_size, num_tasks, prof_len, 2) true_profs_np = np.random.randint(5, size=prof_shape) logit_pred_profs_np = np.random.randn(*prof_shape) log_pred_profs_np = profile_models.profile_logits_to_log_probs( logit_pred_profs_np, axis=2) true_counts_np = np.sum(true_profs_np, axis=2) print("Testing Multinomial NLL...") a = datetime.now() # Using the profile performance function: nll_vec_np = profile_performance.profile_multinomial_nll( true_profs_np, log_pred_profs_np, true_counts_np) b = datetime.now() print("\tTime to compute (NumPy vectorization): %ds" % (b - a).seconds) # Using the profile models function: # Convert to tensors and swap axes to make profile dimension last true_profs_tc = torch.tensor(true_profs_np).transpose(2, 3) log_pred_profs_tc = torch.tensor(log_pred_profs_np).transpose(2, 3) true_counts_tc = torch.tensor(true_counts_np) a = datetime.now() nll_vec_tc = -profile_models.multinomial_log_probs( log_pred_profs_tc, true_counts_tc, true_profs_tc) # Average across strands nll_vec_tc = torch.mean(nll_vec_tc, dim=2).numpy() b = datetime.now() print("\tTime to compute (PyTorch vectorization): %ds" % (b - a).seconds) # Using PyTorch's class # Convert to tensors logit_pred_profs_tc = torch.tensor(logit_pred_profs_np).transpose(2, 3) a = datetime.now() nll_torch = np.empty((batch_size, num_tasks)) for i in range(batch_size): for j in range(num_tasks): dist_0 = torch.distributions.Multinomial( true_counts_tc[i, j, 0].item(), logits=logit_pred_profs_tc[i, j, 0, :]) dist_1 = torch.distributions.Multinomial( true_counts_tc[i, j, 1].item(), logits=logit_pred_profs_tc[i, j, 1, :]) nll_0 = -dist_0.log_prob(true_profs_tc[i, j, 0, :].float()).item() nll_1 = -dist_1.log_prob(true_profs_tc[i, j, 1, :].float()).item() nll_torch[i][j] = np.mean([nll_0, nll_1]) b = datetime.now() print("\tTime to compute (PyTorch distributions): %ds" % (b - a).seconds) print("\tSame result? %s" % ( np.allclose(nll_vec_np, nll_vec_tc) and \ np.allclose(nll_vec_tc, nll_torch) ))
def _get_profile_model_predictions_batch( model, coords, num_tasks, input_func, controls=None, fourier_att_prior_freq_limit=200, fourier_att_prior_freq_limit_softness=0.2, att_prior_grad_smooth_sigma=3, return_losses=False, return_gradients=False): """ Fetches the necessary data from the given coordinates or bin indices and runs it through a profile or binary model. This will perform computation in a single batch. Arguments: `model`: a trained `ProfilePredictorWithMatchedControls`, `ProfilePredictorWithSharedControls`, or `ProfilePredictorWithoutControls` `coords`: a B x 3 array of coordinates to compute outputs for `num_tasks`: number of tasks for the model `input_func`: a function that takes in `coords` and returns the B x I x 4 array of one-hot sequences and the B x (T or T + 1 or 2T) x O x S array of profiles (perhaps with controls) `controls`: the type of control profiles (if any) used in model; can be "matched" (each task has a matched control), "shared" (all tasks share a control), or None (no controls); must match the model class `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior loss `fourier_att_prior_freq_limit_softness`: degree of softness for limit `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients `return_losses`: if True, compute/return the loss values `return_gradients`: if True, compute/return the input gradients and sequences Returns a dictionary of the following structure: true_profs: true profile raw counts (B x T x O x S) log_pred_profs: predicted profile log probabilities (B x T x O x S) true_counts: true total counts (B x T x S) log_pred_counts: predicted log counts (B x T x S) prof_losses: profile NLL losses (B-array), if `return_losses` is True count_losses: counts MSE losses (B-array) if `return_losses` is True att_losses: prior losses (B-array), if `return_losses` is True input_seqs: one-hot input sequences (B x I x 4), if `return_gradients` is true input_grads: "hypothetical" input gradients (B x I x 4), if `return_gradients` is true """ result = {} input_seqs, profiles = input_func(coords) if return_gradients: input_seqs_np = input_seqs model.zero_grad( ) # Zero out weights because we are computing gradients input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float() profiles = model_util.place_tensor(torch.tensor(profiles)).float() if controls is not None: tf_profs = profiles[:, :num_tasks, :, :] cont_profs = profiles[:, num_tasks:, :, :] # Last half or just one else: tf_profs, cont_profs = profiles, None if return_losses or return_gradients: input_seqs.requires_grad = True # Set gradient required logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs) # Subtract mean along output profile dimension; this wouldn't change # softmax probabilities, but normalizes the magnitude of gradients norm_logit_pred_profs = logit_pred_profs - \ torch.mean(logit_pred_profs, dim=2, keepdim=True) # Weight by post-softmax probabilities, but do not take the # gradients of these probabilities; this upweights important regions # exponentially pred_prof_probs = profile_models.profile_logits_to_log_probs( logit_pred_profs).detach() weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs input_grads, = torch.autograd.grad( weighted_norm_logits, input_seqs, grad_outputs=model_util.place_tensor( torch.ones(weighted_norm_logits.size())), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across strands and tasks ) input_grads_np = input_grads.detach().cpu().numpy() input_seqs.requires_grad = False # Reset gradient required else: logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs) result["true_profs"] = tf_profs.detach().cpu().numpy() result["true_counts"] = np.sum(result["true_profs"], axis=2) logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy() result["log_pred_profs"] = profile_models.profile_logits_to_log_probs( logit_pred_profs_np) result["log_pred_counts"] = log_pred_counts.detach().cpu().numpy() if return_losses: log_pred_profs = profile_models.profile_logits_to_log_probs( logit_pred_profs) num_samples = log_pred_profs.size(0) result["prof_losses"] = np.empty(num_samples) result["count_losses"] = np.empty(num_samples) result["att_losses"] = np.empty(num_samples) # Compute losses separately for each example for i in range(num_samples): _, prof_loss, count_loss = model.correctness_loss( tf_profs[i:i + 1], log_pred_profs[i:i + 1], log_pred_counts[i:i + 1], 1, return_separate_losses=True) att_loss = model.fourier_att_prior_loss( model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1], fourier_att_prior_freq_limit, fourier_att_prior_freq_limit_softness, att_prior_grad_smooth_sigma) result["prof_losses"][i] = prof_loss result["count_losses"][i] = count_loss result["att_losses"][i] = att_loss if return_gradients: result["input_seqs"] = input_seqs_np result["input_grads"] = input_grads_np return result
def run_epoch( data_loader, mode, model, epoch_num, params, optimizer=None, return_data=False, seq_mode=True, att_prior_loss_weight=None, ): """ Runs the data from the data loader once through the model, to train, validate, or predict. Arguments: `data_loader`: an instantiated `DataLoader` instance that gives batches of data; each batch must yield the input sequences, profiles, statuses, coordinates, and peaks; if `controls` is "matched", profiles must be such that the first half are prediction (target) profiles, and the second half are control profiles; if `controls` is "shared", the last set of profiles is a shared control; otherwise, all tasks are prediction profiles `mode`: one of "train", "eval"; if "train", run the epoch and perform backpropagation; if "eval", only do evaluation `model`: the current PyTorch model being trained/evaluated `epoch_num`: 0-indexed integer representing the current epoch `optimizer`: an instantiated PyTorch optimizer, for training mode `return_data`: if specified, returns the following as NumPy arrays: true profile raw counts (N x T x O x S), predicted profile log probabilities (N x T x O x S), true total counts (N x T x S), predicted log counts (N x T x S), coordinates used (N x 3 object array), input gradients (N x I x 4), and input sequences (N x I x4); if the attribution prior is not used, the gradients will be garbage Returns lists of overall losses, correctness losses, attribution prior losses, profile losses, and count losses, where each list is over all batches. If the attribution prior loss is not computed, then it will be all 0s. If `return_data` is True, then more things will be returned after these. """ num_tasks = params["num_tasks"] controls = params["controls"] if att_prior_loss_weight is None: att_prior_loss_weight = params["att_prior_loss_weight"] batch_size = params["batch_size"] revcomp = params["revcomp"] input_length = params["input_length"] input_depth = params["input_depth"] profile_length = params["profile_length"] gpu_id = params.get("gpu_id") assert mode in ("train", "eval") if mode == "train": assert optimizer is not None else: assert optimizer is None data_loader.dataset.on_epoch_start() # Set-up the epoch num_batches = len(data_loader.dataset) t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---") if mode == "train": model.train() # Switch to training mode torch.set_grad_enabled(True) batch_losses, corr_losses, att_losses = [], [], [] prof_losses, count_losses = [], [] if return_data: # Allocate empty NumPy arrays to hold the results num_samples_exp = num_batches * batch_size num_samples_exp *= 2 if revcomp else 1 # Real number of samples can be smaller because of partial last batch profile_shape = (num_samples_exp, num_tasks, profile_length, 2) count_shape = (num_samples_exp, num_tasks, 2) all_log_pred_profs = np.empty(profile_shape) all_log_pred_counts = np.empty(count_shape) all_true_profs = np.empty(profile_shape) all_true_counts = np.empty(count_shape) all_true_profs_trans = np.empty(profile_shape) all_true_counts_trans = np.empty(count_shape) all_input_seqs = np.empty((num_samples_exp, input_length, input_depth)) all_input_grads = np.empty( (num_samples_exp, input_length, input_depth)) all_coords = np.empty((num_samples_exp, 3), dtype=object) num_samples_seen = 0 # Real number of samples seen if seq_mode: model.set_seq_mode() else: model.set_aux_mode() for input_seqs, profiles, profiles_trans, statuses, coords, peaks in t_iter: if return_data: input_seqs_np = input_seqs input_seqs = util.place_tensor(torch.tensor(input_seqs), index=gpu_id).float() profiles = util.place_tensor(torch.tensor(profiles), index=gpu_id).float() profiles_trans = util.place_tensor(torch.tensor(profiles_trans), index=gpu_id).float() if controls is not None: tf_profs = profiles[:, :num_tasks, :, :].contiguous() cont_profs = profiles[:, num_tasks:, :, :].contiguous( ) # Last half or just one tf_profs_trans = profiles_trans[:, :num_tasks, :, :].contiguous() cont_profs_trans = profiles_trans[:, num_tasks:, :, :].contiguous( ) # Last half or just one else: tf_profs, cont_profs = profiles, None tf_profs_trans, cont_profs_trans = profiles, None # Clear gradients from last batch if training if mode == "train": optimizer.zero_grad() elif att_prior_loss_weight > 0: # Not training mode, but we still need to zero out weights because # we are computing the input gradients model.zero_grad() if att_prior_loss_weight > 0: input_seqs.requires_grad = True # Set gradient required logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs, tf_profs_trans) # print(input_seqs.is_contiguous()) #### # print(cont_profs.is_contiguous()) #### # print(tf_profs_trans.is_contiguous()) #### # Subtract mean along output profile dimension; this wouldn't change # softmax probabilities, but normalizes the magnitude of gradients norm_logit_pred_profs = logit_pred_profs - \ torch.mean(logit_pred_profs, dim=2, keepdim=True) # Weight by post-softmax probabilities, but do not take the # gradients of these probabilities; this upweights important regions # exponentially pred_prof_probs = profile_models.profile_logits_to_log_probs( logit_pred_profs).detach() weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs if seq_mode: model.set_seq_mode() input_grads, = torch.autograd.grad( weighted_norm_logits, input_seqs, grad_outputs=util.place_tensor(torch.ones( weighted_norm_logits.size()), index=gpu_id), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across strands and tasks ) # print(torch.autograd.grad(torch.sum(input_grads), input_seqs, retain_graph=True)) #### # print(input_grads.shape, input_seqs.shape) #### # input_grads = input_grads.contiguous() if return_data: input_grads_np = input_grads.detach().cpu().numpy() # input_grads = input_grads.detach() #### input_grads = input_grads * input_seqs # Gradient * input # input_grads = input_seqs #### status = util.place_tensor(torch.tensor(statuses), index=gpu_id) status[status != 0] = 1 # Set to 1 if not negative example input_seqs.requires_grad = False # Reset gradient required else: logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs, tf_profs_trans) status, input_grads = None, None loss, (corr_loss, att_loss), (prof_loss, count_loss) = model_loss(model, tf_profs, logit_pred_profs, log_pred_counts, epoch_num, status=status, input_grads=input_grads) # print(input_grads) #### # print(input_grads.shape) #### if mode == "train": # att_loss.backward() #### loss.backward() # Compute gradient optimizer.step() # Update weights through backprop # if not ignore_aux: # model.unfreeze_ptp_layers() batch_losses.append(loss.item()) corr_losses.append(corr_loss.item()) att_losses.append(att_loss.item()) prof_losses.append(prof_loss.item()) count_losses.append(count_loss.item()) t_iter.set_description("\tLoss: %6.4f" % loss.item()) if return_data: logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy() log_pred_counts_np = log_pred_counts.detach().cpu().numpy() true_profs_np = tf_profs.detach().cpu().numpy() true_counts = np.sum(true_profs_np, axis=2) true_profs_trans_np = tf_profs_trans.detach().cpu().numpy() true_counts_trans = np.sum(true_profs_trans_np, axis=2) num_in_batch = true_counts.shape[0] # Turn logit profile predictions into log probabilities log_pred_profs = profile_models.profile_logits_to_log_probs( logit_pred_profs_np, axis=2) # Fill in the batch data/outputs into the preallocated arrays start, end = num_samples_seen, num_samples_seen + num_in_batch all_log_pred_profs[start:end] = log_pred_profs all_log_pred_counts[start:end] = log_pred_counts_np all_true_profs[start:end] = true_profs_np all_true_counts[start:end] = true_counts all_true_profs_trans[start:end] = true_profs_trans_np all_true_counts_trans[start:end] = true_counts_trans all_input_seqs[start:end] = input_seqs_np if att_prior_loss_weight: all_input_grads[start:end] = input_grads_np all_coords[start:end] = coords num_samples_seen += num_in_batch if return_data: # Truncate the saved data to the proper size, based on how many # samples actually seen all_log_pred_profs = all_log_pred_profs[:num_samples_seen] all_log_pred_counts = all_log_pred_counts[:num_samples_seen] all_true_profs = all_true_profs[:num_samples_seen] all_true_counts = all_true_counts[:num_samples_seen] all_true_profs_trans = all_true_profs_trans[:num_samples_seen] all_true_counts_trans = all_true_counts_trans[:num_samples_seen] all_input_seqs = all_input_seqs[:num_samples_seen] all_input_grads = all_input_grads[:num_samples_seen] all_coords = all_coords[:num_samples_seen] return batch_losses, corr_losses, att_losses, prof_losses, \ count_losses, all_true_profs, all_log_pred_profs, \ all_true_counts, all_log_pred_counts, all_coords, all_input_grads, \ all_input_seqs, all_true_profs_trans, all_true_counts_trans else: return batch_losses, corr_losses, att_losses, prof_losses, count_losses