def train_model(train_loader, val_loader, test_loader, num_epochs,
                learning_rate, early_stopping, early_stop_hist_len,
                early_stop_min_delta, train_seed, _run):
    """
    Trains the network for the given training and validation data.
    Arguments:
        `train_loader` (DataLoader): a data loader for the training data
        `val_loader` (DataLoader): a data loader for the validation data
        `test_loader` (DataLoader): a data loader for the test data
    Note that all data loaders are expected to yield the 1-hot encoded
    sequences, output values, source coordinates, and source peaks.
    """
    run_num = _run._id
    output_dir = os.path.join(MODEL_DIR, str(run_num))

    if train_seed:
        torch.manual_seed(train_seed)

    device = torch.device("cuda") if torch.cuda.is_available() \
        else torch.device("cpu")

    model = create_model()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if early_stopping:
        val_epoch_loss_hist = []

    best_val_epoch_loss, best_model_state = float("inf"), None

    for epoch in range(num_epochs):
        if torch.cuda.is_available:
            torch.cuda.empty_cache()  # Clear GPU memory

        t_batch_losses, t_corr_losses, t_att_losses = run_epoch(
            train_loader, "train", model, epoch, optimizer=optimizer)

        train_epoch_loss = np.nanmean(t_batch_losses)
        print("Train epoch %d: average loss = %6.10f" %
              (epoch + 1, train_epoch_loss))
        _run.log_scalar("train_epoch_loss", train_epoch_loss)
        _run.log_scalar("train_batch_losses", t_batch_losses)
        _run.log_scalar("train_corr_losses", t_corr_losses)
        _run.log_scalar("train_att_losses", t_att_losses)

        v_batch_losses, v_corr_losses, v_att_losses = run_epoch(
            val_loader, "eval", model, epoch)
        val_epoch_loss = np.nanmean(v_batch_losses)
        print("Valid epoch %d: average loss = %6.10f" %
              (epoch + 1, val_epoch_loss))
        _run.log_scalar("val_epoch_loss", val_epoch_loss)
        _run.log_scalar("val_batch_losses", v_batch_losses)
        _run.log_scalar("val_corr_losses", v_corr_losses)
        _run.log_scalar("val_att_losses", v_att_losses)

        # Save trained model for the epoch
        savepath = os.path.join(output_dir,
                                "model_ckpt_epoch_%d.pt" % (epoch + 1))
        util.save_model(model, savepath)

        # Save the model state dict of the epoch with the best validation loss
        if val_epoch_loss < best_val_epoch_loss:
            best_val_epoch_loss = val_epoch_loss
            best_model_state = model.state_dict()

        # If losses are both NaN, then stop
        if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss):
            break

        # Check for early stopping
        if early_stopping:
            if len(val_epoch_loss_hist) < early_stop_hist_len + 1:
                # Not enough history yet; tack on the loss
                val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist
            else:
                # Tack on the new validation loss, kicking off the old one
                val_epoch_loss_hist = \
                    [val_epoch_loss] + val_epoch_loss_hist[:-1]
                best_delta = np.max(np.diff(val_epoch_loss_hist))
                if best_delta < early_stop_min_delta:
                    break  # Not improving enough

    # Compute evaluation metrics and log them
    print("Computing test metrics:")
    # Load in the state of the epoch with the best validation loss first
    model.load_state_dict(best_model_state)
    batch_losses, corr_losses, att_losses, true_vals, pred_vals, coords, \
        input_grads, input_seqs = run_epoch(
            test_loader, "eval", model, 0, return_data=True
    )
    _run.log_scalar("test_batch_losses", batch_losses)
    _run.log_scalar("test_corr_losses", corr_losses)
    _run.log_scalar("test_att_losses", att_losses)

    neg_upsample_factor = test_loader.dataset.bins_batcher.neg_to_pos_imbalance
    metrics = binary_performance.compute_performance_metrics(
        true_vals, pred_vals, neg_upsample_factor)
    binary_performance.log_performance_metrics(metrics, "test", _run)
