def get_metrics(profs_preds_logits, counts_preds, num_runs): profs_preds_logs = profs_preds_logits - scipy.special.logsumexp( profs_preds_logits, axis=3, keepdims=True) profs_preds_logs_o = profs_preds_logs[:, 0] counts_preds_o = counts_preds[:, 0] profs_preds_o = np.exp(profs_preds_logs_o) * np.broadcast_to( counts_preds_o[:, :, np.newaxis, :], profs_preds_logs_o.shape) # print(profs_preds_logs_o.shape) #### # print(counts_preds_o.shape) #### # print(profs_preds_o.shape) #### metrics = {} for i in range(1, num_runs + 1): profs_preds_logs_a = profs_preds_logs[:, i] counts_preds_a = counts_preds[:, i] # print(profs_preds_logs_o.shape) #### # print(profs_preds_logs_a.shape) #### metrics_run = profile_performance.compute_performance_metrics( profs_preds_o, profs_preds_logs_a, counts_preds_o, counts_preds_a, print_updates=False, calc_counts=False) metrics_run["counts_diff"] = counts_preds_a - counts_preds_o for k, v in metrics_run.items(): metrics.setdefault(k, []).append(v) return metrics
def train_model(loaders, trans_id, params, _run): """ Trains the network for the given training and validation data. Arguments: `train_loader` (DataLoader): a data loader for the training data `val_loader` (DataLoader): a data loader for the validation data `test_summit_loader` (DataLoader): a data loader for the test data, with coordinates centered at summits `test_peak_loader` (DataLoader): a data loader for the test data, with coordinates tiled across peaks `test_genome_loader` (DataLoader): a data loader for the test data, with summit-centered coordinates augmented with sampled negatives Note that all data loaders are expected to yield the 1-hot encoded sequences, profiles, statuses, source coordinates, and source peaks. """ num_epochs = params["num_epochs"] num_epochs_prof = params["num_epochs_prof"] learning_rate = params["learning_rate"] early_stopping = params["early_stopping"] early_stop_hist_len = params["early_stop_hist_len"] early_stop_min_delta = params["early_stop_min_delta"] train_seed = params["train_seed"] train_loader_1 = loaders["train_1"] val_loader_1 = loaders["val_1"] train_loader_2 = loaders["train_2"] val_loader_2 = loaders["val_2"] test_genome_loader = loaders["test_genome"] test_loaders = [ (loaders["test_summit_union"], "summit_union"), (loaders["test_summit_to_sig"], "summit_to_sig"), (loaders["test_summit_from_sig"], "summit_from_sig"), (loaders["test_summit_to_sig_from_sig"], "summit_to_sig_from_sig"), (loaders["test_summit_to_insig_from_sig"], "summit_to_insig_from_sig"), (loaders["test_summit_to_sig_from_insig"], "summit_to_sig_from_insig"), ] run_num = _run._id output_dir = os.path.join(MODEL_DIR, f"{trans_id}_{run_num}") os.makedirs(output_dir, exist_ok=True) if train_seed: torch.manual_seed(train_seed) device = torch.device(f"cuda:{params['gpu_id']}") if torch.cuda.is_available() \ else torch.device("cpu") # torch.backends.cudnn.enabled = False #### # torch.backends.cudnn.benchmark = True #### model = create_model(**params) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if early_stopping: val_epoch_loss_hist = [] best_val_epoch_loss = np.inf best_model_state = None best_model_epoch = None for epoch in range(num_epochs_prof): if torch.cuda.is_available: torch.cuda.empty_cache() # Clear GPU memory t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, \ t_count_losses = run_epoch( train_loader_1, "train", model, epoch, optimizer=optimizer, seq_mode=False ) train_epoch_loss = np.nanmean(t_batch_losses) print("Pre-Train epoch %d: average loss = %6.10f" % (epoch + 1, train_epoch_loss)) _run.log_scalar(f"{trans_id}_aux_train_epoch_loss", train_epoch_loss) _run.log_scalar(f"{trans_id}_aux_train_batch_losses", t_batch_losses) _run.log_scalar(f"{trans_id}_aux_train_corr_losses", t_corr_losses) _run.log_scalar(f"{trans_id}_aux_train_att_losses", t_att_losses) _run.log_scalar(f"{trans_id}_aux_train_prof_corr_losses", t_prof_losses) _run.log_scalar(f"{trans_id}_aux_train_count_corr_losses", t_count_losses) v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, \ v_count_losses = run_epoch( val_loader_1, "eval", model, epoch ) val_epoch_loss = np.nanmean(v_batch_losses) print("Pre-Valid epoch %d: average loss = %6.10f" % (epoch + 1, val_epoch_loss)) _run.log_scalar(f"{trans_id}_aux_val_epoch_loss", val_epoch_loss) _run.log_scalar(f"{trans_id}_aux_val_batch_losses", v_batch_losses) _run.log_scalar(f"{trans_id}_aux_val_corr_losses", v_corr_losses) _run.log_scalar(f"{trans_id}_aux_val_att_losses", v_att_losses) _run.log_scalar(f"{trans_id}_aux_val_prof_corr_losses", v_prof_losses) _run.log_scalar(f"{trans_id}_aux_val_count_corr_losses", v_count_losses) # Save trained model for the epoch savepath = os.path.join(output_dir, "model_aux_ckpt_epoch_%d.pt" % (epoch + 1)) util.save_model(model, savepath) # Save the model state dict of the epoch with the best validation loss if val_epoch_loss < best_val_epoch_loss: best_val_epoch_loss = val_epoch_loss best_model_state = model.state_dict() best_model_epoch = epoch # If losses are both NaN, then stop if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss): break # Check for early stopping if early_stopping: if len(val_epoch_loss_hist) < early_stop_hist_len + 1: # Not enough history yet; tack on the loss val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist else: # Tack on the new validation loss, kicking off the old one val_epoch_loss_hist = \ [val_epoch_loss] + val_epoch_loss_hist[:-1] best_delta = np.max(np.diff(val_epoch_loss_hist)) if best_delta < early_stop_min_delta: break # Not improving enough _run.log_scalar(f"{trans_id}_aux_best_epoch", best_model_epoch) # Compute evaluation metrics and log them # for data_loader, prefix in [ # (test_summit_loader, "summit"), # (test_peak_loader, "peak"), # # (test_genome_loader, "genomewide") # ]: for data_loader, prefix in test_loaders: print("Computing pretraining test metrics, %s:" % prefix) # Load in the state of the epoch with the best validation loss first model.load_state_dict(best_model_state) batch_losses, corr_losses, att_losses, prof_losses, count_losses, \ true_profs, log_pred_profs, true_counts, log_pred_counts, coords, \ input_grads, input_seqs, true_profs_trans, true_counts_trans = run_epoch( data_loader, "eval", model, 0, return_data=True ) _run.log_scalar(f"{trans_id}_aux_test_{prefix}_batch_losses", batch_losses) _run.log_scalar(f"{trans_id}_aux_test_{prefix}_corr_losses", corr_losses) _run.log_scalar(f"{trans_id}_aux_test_{prefix}_att_losses", att_losses) _run.log_scalar(f"{trans_id}_aux_test_{prefix}_prof_corr_losses", prof_losses) _run.log_scalar(f"{trans_id}_aux_test_{prefix}_count_corr_losses", count_losses) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts) if prefix == "summit_union": metrics_savepath = os.path.join(output_dir, "metrics_aux.pickle") else: metrics_savepath = None profile_performance.log_performance_metrics( metrics, f"{trans_id}_aux_{prefix}", _run, \ savepath=metrics_savepath, counts=(true_counts, true_counts_trans), coords=coords ) if early_stopping: val_epoch_loss_hist = [] best_val_epoch_loss = np.inf best_model_state = None best_model_epoch = None for epoch in range(num_epochs): if torch.cuda.is_available: torch.cuda.empty_cache() # Clear GPU memory t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, \ t_count_losses = run_epoch( train_loader_2, "train", model, epoch, optimizer=optimizer, seq_mode=True ) train_epoch_loss = np.nanmean(t_batch_losses) print("Train epoch %d: average loss = %6.10f" % (epoch + 1, train_epoch_loss)) _run.log_scalar(f"{trans_id}_train_epoch_loss", train_epoch_loss) _run.log_scalar(f"{trans_id}_train_batch_losses", t_batch_losses) _run.log_scalar(f"{trans_id}_train_corr_losses", t_corr_losses) _run.log_scalar(f"{trans_id}_train_att_losses", t_att_losses) _run.log_scalar(f"{trans_id}_train_prof_corr_losses", t_prof_losses) _run.log_scalar(f"{trans_id}_train_count_corr_losses", t_count_losses) v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, \ v_count_losses = run_epoch( val_loader_2, "eval", model, epoch ) val_epoch_loss = np.nanmean(v_batch_losses) print("Valid epoch %d: average loss = %6.10f" % (epoch + 1, val_epoch_loss)) _run.log_scalar(f"{trans_id}_val_epoch_loss", val_epoch_loss) _run.log_scalar(f"{trans_id}_val_batch_losses", v_batch_losses) _run.log_scalar(f"{trans_id}_val_corr_losses", v_corr_losses) _run.log_scalar(f"{trans_id}_val_att_losses", v_att_losses) _run.log_scalar(f"{trans_id}_val_prof_corr_losses", v_prof_losses) _run.log_scalar(f"{trans_id}_val_count_corr_losses", v_count_losses) # Save trained model for the epoch savepath = os.path.join(output_dir, "model_ckpt_epoch_%d.pt" % (epoch + 1)) util.save_model(model, savepath) # Save the model state dict of the epoch with the best validation loss if val_epoch_loss < best_val_epoch_loss: best_val_epoch_loss = val_epoch_loss best_model_state = model.state_dict() best_model_epoch = epoch # If losses are both NaN, then stop if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss): break # Check for early stopping if early_stopping: if len(val_epoch_loss_hist) < early_stop_hist_len + 1: # Not enough history yet; tack on the loss val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist else: # Tack on the new validation loss, kicking off the old one val_epoch_loss_hist = \ [val_epoch_loss] + val_epoch_loss_hist[:-1] best_delta = np.max(np.diff(val_epoch_loss_hist)) if best_delta < early_stop_min_delta: break # Not improving enough _run.log_scalar(f"{trans_id}_best_epoch", best_model_epoch) # Compute evaluation metrics and log them # for data_loader, prefix in [ # (test_summit_loader, "summit"), # (test_peak_loader, "peak"), # # (test_genome_loader, "genomewide") # ]: for data_loader, prefix in test_loaders: print("Computing test metrics, %s:" % prefix) # Load in the state of the epoch with the best validation loss first model.load_state_dict(best_model_state) batch_losses, corr_losses, att_losses, prof_losses, count_losses, \ true_profs, log_pred_profs, true_counts, log_pred_counts, coords, \ input_grads, input_seqs, true_profs_trans, true_counts_trans = run_epoch( data_loader, "eval", model, 0, return_data=True ) _run.log_scalar(f"{trans_id}_test_{prefix}_batch_losses", batch_losses) _run.log_scalar(f"{trans_id}_test_{prefix}_corr_losses", corr_losses) _run.log_scalar(f"{trans_id}_test_{prefix}_att_losses", att_losses) _run.log_scalar(f"{trans_id}_test_{prefix}_prof_corr_losses", prof_losses) _run.log_scalar(f"{trans_id}_test_{prefix}_count_corr_losses", count_losses) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts) if prefix == "summit_union": metrics_savepath = os.path.join(output_dir, "metrics.pickle") else: metrics_savepath = None profile_performance.log_performance_metrics( metrics, f"{trans_id}_{prefix}", _run, \ savepath=metrics_savepath, counts=(true_counts, true_counts_trans), coords=coords )
def test_all_metrics_on_different_predictions(): np.random.seed(20191110) batch_size, num_tasks, prof_len = 50, 2, 1000 # Make some random true profiles that have some "peaks" true_profs = np.empty((batch_size, num_tasks, prof_len, 2)) ran = np.arange(prof_len) for i in range(batch_size): for j in range(num_tasks): pos_peak = (prof_len / 2) - np.random.randint(50) pos_sigma = np.random.random() * 5 pos_prof = np.exp(-((ran - pos_peak)**2) / (2 * (pos_sigma**2))) neg_peak = (prof_len / 2) + np.random.randint(50) neg_sigma = np.random.random() * 5 neg_prof = np.exp(-((ran - neg_peak)**2) / (2 * (neg_sigma**2))) count = np.random.randint(50, 500) true_profs[i, j, :, 0] = neg_prof / np.sum(neg_prof) * count true_profs[i, j, :, 1] = pos_prof / np.sum(pos_prof) * count true_profs = np.nan_to_num(true_profs) # NaN to 0 true_counts = np.sum(true_profs, axis=2) _run = FakeLogger() epsilon = 1e-50 # The smaller this is, the better Spearman correlation is # Make some "perfect" predicted profiles, which are identical to truth print("Testing all metrics on some perfect predictions...") pred_profs = true_profs pred_prof_probs = pred_profs / np.sum(pred_profs, axis=2, keepdims=True) pred_prof_probs = np.nan_to_num(pred_prof_probs) log_pred_profs = np.log(pred_prof_probs + epsilon) pred_counts = true_counts log_pred_counts = np.log(pred_counts + 1) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts) profile_performance.log_performance_metrics(metrics, "Perfect", _run) # Make some "good" predicted profiles by adding Gaussian noise to true print("Testing all metrics on some good predictions...") pred_profs = np.abs(true_profs + (np.random.randn(*true_profs.shape) * 3)) pred_prof_probs = pred_profs / np.sum(pred_profs, axis=2, keepdims=True) log_pred_profs = np.log(pred_prof_probs + epsilon) pred_counts = np.abs(true_counts + (np.random.randn(*true_counts.shape) * 10)) log_pred_counts = np.log(pred_counts + 1) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts) profile_performance.log_performance_metrics(metrics, "Good", _run) # Make some "bad" predicted profiles which are just Gaussian noise print("Testing all metrics on some bad predictions...") pred_profs = np.abs(np.random.randn(*true_profs.shape) * 3) pred_prof_probs = pred_profs / np.sum(pred_profs, axis=2, keepdims=True) log_pred_profs = np.log(pred_prof_probs + epsilon) pred_counts = np.abs(np.random.randint(200, size=true_counts.shape)) log_pred_counts = np.log(pred_counts + 1) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts) profile_performance.log_performance_metrics(metrics, "Bad", _run) print( "Warning: note that profile Spearman correlation is not so high, " +\ "even in the perfect case. This is because while the true profile " +\ "has 0 probability (or close to it) in most places, the predicted " +\ "profile will never have exactly 0 probability due to the logits. " )
def train_model( train_loader, val_loader, test_summit_loader, test_peak_loader, test_genome_loader, num_epochs, learning_rate, early_stopping, early_stop_hist_len, early_stop_min_delta, train_seed, _run ): """ Trains the network for the given training and validation data. Arguments: `train_loader` (DataLoader): a data loader for the training data `val_loader` (DataLoader): a data loader for the validation data `test_summit_loader` (DataLoader): a data loader for the test data, with coordinates centered at summits `test_peak_loader` (DataLoader): a data loader for the test data, with coordinates tiled across peaks `test_genome_loader` (DataLoader): a data loader for the test data, with summit-centered coordinates augmented with sampled negatives Note that all data loaders are expected to yield the 1-hot encoded sequences, profiles, statuses, source coordinates, and source peaks. """ run_num = _run._id output_dir = os.path.join(MODEL_DIR, str(run_num)) if train_seed: torch.manual_seed(train_seed) device = torch.device("cuda") if torch.cuda.is_available() \ else torch.device("cpu") model = create_model() model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if early_stopping: val_epoch_loss_hist = [] best_val_epoch_loss, best_model_state = float("inf"), None for epoch in range(num_epochs): if torch.cuda.is_available: torch.cuda.empty_cache() # Clear GPU memory t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, \ t_count_losses = run_epoch( train_loader, "train", model, epoch, optimizer=optimizer ) train_epoch_loss = np.nanmean(t_batch_losses) print( "Train epoch %d: average loss = %6.10f" % ( epoch + 1, train_epoch_loss ) ) _run.log_scalar("train_epoch_loss", train_epoch_loss) _run.log_scalar("train_batch_losses", t_batch_losses) _run.log_scalar("train_corr_losses", t_corr_losses) _run.log_scalar("train_att_losses", t_att_losses) _run.log_scalar("train_prof_corr_losses", t_prof_losses) _run.log_scalar("train_count_corr_losses", t_count_losses) v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, \ v_count_losses = run_epoch( val_loader, "eval", model, epoch ) val_epoch_loss = np.nanmean(v_batch_losses) print( "Valid epoch %d: average loss = %6.10f" % ( epoch + 1, val_epoch_loss ) ) _run.log_scalar("val_epoch_loss", val_epoch_loss) _run.log_scalar("val_batch_losses", v_batch_losses) _run.log_scalar("val_corr_losses", v_corr_losses) _run.log_scalar("val_att_losses", v_att_losses) _run.log_scalar("val_prof_corr_losses", v_prof_losses) _run.log_scalar("val_count_corr_losses", v_count_losses) # Save trained model for the epoch savepath = os.path.join( output_dir, "model_ckpt_epoch_%d.pt" % (epoch + 1) ) util.save_model(model, savepath) # Save the model state dict of the epoch with the best validation loss if val_epoch_loss < best_val_epoch_loss: best_val_epoch_loss = val_epoch_loss best_model_state = model.state_dict() # If losses are both NaN, then stop if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss): break # Check for early stopping if early_stopping: if len(val_epoch_loss_hist) < early_stop_hist_len + 1: # Not enough history yet; tack on the loss val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist else: # Tack on the new validation loss, kicking off the old one val_epoch_loss_hist = \ [val_epoch_loss] + val_epoch_loss_hist[:-1] best_delta = np.max(np.diff(val_epoch_loss_hist)) if best_delta < early_stop_min_delta: break # Not improving enough # Compute evaluation metrics and log them for data_loader, prefix in [ (test_summit_loader, "summit"), # (test_peak_loader, "peak"), # (test_genome_loader, "genomewide") ]: print("Computing test metrics, %s:" % prefix) # Load in the state of the epoch with the best validation loss first model.load_state_dict(best_model_state) batch_losses, corr_losses, att_losses, prof_losses, count_losses, \ true_profs, log_pred_profs, true_counts, log_pred_counts, coords, \ input_grads, input_seqs = run_epoch( data_loader, "eval", model, 0, return_data=True ) _run.log_scalar("test_%s_batch_losses" % prefix, batch_losses) _run.log_scalar("test_%s_corr_losses" % prefix, corr_losses) _run.log_scalar("test_%s_att_losses" % prefix, att_losses) _run.log_scalar("test_%s_prof_corr_losses" % prefix, prof_losses) _run.log_scalar("test_%s_count_corr_losses" % prefix, count_losses) metrics = profile_performance.compute_performance_metrics( true_profs, log_pred_profs, true_counts, log_pred_counts ) profile_performance.log_performance_metrics(metrics, prefix, _run)