def predict_func(input_seq_batch): # Return the set of outputs for the batch of input sequences, all # using the same set of control profiles, if needed num_in_batch = len(input_seq_batch) if cont_profs is not None: cont_profs_batch = np.stack([cont_profs[seq_index]] * num_in_batch, axis=0) else: cont_profs_batch = None # No controls output_vals = np.empty(num_in_batch) num_batches = int(np.ceil(num_in_batch / batch_size)) for i in range(num_batches): batch_slice = slice(i * batch_size, (i + 1) * batch_size) logit_pred_profs, _ = model( place_tensor(torch.tensor( input_seq_batch[batch_slice])).float(), place_tensor(torch.tensor( cont_profs_batch[batch_slice])).float()) logit_pred_profs = logit_pred_profs.detach().cpu().numpy() if task_index is None: output_vals[batch_slice] = np.sum(logit_pred_profs, axis=(1, 2, 3)) else: output_vals[batch_slice] = np.sum( logit_pred_profs[:, task_index], axis=(1, 2)) return output_vals
def run_model(model, seqs, profs_ctrls, fps, gpu_id, profs_trans=None): # print(seqs[:5]) #### # print(seqs.shape) #### # print(profs_ctrls.shape) #### # print(profs_trans.shape) #### num_runs = seqs.shape[1] profs_preds_shape = (profs_ctrls.shape[0], seqs.shape[1], profs_ctrls.shape[1], profs_ctrls.shape[2], profs_ctrls.shape[3]) profs_preds_logits = np.empty(profs_preds_shape) counts_preds_shape = (profs_ctrls.shape[0], seqs.shape[1], profs_ctrls.shape[1], profs_ctrls.shape[3]) counts_preds = np.empty(counts_preds_shape) profs_ctrls_b = place_tensor(torch.tensor(profs_ctrls), index=gpu_id).float() if profs_trans is not None: profs_trans_b = place_tensor(torch.tensor(profs_trans), index=gpu_id).float() for run in range(seqs.shape[1]): seqs_b = place_tensor(torch.tensor(seqs[:, run]), index=gpu_id).float() if profs_trans is not None: profs_preds_logits_b, counts_preds_b = model( seqs_b, profs_ctrls_b, profs_trans_b) else: profs_preds_logits_b, counts_preds_b = model(seqs_b, profs_ctrls_b) profs_preds_logits_b = profs_preds_logits_b.detach().cpu().numpy() counts_preds_b = counts_preds_b.detach().cpu().numpy() profs_preds_logits[:, run] = profs_preds_logits_b counts_preds[:, run] = counts_preds_b return profs_preds_logits, counts_preds
def explain_fn(input_seqs, cont_profs=None, batch_size=128, hide_shap_output=False): """ Given input sequences and control profiles, returns hypothetical scores for the input sequences. Arguments: `input_seqs`: a B x I x 4 array `cont_profs`: a B x (T or 1) x O x S array, or None `batch_size`: batch size for computation `hide_shap_output`: if True, do not show any warnings from DeepSHAP Returns a B x I x 4 array containing hypothetical importance scores for each of the B input sequences. """ scores = np.empty_like(input_seqs) input_seqs_t = place_tensor(torch.tensor(input_seqs)).float() try: if hide_shap_output: hide_stdout() if controls is None: return explainer.shap_values([input_seqs_t], progress_message=None)[0] else: cont_profs_t = place_tensor(torch.tensor(cont_profs)).float() return explainer.shap_values([input_seqs_t, cont_profs_t], progress_message=None)[0] except Exception as e: raise e finally: show_stdout()
def main(): global model, prof, count device = torch.device("cuda") if torch.cuda.is_available() \ else torch.device("cpu") model = create_model() model = model.to(device) np.random.seed(20191013) x = np.random.randint(2, size=[10, 4, 1346]) y = (np.random.randint(5, size=[10, 3, 2, 1000]), np.random.randint(5, size=[10, 3, 2, 1000])) # dataset = TestDataset([x], [y]) # data_loader = torch.utils.data.DataLoader( # dataset, batch_size=None, collate_fn=lambda x: x # ) torch.autograd.set_detect_anomaly(True) input_seq = util.place_tensor(torch.tensor(x)).float() tf_prof = util.place_tensor(torch.tensor(y[0])).float() cont_prof = util.place_tensor(torch.tensor(y[1])).float() pred_prof, pred_count = model(input_seq, cont_prof) loss = model.correctness_loss(tf_prof, pred_prof, pred_count, 1) loss.backward()
def sparsity_att_prior_loss(self, status, input_grads): """ Computes an attribution prior loss for some given training examples, by rewarding sparsity of the attributions. Arguments: `status`: a B-tensor, where B is the batch size; each entry is 1 if that example is to be treated as a positive example, and 0 otherwise `input_grads`: a B x L x D tensor, where B is the batch size, L is the length of the input, and D is the dimensionality of each input base; this needs to be the gradients of the input with respect to the output (for multiple tasks, this gradient needs to be aggregated); this should be *gradient times input* Returns a single scalar Tensor consisting of the attribution loss for the batch. """ abs_grads = torch.sum(torch.abs(input_grads), dim=2) # Only do the positives pos_grads = abs_grads[status == 1] # Shape: B' x L # Loss for positives if pos_grads.nelement(): # Compute all pairwise differences seq_len = pos_grads.size(1) seq_matrix = pos_grads.repeat(1, seq_len).view(-1, seq_len, seq_len) diffs = torch.abs(seq_matrix - seq_matrix.transpose(1, 2)) norm = seq_len * torch.sum(pos_grads, dim=1) return -torch.mean(torch.sum(diffs, dim=(1, 2)) / norm) else: return place_tensor(torch.zeros(1))
def smoothness_att_prior_loss(self, status, input_grads): """ Computes an attribution prior loss for some given training examples, by rewarding smoothness between neighbors' attributions. Arguments: `status`: a B-tensor, where B is the batch size; each entry is 1 if that example is to be treated as a positive example, and 0 otherwise `input_grads`: a B x L x D tensor, where B is the batch size, L is the length of the input, and D is the dimensionality of each input base; this needs to be the gradients of the input with respect to the output (for multiple tasks, this gradient needs to be aggregated); this should be *gradient times input* Returns a single scalar Tensor consisting of the attribution loss for the batch. """ abs_grads = torch.sum(torch.abs(input_grads), dim=2) # Only do the positives pos_grads = abs_grads[status == 1] # Loss for positives if pos_grads.nelement(): # Compute neighbor differences diffs = torch.abs(pos_grads[:, 1:] - pos_grads[:, :-1]) return torch.mean(torch.sum(diffs, dim=1)) else: return place_tensor(torch.zeros(1))
def create_input_seq_background(input_seq, input_length, bg_size=10, seed=20200219): """ From the input sequence to a model, generates a set of background sequences to perform interpretation against. Arguments: `input_seq`: I x 4 tensor of one-hot encoded input sequence, or None `input_length`: length of input, I `bg_size`: the number of background examples to generate, G Returns a G x I x 4 tensor containing randomly dinucleotide-shuffles of the original input sequence. If `input_seq` is None, then a G x I x 4 tensor of all 0s is returned. """ if input_seq is None: input_seq_bg_shape = (bg_size, input_length, 4) return place_tensor(torch.zeros(input_seq_bg_shape)).float() # Do dinucleotide shuffles input_seq_np = input_seq.cpu().numpy() rng = np.random.RandomState(seed) input_seq_bg_np = dinuc_shuffle(input_seq_np, bg_size, rng=rng) return place_tensor(torch.tensor(input_seq_bg_np)).float()
def predict_func(input_seq_batch): # Return the set of outputs for the batch of input sequences num_in_batch = len(input_seq_batch) output_vals = np.empty(num_in_batch) num_batches = int(np.ceil(num_in_batch / batch_size)) for i in range(num_batches): batch_slice = slice(i * batch_size, (i + 1) * batch_size) logit_preds = model( place_tensor(torch.tensor( input_seq_batch[batch_slice])).float(), ) logit_preds = logit_preds.detach().cpu().numpy() if task_index is None: output_vals[batch_slice] = np.sum(logit_preds, axis=1) else: output_vals[batch_slice] = logit_preds[:, task_index] return output_vals
def explain_fn(input_seqs, hide_shap_output): """ Given input sequences, returns hypothetical. Arguments: `input_seqs`: a B x I x 4 array `hide_shap_output`: if True, do not show any warnings from DeepSHAP Returns a B x I x 4 array containing hypothetical importance scores for each of the B input sequences. """ input_seqs_t = place_tensor(torch.tensor(input_seqs)).float() try: if hide_shap_output: hide_stdout() return explainer.shap_values([input_seqs_t], progress_message=None)[0] except Exception as e: raise e finally: show_stdout()
def create_profile_control_background(control_profs, profile_length, num_tasks, num_strands, controls="matched", bg_size=10): """ Generates a background for a set of profile controls. In general, this is the given control profiles, copied a number of times (i.e. the background for controls should always be the same). Note this is only used for profile models. Arguments: `control_profs`: (T or 1) x O x S tensor of control profiles, or None `profile_length`: length of profile, O `num_tasks`: number of tasks, T `num_strands`: number of strands, S `controls`: the kind of controls used: "matched" or "shared"; if "matched", the control profiles taken in and returned are T x O x S; if "shared", the profiles are 1 x O x S `bg_size`: the number of background examples to generate, G Returns the tensor of `control_profs`, replicated G times. If `controls` is "matched", this becomes a G x T x O x S tensor; if `controls` is "shared", this is a G x 1 x O x S tensor. If `control_profs` is None, then a tensor of all 0s is returned, whose shape is determined by `controls`. """ assert controls in ("matched", "shared") if controls == "matched": control_profs_bg_shape = (bg_size, num_tasks, profile_length, num_strands) else: control_profs_bg_shape = (bg_size, 1, profile_length, num_strands) if control_profs is None: return place_tensor(torch.zeros(control_profs_bg_shape)).float() # Replicate `control_profs` return torch.stack([control_profs] * bg_size, dim=0)
def _get_profile_model_predictions_batch( model, coords, num_tasks, input_func, controls=None, fourier_att_prior_freq_limit=200, fourier_att_prior_freq_limit_softness=0.2, att_prior_grad_smooth_sigma=3, return_losses=False, return_gradients=False): """ Fetches the necessary data from the given coordinates or bin indices and runs it through a profile or binary model. This will perform computation in a single batch. Arguments: `model`: a trained `ProfilePredictorWithMatchedControls`, `ProfilePredictorWithSharedControls`, or `ProfilePredictorWithoutControls` `coords`: a B x 3 array of coordinates to compute outputs for `num_tasks`: number of tasks for the model `input_func`: a function that takes in `coords` and returns the B x I x 4 array of one-hot sequences and the B x (T or T + 1 or 2T) x O x S array of profiles (perhaps with controls) `controls`: the type of control profiles (if any) used in model; can be "matched" (each task has a matched control), "shared" (all tasks share a control), or None (no controls); must match the model class `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior loss `fourier_att_prior_freq_limit_softness`: degree of softness for limit `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients `return_losses`: if True, compute/return the loss values `return_gradients`: if True, compute/return the input gradients and sequences Returns a dictionary of the following structure: true_profs: true profile raw counts (B x T x O x S) log_pred_profs: predicted profile log probabilities (B x T x O x S) true_counts: true total counts (B x T x S) log_pred_counts: predicted log counts (B x T x S) prof_losses: profile NLL losses (B-array), if `return_losses` is True count_losses: counts MSE losses (B-array) if `return_losses` is True att_losses: prior losses (B-array), if `return_losses` is True input_seqs: one-hot input sequences (B x I x 4), if `return_gradients` is true input_grads: "hypothetical" input gradients (B x I x 4), if `return_gradients` is true """ result = {} input_seqs, profiles = input_func(coords) if return_gradients: input_seqs_np = input_seqs model.zero_grad( ) # Zero out weights because we are computing gradients input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float() profiles = model_util.place_tensor(torch.tensor(profiles)).float() if controls is not None: tf_profs = profiles[:, :num_tasks, :, :] cont_profs = profiles[:, num_tasks:, :, :] # Last half or just one else: tf_profs, cont_profs = profiles, None if return_losses or return_gradients: input_seqs.requires_grad = True # Set gradient required logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs) # Subtract mean along output profile dimension; this wouldn't change # softmax probabilities, but normalizes the magnitude of gradients norm_logit_pred_profs = logit_pred_profs - \ torch.mean(logit_pred_profs, dim=2, keepdim=True) # Weight by post-softmax probabilities, but do not take the # gradients of these probabilities; this upweights important regions # exponentially pred_prof_probs = profile_models.profile_logits_to_log_probs( logit_pred_profs).detach() weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs input_grads, = torch.autograd.grad( weighted_norm_logits, input_seqs, grad_outputs=model_util.place_tensor( torch.ones(weighted_norm_logits.size())), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across strands and tasks ) input_grads_np = input_grads.detach().cpu().numpy() input_seqs.requires_grad = False # Reset gradient required else: logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs) result["true_profs"] = tf_profs.detach().cpu().numpy() result["true_counts"] = np.sum(result["true_profs"], axis=2) logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy() result["log_pred_profs"] = profile_models.profile_logits_to_log_probs( logit_pred_profs_np) result["log_pred_counts"] = log_pred_counts.detach().cpu().numpy() if return_losses: log_pred_profs = profile_models.profile_logits_to_log_probs( logit_pred_profs) num_samples = log_pred_profs.size(0) result["prof_losses"] = np.empty(num_samples) result["count_losses"] = np.empty(num_samples) result["att_losses"] = np.empty(num_samples) # Compute losses separately for each example for i in range(num_samples): _, prof_loss, count_loss = model.correctness_loss( tf_profs[i:i + 1], log_pred_profs[i:i + 1], log_pred_counts[i:i + 1], 1, return_separate_losses=True) att_loss = model.fourier_att_prior_loss( model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1], fourier_att_prior_freq_limit, fourier_att_prior_freq_limit_softness, att_prior_grad_smooth_sigma) result["prof_losses"][i] = prof_loss result["count_losses"][i] = count_loss result["att_losses"][i] = att_loss if return_gradients: result["input_seqs"] = input_seqs_np result["input_grads"] = input_grads_np return result
def _get_binary_model_predictions_batch( model, bins, input_func, fourier_att_prior_freq_limit=150, fourier_att_prior_freq_limit_softness=0.2, att_prior_grad_smooth_sigma=3, return_losses=False, return_gradients=False): """ Arguments: `model`: a trained `BinaryPredictor`, `bins`: an N-array of bin indices to compute outputs for `input_func`: a function that takes in `bins` and returns the B x I x 4 array of one-hot sequences, the B x T array of output values, and B x 3 array of underlying coordinates for the input sequence `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior loss `fourier_att_prior_freq_limit_softness`: degree of softness for limit `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients `return_losses`: if True, compute/return the loss values `return_gradients`: if True, compute/return the input gradients and sequences Returns a dictionary of the following structure: true_vals: true binary values (B x T) pred_vals: predicted probabilities (B x T) coords: coordinates used for prediction (B x 3 object array) corr_losses: correctness losses (B-array) if `return_losses` is True att_losses: prior losses (B-array), if `return_losses` is True input_seqs: one-hot input sequences (B x I x 4), if `return_gradients` is True input_grads: "hypothetical" input gradients (B x I x 4), if `return_gradients` is true """ result = {} input_seqs, output_vals, coords = input_func(bins) output_vals_np = output_vals if return_gradients: input_seqs_np = input_seqs model.zero_grad() input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float() output_vals = model_util.place_tensor(torch.tensor(output_vals)).float() if return_losses or return_gradients: input_seqs.requires_grad = True # Set gradient required logit_pred_vals = model(input_seqs) # Compute the gradients of the output with respect to the input input_grads, = torch.autograd.grad( logit_pred_vals, input_seqs, grad_outputs=model_util.place_tensor( torch.ones(logit_pred_vals.size())), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across tasks ) input_grads_np = input_grads.detach().cpu().numpy() input_seqs.requires_grad = False # Reset gradient required else: logit_pred_vals = model(input_seqs) status, input_grads = None, None result["true_vals"] = output_vals_np logit_pred_vals_np = logit_pred_vals.detach().cpu().numpy() result["pred_vals"] = binary_models.binary_logits_to_probs( logit_pred_vals_np) result["coords"] = coords if return_losses: num_samples = logit_pred_vals.size(0) result["corr_losses"] = np.empty(num_samples) result["att_losses"] = np.empty(num_samples) # Compute losses separately for each example for i in range(num_samples): corr_loss = model.correctness_loss(output_vals[i:i + 1], logit_pred_vals[i:i + 1], True) att_loss = model.fourier_att_prior_loss( model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1], fourier_att_prior_freq_limit, fourier_att_prior_freq_limit_softness, att_prior_grad_smooth_sigma) result["corr_losses"][i] = corr_loss result["att_losses"][i] = att_loss if return_gradients: result["input_seqs"] = input_seqs_np result["input_grads"] = input_grads_np return result
def run_epoch( data_loader, mode, model, epoch_num, params, optimizer=None, return_data=False, seq_mode=True, att_prior_loss_weight=None, ): """ Runs the data from the data loader once through the model, to train, validate, or predict. Arguments: `data_loader`: an instantiated `DataLoader` instance that gives batches of data; each batch must yield the input sequences, profiles, statuses, coordinates, and peaks; if `controls` is "matched", profiles must be such that the first half are prediction (target) profiles, and the second half are control profiles; if `controls` is "shared", the last set of profiles is a shared control; otherwise, all tasks are prediction profiles `mode`: one of "train", "eval"; if "train", run the epoch and perform backpropagation; if "eval", only do evaluation `model`: the current PyTorch model being trained/evaluated `epoch_num`: 0-indexed integer representing the current epoch `optimizer`: an instantiated PyTorch optimizer, for training mode `return_data`: if specified, returns the following as NumPy arrays: true profile raw counts (N x T x O x S), predicted profile log probabilities (N x T x O x S), true total counts (N x T x S), predicted log counts (N x T x S), coordinates used (N x 3 object array), input gradients (N x I x 4), and input sequences (N x I x4); if the attribution prior is not used, the gradients will be garbage Returns lists of overall losses, correctness losses, attribution prior losses, profile losses, and count losses, where each list is over all batches. If the attribution prior loss is not computed, then it will be all 0s. If `return_data` is True, then more things will be returned after these. """ num_tasks = params["num_tasks"] controls = params["controls"] if att_prior_loss_weight is None: att_prior_loss_weight = params["att_prior_loss_weight"] batch_size = params["batch_size"] revcomp = params["revcomp"] input_length = params["input_length"] input_depth = params["input_depth"] profile_length = params["profile_length"] gpu_id = params.get("gpu_id") assert mode in ("train", "eval") if mode == "train": assert optimizer is not None else: assert optimizer is None data_loader.dataset.on_epoch_start() # Set-up the epoch num_batches = len(data_loader.dataset) t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---") if mode == "train": model.train() # Switch to training mode torch.set_grad_enabled(True) batch_losses, corr_losses, att_losses = [], [], [] prof_losses, count_losses = [], [] if return_data: # Allocate empty NumPy arrays to hold the results num_samples_exp = num_batches * batch_size num_samples_exp *= 2 if revcomp else 1 # Real number of samples can be smaller because of partial last batch profile_shape = (num_samples_exp, num_tasks, profile_length, 2) count_shape = (num_samples_exp, num_tasks, 2) all_log_pred_profs = np.empty(profile_shape) all_log_pred_counts = np.empty(count_shape) all_true_profs = np.empty(profile_shape) all_true_counts = np.empty(count_shape) all_true_profs_trans = np.empty(profile_shape) all_true_counts_trans = np.empty(count_shape) all_input_seqs = np.empty((num_samples_exp, input_length, input_depth)) all_input_grads = np.empty( (num_samples_exp, input_length, input_depth)) all_coords = np.empty((num_samples_exp, 3), dtype=object) num_samples_seen = 0 # Real number of samples seen if seq_mode: model.set_seq_mode() else: model.set_aux_mode() for input_seqs, profiles, profiles_trans, statuses, coords, peaks in t_iter: if return_data: input_seqs_np = input_seqs input_seqs = util.place_tensor(torch.tensor(input_seqs), index=gpu_id).float() profiles = util.place_tensor(torch.tensor(profiles), index=gpu_id).float() profiles_trans = util.place_tensor(torch.tensor(profiles_trans), index=gpu_id).float() if controls is not None: tf_profs = profiles[:, :num_tasks, :, :].contiguous() cont_profs = profiles[:, num_tasks:, :, :].contiguous( ) # Last half or just one tf_profs_trans = profiles_trans[:, :num_tasks, :, :].contiguous() cont_profs_trans = profiles_trans[:, num_tasks:, :, :].contiguous( ) # Last half or just one else: tf_profs, cont_profs = profiles, None tf_profs_trans, cont_profs_trans = profiles, None # Clear gradients from last batch if training if mode == "train": optimizer.zero_grad() elif att_prior_loss_weight > 0: # Not training mode, but we still need to zero out weights because # we are computing the input gradients model.zero_grad() if att_prior_loss_weight > 0: input_seqs.requires_grad = True # Set gradient required logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs, tf_profs_trans) # print(input_seqs.is_contiguous()) #### # print(cont_profs.is_contiguous()) #### # print(tf_profs_trans.is_contiguous()) #### # Subtract mean along output profile dimension; this wouldn't change # softmax probabilities, but normalizes the magnitude of gradients norm_logit_pred_profs = logit_pred_profs - \ torch.mean(logit_pred_profs, dim=2, keepdim=True) # Weight by post-softmax probabilities, but do not take the # gradients of these probabilities; this upweights important regions # exponentially pred_prof_probs = profile_models.profile_logits_to_log_probs( logit_pred_profs).detach() weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs if seq_mode: model.set_seq_mode() input_grads, = torch.autograd.grad( weighted_norm_logits, input_seqs, grad_outputs=util.place_tensor(torch.ones( weighted_norm_logits.size()), index=gpu_id), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across strands and tasks ) # print(torch.autograd.grad(torch.sum(input_grads), input_seqs, retain_graph=True)) #### # print(input_grads.shape, input_seqs.shape) #### # input_grads = input_grads.contiguous() if return_data: input_grads_np = input_grads.detach().cpu().numpy() # input_grads = input_grads.detach() #### input_grads = input_grads * input_seqs # Gradient * input # input_grads = input_seqs #### status = util.place_tensor(torch.tensor(statuses), index=gpu_id) status[status != 0] = 1 # Set to 1 if not negative example input_seqs.requires_grad = False # Reset gradient required else: logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs, tf_profs_trans) status, input_grads = None, None loss, (corr_loss, att_loss), (prof_loss, count_loss) = model_loss(model, tf_profs, logit_pred_profs, log_pred_counts, epoch_num, status=status, input_grads=input_grads) # print(input_grads) #### # print(input_grads.shape) #### if mode == "train": # att_loss.backward() #### loss.backward() # Compute gradient optimizer.step() # Update weights through backprop # if not ignore_aux: # model.unfreeze_ptp_layers() batch_losses.append(loss.item()) corr_losses.append(corr_loss.item()) att_losses.append(att_loss.item()) prof_losses.append(prof_loss.item()) count_losses.append(count_loss.item()) t_iter.set_description("\tLoss: %6.4f" % loss.item()) if return_data: logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy() log_pred_counts_np = log_pred_counts.detach().cpu().numpy() true_profs_np = tf_profs.detach().cpu().numpy() true_counts = np.sum(true_profs_np, axis=2) true_profs_trans_np = tf_profs_trans.detach().cpu().numpy() true_counts_trans = np.sum(true_profs_trans_np, axis=2) num_in_batch = true_counts.shape[0] # Turn logit profile predictions into log probabilities log_pred_profs = profile_models.profile_logits_to_log_probs( logit_pred_profs_np, axis=2) # Fill in the batch data/outputs into the preallocated arrays start, end = num_samples_seen, num_samples_seen + num_in_batch all_log_pred_profs[start:end] = log_pred_profs all_log_pred_counts[start:end] = log_pred_counts_np all_true_profs[start:end] = true_profs_np all_true_counts[start:end] = true_counts all_true_profs_trans[start:end] = true_profs_trans_np all_true_counts_trans[start:end] = true_counts_trans all_input_seqs[start:end] = input_seqs_np if att_prior_loss_weight: all_input_grads[start:end] = input_grads_np all_coords[start:end] = coords num_samples_seen += num_in_batch if return_data: # Truncate the saved data to the proper size, based on how many # samples actually seen all_log_pred_profs = all_log_pred_profs[:num_samples_seen] all_log_pred_counts = all_log_pred_counts[:num_samples_seen] all_true_profs = all_true_profs[:num_samples_seen] all_true_counts = all_true_counts[:num_samples_seen] all_true_profs_trans = all_true_profs_trans[:num_samples_seen] all_true_counts_trans = all_true_counts_trans[:num_samples_seen] all_input_seqs = all_input_seqs[:num_samples_seen] all_input_grads = all_input_grads[:num_samples_seen] all_coords = all_coords[:num_samples_seen] return batch_losses, corr_losses, att_losses, prof_losses, \ count_losses, all_true_profs, all_log_pred_profs, \ all_true_counts, all_log_pred_counts, all_coords, all_input_grads, \ all_input_seqs, all_true_profs_trans, all_true_counts_trans else: return batch_losses, corr_losses, att_losses, prof_losses, count_losses
def fourier_att_prior_loss(self, status, input_grads, freq_limit, limit_softness, att_prior_grad_smooth_sigma, gpu_id=None): """ Computes an attribution prior loss for some given training examples, using a Fourier transform form. Arguments: `status`: a B-tensor, where B is the batch size; each entry is 1 if that example is to be treated as a positive example, and 0 otherwise `input_grads`: a B x L x D tensor, where B is the batch size, L is the length of the input, and D is the dimensionality of each input base; this needs to be the gradients of the input with respect to the output (for multiple tasks, this gradient needs to be aggregated); this should be *gradient times input* `freq_limit`: the maximum integer frequency index, k, to consider for the loss; this corresponds to a frequency cut-off of pi * k / L; k should be less than L / 2 `limit_softness`: amount to soften the limit by, using a hill function; None means no softness `att_prior_grad_smooth_sigma`: amount to smooth the gradient before computing the loss Returns a single scalar Tensor consisting of the attribution loss for the batch. """ # return place_tensor(torch.zeros(1), index=gpu_id) #### abs_grads = torch.sum(torch.abs(input_grads), dim=2) # return torch.mean(abs_grads) #### # Smooth the gradients grads_smooth = smooth_tensor_1d(abs_grads, att_prior_grad_smooth_sigma, gpu_id=gpu_id) # Only do the positives pos_grads = grads_smooth[status == 1] # Loss for positives if pos_grads.nelement(): pos_fft = torch.rfft(pos_grads, 1) pos_mags = torch.norm(pos_fft, dim=2) pos_mag_sum = torch.sum(pos_mags, dim=1, keepdim=True) pos_mag_sum[pos_mag_sum == 0] = 1 # Keep 0s when the sum is 0 pos_mags = pos_mags / pos_mag_sum # Cut off DC pos_mags = pos_mags[:, 1:] # Construct weight vector weights = place_tensor(torch.ones_like(pos_mags), index=gpu_id) if limit_softness is None: weights[:, freq_limit:] = 0 else: x = place_tensor(torch.arange( 1, pos_mags.size(1) - freq_limit + 1), index=gpu_id).float() weights[:, freq_limit:] = 1 / (1 + torch.pow(x, limit_softness)) # Multiply frequency magnitudes by weights pos_weighted_mags = pos_mags * weights # Add up along frequency axis to get score pos_score = torch.sum(pos_weighted_mags, dim=1) pos_loss = 1 - pos_score return torch.mean(pos_loss) else: return place_tensor(torch.zeros(1), index=gpu_id)
def run_epoch(data_loader, mode, model, epoch_num, num_tasks, att_prior_loss_weight, batch_size, revcomp, input_length, input_depth, optimizer=None, return_data=False): """ Runs the data from the data loader once through the model, to train, validate, or predict. Arguments: `data_loader`: an instantiated `DataLoader` instance that gives batches of data; each batch must yield the input sequences, the output values, the status, and coordinates `mode`: one of "train", "eval"; if "train", run the epoch and perform backpropagation; if "eval", only do evaluation `model`: the current PyTorch model being trained/evaluated `epoch_num`: 0-indexed integer representing the current epoch `optimizer`: an instantiated PyTorch optimizer, for training mode `return_data`: if specified, returns the following as NumPy arrays: true binding values (0, 1, or -1) (N x T), predicted binding probabilities (N x T), the underlying sequence coordinates (N x 3 object array), input gradients (N x I x 4), and the input sequences (N x I x 4); if the attribution prior is not used, the gradients will be garbage Returns lists of overall losses, correctness losses, and attribution prior losses, where each list is over all batches. If the attribution prior loss is not computed, then it will be all 0s. If `return_data` is True, then more things will be returned after these. """ assert mode in ("train", "eval") if mode == "train": assert optimizer is not None else: assert optimizer is None data_loader.dataset.on_epoch_start() # Set-up the epoch num_batches = len(data_loader.dataset) t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---") if mode == "train": model.train() # Switch to training mode torch.set_grad_enabled(True) batch_losses, corr_losses, att_losses = [], [], [] if return_data: # Allocate empty NumPy arrays to hold the results num_samples_exp = num_batches * batch_size num_samples_exp *= 2 if revcomp else 1 # Real number of samples can be smaller because of partial last batch all_true_vals = np.empty((num_samples_exp, num_tasks)) all_pred_vals = np.empty((num_samples_exp, num_tasks)) all_input_seqs = np.empty((num_samples_exp, input_length, input_depth)) all_input_grads = np.empty( (num_samples_exp, input_length, input_depth)) all_coords = np.empty((num_samples_exp, 3), dtype=object) num_samples_seen = 0 # Real number of samples seen for input_seqs, output_vals, statuses, coords in t_iter: if return_data: input_seqs_np = input_seqs output_vals_np = output_vals input_seqs = util.place_tensor(torch.tensor(input_seqs)).float() output_vals = util.place_tensor(torch.tensor(output_vals)).float() # Clear gradients from last batch if training if mode == "train": optimizer.zero_grad() elif att_prior_loss_weight > 0: # Not training mode, but we still need to zero out weights because # we are computing the input gradients model.zero_grad() if att_prior_loss_weight > 0: input_seqs.requires_grad = True # Set gradient required logit_pred_vals = model(input_seqs) # Compute the gradients of the output with respect to the input input_grads, = torch.autograd.grad( logit_pred_vals, input_seqs, grad_outputs=util.place_tensor( torch.ones(logit_pred_vals.size())), retain_graph=True, create_graph=True # We'll be operating on the gradient itself, so we need to # create the graph # Gradients are summed across tasks ) if return_data: input_grads_np = input_grads.detach().cpu().numpy() input_grads = input_grads * input_seqs # Gradient * input status = util.place_tensor(torch.tensor(statuses)) input_seqs.requires_grad = False # Reset gradient required else: logit_pred_vals = model(input_seqs) status, input_grads = None, None loss, (corr_loss, att_loss) = model_loss(model, output_vals, logit_pred_vals, epoch_num, status=status, input_grads=input_grads) if mode == "train": loss.backward() # Compute gradient optimizer.step() # Update weights through backprop batch_losses.append(loss.item()) corr_losses.append(corr_loss.item()) att_losses.append(att_loss.item()) t_iter.set_description("\tLoss: %6.4f" % loss.item()) if return_data: logit_pred_vals_np = logit_pred_vals.detach().cpu().numpy() # Turn logits into probabilities pred_vals = binary_models.binary_logits_to_probs( logit_pred_vals_np) num_in_batch = pred_vals.shape[0] # Fill in the batch data/outputs into the preallocated arrays start, end = num_samples_seen, num_samples_seen + num_in_batch all_true_vals[start:end] = output_vals_np all_pred_vals[start:end] = pred_vals all_input_seqs[start:end] = input_seqs_np if att_prior_loss_weight: all_input_grads[start:end] = input_grads_np all_coords[start:end] = coords num_samples_seen += num_in_batch if return_data: # Truncate the saved data to the proper size, based on how many # samples actually seen all_true_vals = all_true_vals[:num_samples_seen] all_pred_vals = all_pred_vals[:num_samples_seen] all_input_seqs = all_input_seqs[:num_samples_seen] all_input_grads = all_input_grads[:num_samples_seen] all_coords = all_coords[:num_samples_seen] return batch_losses, corr_losses, att_losses, all_true_vals, \ all_pred_vals, all_coords, all_input_grads, all_input_seqs else: return batch_losses, corr_losses, att_losses
def model_loss(model, true_vals, logit_pred_vals, epoch_num, avg_class_loss, att_prior_loss_weight, att_prior_loss_weight_anneal_type, att_prior_loss_weight_anneal_speed, att_prior_grad_smooth_sigma, fourier_att_prior_freq_limit, fourier_att_prior_freq_limit_softness, att_prior_loss_only, l2_reg_loss_weight, input_grads=None, status=None): """ Computes the loss for the model. Arguments: `model`: the model being trained `true_vals`: a B x T tensor, where B is the batch size and T is the number of output tasks, containing the true binary values `logit_pred_vals`: a B x T tensor containing the predicted logits `epoch_num`: a 0-indexed integer representing the current epoch `input_grads`: a B x I x D tensor, where I is the input length and D is the input depth; this is the gradient of the output with respect to the input, times the input itself; only needed when attribution prior loss weight is positive `status`: a B-tensor, where B is the batch size; each entry is 1 if that that example is to be treated as a positive example, 0 if negative, and -1 if ambiguous; only needed when attribution prior loss weight is positive Returns a scalar Tensor containing the loss for the given batch, and a pair consisting of the correctness loss and the attribution prior loss. If the attribution prior loss is not computed at all, then 0 will be in its place, instead. """ corr_loss = model.correctness_loss(true_vals, logit_pred_vals, avg_class_loss) final_loss = corr_loss if att_prior_loss_weight > 0: att_prior_loss = model.fourier_att_prior_loss( status, input_grads, fourier_att_prior_freq_limit, fourier_att_prior_freq_limit_softness, att_prior_grad_smooth_sigma) # att_prior_loss = model.smoothness_att_prior_loss(status, input_grads) # att_prior_loss = model.sparsity_att_prior_loss(status, input_grads) if att_prior_loss_weight_anneal_type is None: weight = att_prior_loss_weight elif att_prior_loss_weight_anneal_type == "inflate": exp = np.exp(-att_prior_loss_weight_anneal_speed * epoch_num) weight = att_prior_loss_weight * ((2 / (1 + exp)) - 1) elif att_prior_loss_weight_anneal_type == "deflate": exp = np.exp(-att_prior_loss_weight_anneal_speed * epoch_num) weight = att_prior_loss_weight * exp if att_prior_loss_only: final_loss = att_prior_loss else: final_loss = final_loss + (weight * att_prior_loss) else: att_prior_loss = torch.zeros(1) # If necessary, add the L2 penalty if l2_reg_loss_weight > 0: l2_loss = util.place_tensor(torch.tensor(0).float()) for param in model.parameters(): if param.requires_grad_: l2_loss = l2_loss + torch.sum(param * param) final_loss = final_loss + (l2_reg_loss_weight * l2_loss) return final_loss, (corr_loss, att_prior_loss)