def train_model(loaders, trans_id, params, _run):
    """
    Trains the network for the given training and validation data.
    Arguments:
        `train_loader` (DataLoader): a data loader for the training data
        `val_loader` (DataLoader): a data loader for the validation data
        `test_summit_loader` (DataLoader): a data loader for the test data, with
            coordinates centered at summits
        `test_peak_loader` (DataLoader): a data loader for the test data, with
            coordinates tiled across peaks
        `test_genome_loader` (DataLoader): a data loader for the test data, with
            summit-centered coordinates augmented with sampled negatives
    Note that all data loaders are expected to yield the 1-hot encoded
    sequences, profiles, statuses, source coordinates, and source peaks.
    """
    num_epochs = params["num_epochs"]
    num_epochs_prof = params["num_epochs_prof"]
    learning_rate = params["learning_rate"]
    early_stopping = params["early_stopping"]
    early_stop_hist_len = params["early_stop_hist_len"]
    early_stop_min_delta = params["early_stop_min_delta"]
    train_seed = params["train_seed"]

    train_loader_1 = loaders["train_1"]
    val_loader_1 = loaders["val_1"]
    train_loader_2 = loaders["train_2"]
    val_loader_2 = loaders["val_2"]
    test_genome_loader = loaders["test_genome"]
    test_loaders = [
        (loaders["test_summit_union"], "summit_union"),
        (loaders["test_summit_to_sig"], "summit_to_sig"),
        (loaders["test_summit_from_sig"], "summit_from_sig"),
        (loaders["test_summit_to_sig_from_sig"], "summit_to_sig_from_sig"),
        (loaders["test_summit_to_insig_from_sig"], "summit_to_insig_from_sig"),
        (loaders["test_summit_to_sig_from_insig"], "summit_to_sig_from_insig"),
    ]

    run_num = _run._id
    output_dir = os.path.join(MODEL_DIR, f"{trans_id}_{run_num}")
    os.makedirs(output_dir, exist_ok=True)

    if train_seed:
        torch.manual_seed(train_seed)

    device = torch.device(f"cuda:{params['gpu_id']}") if torch.cuda.is_available() \
        else torch.device("cpu")

    # torch.backends.cudnn.enabled = False ####
    # torch.backends.cudnn.benchmark = True ####

    model = create_model(**params)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if early_stopping:
        val_epoch_loss_hist = []

    best_val_epoch_loss = np.inf
    best_model_state = None
    best_model_epoch = None

    for epoch in range(num_epochs_prof):
        if torch.cuda.is_available:
            torch.cuda.empty_cache()  # Clear GPU memory

        t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, \
            t_count_losses = run_epoch(
                train_loader_1, "train", model, epoch, optimizer=optimizer, seq_mode=False
        )
        train_epoch_loss = np.nanmean(t_batch_losses)
        print("Pre-Train epoch %d: average loss = %6.10f" %
              (epoch + 1, train_epoch_loss))
        _run.log_scalar(f"{trans_id}_aux_train_epoch_loss", train_epoch_loss)
        _run.log_scalar(f"{trans_id}_aux_train_batch_losses", t_batch_losses)
        _run.log_scalar(f"{trans_id}_aux_train_corr_losses", t_corr_losses)
        _run.log_scalar(f"{trans_id}_aux_train_att_losses", t_att_losses)
        _run.log_scalar(f"{trans_id}_aux_train_prof_corr_losses",
                        t_prof_losses)
        _run.log_scalar(f"{trans_id}_aux_train_count_corr_losses",
                        t_count_losses)

        v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, \
            v_count_losses = run_epoch(
                val_loader_1, "eval", model, epoch
        )
        val_epoch_loss = np.nanmean(v_batch_losses)
        print("Pre-Valid epoch %d: average loss = %6.10f" %
              (epoch + 1, val_epoch_loss))
        _run.log_scalar(f"{trans_id}_aux_val_epoch_loss", val_epoch_loss)
        _run.log_scalar(f"{trans_id}_aux_val_batch_losses", v_batch_losses)
        _run.log_scalar(f"{trans_id}_aux_val_corr_losses", v_corr_losses)
        _run.log_scalar(f"{trans_id}_aux_val_att_losses", v_att_losses)
        _run.log_scalar(f"{trans_id}_aux_val_prof_corr_losses", v_prof_losses)
        _run.log_scalar(f"{trans_id}_aux_val_count_corr_losses",
                        v_count_losses)

        # Save trained model for the epoch
        savepath = os.path.join(output_dir,
                                "model_aux_ckpt_epoch_%d.pt" % (epoch + 1))
        util.save_model(model, savepath)

        # Save the model state dict of the epoch with the best validation loss
        if val_epoch_loss < best_val_epoch_loss:
            best_val_epoch_loss = val_epoch_loss
            best_model_state = model.state_dict()
            best_model_epoch = epoch

        # If losses are both NaN, then stop
        if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss):
            break

        # Check for early stopping
        if early_stopping:
            if len(val_epoch_loss_hist) < early_stop_hist_len + 1:
                # Not enough history yet; tack on the loss
                val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist
            else:
                # Tack on the new validation loss, kicking off the old one
                val_epoch_loss_hist = \
                    [val_epoch_loss] + val_epoch_loss_hist[:-1]
                best_delta = np.max(np.diff(val_epoch_loss_hist))
                if best_delta < early_stop_min_delta:
                    break  # Not improving enough

    _run.log_scalar(f"{trans_id}_aux_best_epoch", best_model_epoch)

    # Compute evaluation metrics and log them
    # for data_loader, prefix in [
    #     (test_summit_loader, "summit"), # (test_peak_loader, "peak"),
    #     # (test_genome_loader, "genomewide")
    # ]:
    for data_loader, prefix in test_loaders:
        print("Computing pretraining test metrics, %s:" % prefix)
        # Load in the state of the epoch with the best validation loss first
        model.load_state_dict(best_model_state)
        batch_losses, corr_losses, att_losses, prof_losses, count_losses, \
            true_profs, log_pred_profs, true_counts, log_pred_counts, coords, \
            input_grads, input_seqs, true_profs_trans, true_counts_trans = run_epoch(
                data_loader, "eval", model, 0, return_data=True
        )
        _run.log_scalar(f"{trans_id}_aux_test_{prefix}_batch_losses",
                        batch_losses)
        _run.log_scalar(f"{trans_id}_aux_test_{prefix}_corr_losses",
                        corr_losses)
        _run.log_scalar(f"{trans_id}_aux_test_{prefix}_att_losses", att_losses)
        _run.log_scalar(f"{trans_id}_aux_test_{prefix}_prof_corr_losses",
                        prof_losses)
        _run.log_scalar(f"{trans_id}_aux_test_{prefix}_count_corr_losses",
                        count_losses)

        metrics = profile_performance.compute_performance_metrics(
            true_profs, log_pred_profs, true_counts, log_pred_counts)
        if prefix == "summit_union":
            metrics_savepath = os.path.join(output_dir, "metrics_aux.pickle")
        else:
            metrics_savepath = None

        profile_performance.log_performance_metrics(
            metrics, f"{trans_id}_aux_{prefix}", _run, \
            savepath=metrics_savepath, counts=(true_counts, true_counts_trans), coords=coords
        )

    if early_stopping:
        val_epoch_loss_hist = []

    best_val_epoch_loss = np.inf
    best_model_state = None
    best_model_epoch = None

    for epoch in range(num_epochs):
        if torch.cuda.is_available:
            torch.cuda.empty_cache()  # Clear GPU memory

        t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, \
            t_count_losses = run_epoch(
                train_loader_2, "train", model, epoch, optimizer=optimizer, seq_mode=True
        )
        train_epoch_loss = np.nanmean(t_batch_losses)
        print("Train epoch %d: average loss = %6.10f" %
              (epoch + 1, train_epoch_loss))
        _run.log_scalar(f"{trans_id}_train_epoch_loss", train_epoch_loss)
        _run.log_scalar(f"{trans_id}_train_batch_losses", t_batch_losses)
        _run.log_scalar(f"{trans_id}_train_corr_losses", t_corr_losses)
        _run.log_scalar(f"{trans_id}_train_att_losses", t_att_losses)
        _run.log_scalar(f"{trans_id}_train_prof_corr_losses", t_prof_losses)
        _run.log_scalar(f"{trans_id}_train_count_corr_losses", t_count_losses)

        v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, \
            v_count_losses = run_epoch(
                val_loader_2, "eval", model, epoch
        )
        val_epoch_loss = np.nanmean(v_batch_losses)
        print("Valid epoch %d: average loss = %6.10f" %
              (epoch + 1, val_epoch_loss))
        _run.log_scalar(f"{trans_id}_val_epoch_loss", val_epoch_loss)
        _run.log_scalar(f"{trans_id}_val_batch_losses", v_batch_losses)
        _run.log_scalar(f"{trans_id}_val_corr_losses", v_corr_losses)
        _run.log_scalar(f"{trans_id}_val_att_losses", v_att_losses)
        _run.log_scalar(f"{trans_id}_val_prof_corr_losses", v_prof_losses)
        _run.log_scalar(f"{trans_id}_val_count_corr_losses", v_count_losses)

        # Save trained model for the epoch
        savepath = os.path.join(output_dir,
                                "model_ckpt_epoch_%d.pt" % (epoch + 1))
        util.save_model(model, savepath)

        # Save the model state dict of the epoch with the best validation loss
        if val_epoch_loss < best_val_epoch_loss:
            best_val_epoch_loss = val_epoch_loss
            best_model_state = model.state_dict()
            best_model_epoch = epoch

        # If losses are both NaN, then stop
        if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss):
            break

        # Check for early stopping
        if early_stopping:
            if len(val_epoch_loss_hist) < early_stop_hist_len + 1:
                # Not enough history yet; tack on the loss
                val_epoch_loss_hist = [val_epoch_loss] + val_epoch_loss_hist
            else:
                # Tack on the new validation loss, kicking off the old one
                val_epoch_loss_hist = \
                    [val_epoch_loss] + val_epoch_loss_hist[:-1]
                best_delta = np.max(np.diff(val_epoch_loss_hist))
                if best_delta < early_stop_min_delta:
                    break  # Not improving enough

    _run.log_scalar(f"{trans_id}_best_epoch", best_model_epoch)

    # Compute evaluation metrics and log them
    # for data_loader, prefix in [
    #     (test_summit_loader, "summit"), # (test_peak_loader, "peak"),
    #     # (test_genome_loader, "genomewide")
    # ]:
    for data_loader, prefix in test_loaders:
        print("Computing test metrics, %s:" % prefix)
        # Load in the state of the epoch with the best validation loss first
        model.load_state_dict(best_model_state)
        batch_losses, corr_losses, att_losses, prof_losses, count_losses, \
            true_profs, log_pred_profs, true_counts, log_pred_counts, coords, \
            input_grads, input_seqs, true_profs_trans, true_counts_trans = run_epoch(
                data_loader, "eval", model, 0, return_data=True
        )
        _run.log_scalar(f"{trans_id}_test_{prefix}_batch_losses", batch_losses)
        _run.log_scalar(f"{trans_id}_test_{prefix}_corr_losses", corr_losses)
        _run.log_scalar(f"{trans_id}_test_{prefix}_att_losses", att_losses)
        _run.log_scalar(f"{trans_id}_test_{prefix}_prof_corr_losses",
                        prof_losses)
        _run.log_scalar(f"{trans_id}_test_{prefix}_count_corr_losses",
                        count_losses)

        metrics = profile_performance.compute_performance_metrics(
            true_profs, log_pred_profs, true_counts, log_pred_counts)

        if prefix == "summit_union":
            metrics_savepath = os.path.join(output_dir, "metrics.pickle")
        else:
            metrics_savepath = None

        profile_performance.log_performance_metrics(
            metrics, f"{trans_id}_{prefix}", _run, \
            savepath=metrics_savepath, counts=(true_counts, true_counts_trans), coords=coords
        )
Example #3
0
 def save_model(self, dir_path, identifier):
     save_model(dir_path,
                self.actor,
                self.critic,
                self.replay_buffer,
                identifier=identifier)