示例#1
0
        def predict_func(input_seq_batch):
            # Return the set of outputs for the batch of input sequences, all
            # using the same set of control profiles, if needed
            num_in_batch = len(input_seq_batch)
            if cont_profs is not None:
                cont_profs_batch = np.stack([cont_profs[seq_index]] *
                                            num_in_batch,
                                            axis=0)
            else:
                cont_profs_batch = None  # No controls

            output_vals = np.empty(num_in_batch)
            num_batches = int(np.ceil(num_in_batch / batch_size))
            for i in range(num_batches):
                batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                logit_pred_profs, _ = model(
                    place_tensor(torch.tensor(
                        input_seq_batch[batch_slice])).float(),
                    place_tensor(torch.tensor(
                        cont_profs_batch[batch_slice])).float())
                logit_pred_profs = logit_pred_profs.detach().cpu().numpy()

                if task_index is None:
                    output_vals[batch_slice] = np.sum(logit_pred_profs,
                                                      axis=(1, 2, 3))
                else:
                    output_vals[batch_slice] = np.sum(
                        logit_pred_profs[:, task_index], axis=(1, 2))

            return output_vals
示例#2
0
def run_model(model, seqs, profs_ctrls, fps, gpu_id, profs_trans=None):
    # print(seqs[:5]) ####
    # print(seqs.shape) ####
    # print(profs_ctrls.shape) ####
    # print(profs_trans.shape) ####
    num_runs = seqs.shape[1]
    profs_preds_shape = (profs_ctrls.shape[0], seqs.shape[1],
                         profs_ctrls.shape[1], profs_ctrls.shape[2],
                         profs_ctrls.shape[3])
    profs_preds_logits = np.empty(profs_preds_shape)
    counts_preds_shape = (profs_ctrls.shape[0], seqs.shape[1],
                          profs_ctrls.shape[1], profs_ctrls.shape[3])
    counts_preds = np.empty(counts_preds_shape)

    profs_ctrls_b = place_tensor(torch.tensor(profs_ctrls),
                                 index=gpu_id).float()
    if profs_trans is not None:
        profs_trans_b = place_tensor(torch.tensor(profs_trans),
                                     index=gpu_id).float()

    for run in range(seqs.shape[1]):
        seqs_b = place_tensor(torch.tensor(seqs[:, run]), index=gpu_id).float()
        if profs_trans is not None:
            profs_preds_logits_b, counts_preds_b = model(
                seqs_b, profs_ctrls_b, profs_trans_b)
        else:
            profs_preds_logits_b, counts_preds_b = model(seqs_b, profs_ctrls_b)

        profs_preds_logits_b = profs_preds_logits_b.detach().cpu().numpy()
        counts_preds_b = counts_preds_b.detach().cpu().numpy()
        profs_preds_logits[:, run] = profs_preds_logits_b
        counts_preds[:, run] = counts_preds_b

    return profs_preds_logits, counts_preds
示例#3
0
 def explain_fn(input_seqs,
                cont_profs=None,
                batch_size=128,
                hide_shap_output=False):
     """
     Given input sequences and control profiles, returns hypothetical scores
     for the input sequences.
     Arguments:
         `input_seqs`: a B x I x 4 array
         `cont_profs`: a B x (T or 1) x O x S array, or None
         `batch_size`: batch size for computation
         `hide_shap_output`: if True, do not show any warnings from DeepSHAP
     Returns a B x I x 4 array containing hypothetical importance scores for
     each of the B input sequences.
     """
     scores = np.empty_like(input_seqs)
     input_seqs_t = place_tensor(torch.tensor(input_seqs)).float()
     try:
         if hide_shap_output:
             hide_stdout()
         if controls is None:
             return explainer.shap_values([input_seqs_t],
                                          progress_message=None)[0]
         else:
             cont_profs_t = place_tensor(torch.tensor(cont_profs)).float()
             return explainer.shap_values([input_seqs_t, cont_profs_t],
                                          progress_message=None)[0]
     except Exception as e:
         raise e
     finally:
         show_stdout()
示例#4
0
def main():
    global model, prof, count
    device = torch.device("cuda") if torch.cuda.is_available() \
        else torch.device("cpu")

    model = create_model()

    model = model.to(device)

    np.random.seed(20191013)
    x = np.random.randint(2, size=[10, 4, 1346])
    y = (np.random.randint(5, size=[10, 3, 2, 1000]),
         np.random.randint(5, size=[10, 3, 2, 1000]))

    # dataset = TestDataset([x], [y])
    # data_loader = torch.utils.data.DataLoader(
    #     dataset, batch_size=None, collate_fn=lambda x: x
    # )

    torch.autograd.set_detect_anomaly(True)

    input_seq = util.place_tensor(torch.tensor(x)).float()
    tf_prof = util.place_tensor(torch.tensor(y[0])).float()
    cont_prof = util.place_tensor(torch.tensor(y[1])).float()

    pred_prof, pred_count = model(input_seq, cont_prof)

    loss = model.correctness_loss(tf_prof, pred_prof, pred_count, 1)

    loss.backward()
    def sparsity_att_prior_loss(self, status, input_grads):
        """
        Computes an attribution prior loss for some given training examples,
        by rewarding sparsity of the attributions.
        Arguments:
            `status`: a B-tensor, where B is the batch size; each entry is 1 if
                that example is to be treated as a positive example, and 0
                otherwise
            `input_grads`: a B x L x D tensor, where B is the batch size, L is
                the length of the input, and D is the dimensionality of each
                input base; this needs to be the gradients of the input with
                respect to the output (for multiple tasks, this gradient needs
                to be aggregated); this should be *gradient times input*
        Returns a single scalar Tensor consisting of the attribution loss for
        the batch.
        """
        abs_grads = torch.sum(torch.abs(input_grads), dim=2)

        # Only do the positives
        pos_grads = abs_grads[status == 1]  # Shape: B' x L

        # Loss for positives
        if pos_grads.nelement():
            # Compute all pairwise differences
            seq_len = pos_grads.size(1)
            seq_matrix = pos_grads.repeat(1,
                                          seq_len).view(-1, seq_len, seq_len)
            diffs = torch.abs(seq_matrix - seq_matrix.transpose(1, 2))

            norm = seq_len * torch.sum(pos_grads, dim=1)

            return -torch.mean(torch.sum(diffs, dim=(1, 2)) / norm)
        else:
            return place_tensor(torch.zeros(1))
    def smoothness_att_prior_loss(self, status, input_grads):
        """
        Computes an attribution prior loss for some given training examples,
        by rewarding smoothness between neighbors' attributions.
        Arguments:
            `status`: a B-tensor, where B is the batch size; each entry is 1 if
                that example is to be treated as a positive example, and 0
                otherwise
            `input_grads`: a B x L x D tensor, where B is the batch size, L is
                the length of the input, and D is the dimensionality of each
                input base; this needs to be the gradients of the input with
                respect to the output (for multiple tasks, this gradient needs
                to be aggregated); this should be *gradient times input*
        Returns a single scalar Tensor consisting of the attribution loss for
        the batch.
        """
        abs_grads = torch.sum(torch.abs(input_grads), dim=2)

        # Only do the positives
        pos_grads = abs_grads[status == 1]

        # Loss for positives
        if pos_grads.nelement():
            # Compute neighbor differences
            diffs = torch.abs(pos_grads[:, 1:] - pos_grads[:, :-1])
            return torch.mean(torch.sum(diffs, dim=1))
        else:
            return place_tensor(torch.zeros(1))
示例#7
0
def create_input_seq_background(input_seq,
                                input_length,
                                bg_size=10,
                                seed=20200219):
    """
    From the input sequence to a model, generates a set of background
    sequences to perform interpretation against.
    Arguments:
        `input_seq`: I x 4 tensor of one-hot encoded input sequence, or None
        `input_length`: length of input, I
        `bg_size`: the number of background examples to generate, G
    Returns a G x I x 4 tensor containing randomly dinucleotide-shuffles of the
    original input sequence. If `input_seq` is None, then a G x I x 4 tensor of
    all 0s is returned.
    """
    if input_seq is None:
        input_seq_bg_shape = (bg_size, input_length, 4)
        return place_tensor(torch.zeros(input_seq_bg_shape)).float()

    # Do dinucleotide shuffles
    input_seq_np = input_seq.cpu().numpy()
    rng = np.random.RandomState(seed)
    input_seq_bg_np = dinuc_shuffle(input_seq_np, bg_size, rng=rng)
    return place_tensor(torch.tensor(input_seq_bg_np)).float()
示例#8
0
    def predict_func(input_seq_batch):
        # Return the set of outputs for the batch of input sequences
        num_in_batch = len(input_seq_batch)
        output_vals = np.empty(num_in_batch)
        num_batches = int(np.ceil(num_in_batch / batch_size))
        for i in range(num_batches):
            batch_slice = slice(i * batch_size, (i + 1) * batch_size)
            logit_preds = model(
                place_tensor(torch.tensor(
                    input_seq_batch[batch_slice])).float(), )
            logit_preds = logit_preds.detach().cpu().numpy()

            if task_index is None:
                output_vals[batch_slice] = np.sum(logit_preds, axis=1)
            else:
                output_vals[batch_slice] = logit_preds[:, task_index]

        return output_vals
示例#9
0
 def explain_fn(input_seqs, hide_shap_output):
     """
     Given input sequences, returns hypothetical.
     Arguments:
         `input_seqs`: a B x I x 4 array
         `hide_shap_output`: if True, do not show any warnings from DeepSHAP
     Returns a B x I x 4 array containing hypothetical importance scores for
     each of the B input sequences.
     """
     input_seqs_t = place_tensor(torch.tensor(input_seqs)).float()
     try:
         if hide_shap_output:
             hide_stdout()
         return explainer.shap_values([input_seqs_t],
                                      progress_message=None)[0]
     except Exception as e:
         raise e
     finally:
         show_stdout()
示例#10
0
def create_profile_control_background(control_profs,
                                      profile_length,
                                      num_tasks,
                                      num_strands,
                                      controls="matched",
                                      bg_size=10):
    """
    Generates a background for a set of profile controls. In general, this is
    the given control profiles, copied a number of times (i.e. the background
    for controls should always be the same). Note this is only used for profile
    models.
    Arguments:
        `control_profs`: (T or 1) x O x S tensor of control profiles,
            or None
        `profile_length`: length of profile, O
        `num_tasks`: number of tasks, T
        `num_strands`: number of strands, S
        `controls`: the kind of controls used: "matched" or "shared"; if
            "matched", the control profiles taken in and returned are
            T x O x S; if "shared", the profiles are 1 x O x S
        `bg_size`: the number of background examples to generate, G
    Returns the tensor of `control_profs`, replicated G times. If `controls` is
    "matched", this becomes a G x T x O x S tensor; if `controls` is "shared",
    this is a G x 1 x O x S tensor. If `control_profs` is None, then a tensor of
    all 0s is returned, whose shape is determined by `controls`.
    """
    assert controls in ("matched", "shared")

    if controls == "matched":
        control_profs_bg_shape = (bg_size, num_tasks, profile_length,
                                  num_strands)
    else:
        control_profs_bg_shape = (bg_size, 1, profile_length, num_strands)
    if control_profs is None:
        return place_tensor(torch.zeros(control_profs_bg_shape)).float()

    # Replicate `control_profs`
    return torch.stack([control_profs] * bg_size, dim=0)
示例#11
0
def _get_profile_model_predictions_batch(
        model,
        coords,
        num_tasks,
        input_func,
        controls=None,
        fourier_att_prior_freq_limit=200,
        fourier_att_prior_freq_limit_softness=0.2,
        att_prior_grad_smooth_sigma=3,
        return_losses=False,
        return_gradients=False):
    """
    Fetches the necessary data from the given coordinates or bin indices and
    runs it through a profile or binary model. This will perform computation
    in a single batch.
    Arguments:
        `model`: a trained `ProfilePredictorWithMatchedControls`,
            `ProfilePredictorWithSharedControls`, or
            `ProfilePredictorWithoutControls`
        `coords`: a B x 3 array of coordinates to compute outputs for
        `num_tasks`: number of tasks for the model
        `input_func`: a function that takes in `coords` and returns the
            B x I x 4 array of one-hot sequences and the
            B x (T or T + 1 or 2T) x O x S array of profiles (perhaps with
            controls)
        `controls`: the type of control profiles (if any) used in model; can be
            "matched" (each task has a matched control), "shared" (all tasks
            share a control), or None (no controls); must match the model class
        `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior
            loss
        `fourier_att_prior_freq_limit_softness`: degree of softness for limit
        `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients
        `return_losses`: if True, compute/return the loss values
        `return_gradients`: if True, compute/return the input gradients and
            sequences
    Returns a dictionary of the following structure:
        true_profs: true profile raw counts (B x T x O x S)
        log_pred_profs: predicted profile log probabilities (B x T x O x S)
        true_counts: true total counts (B x T x S)
        log_pred_counts: predicted log counts (B x T x S)
        prof_losses: profile NLL losses (B-array), if `return_losses` is True
        count_losses: counts MSE losses (B-array) if `return_losses` is True
        att_losses: prior losses (B-array), if `return_losses` is True
        input_seqs: one-hot input sequences (B x I x 4), if `return_gradients`
            is true
        input_grads: "hypothetical" input gradients (B x I x 4), if
            `return_gradients` is true
    """
    result = {}
    input_seqs, profiles = input_func(coords)
    if return_gradients:
        input_seqs_np = input_seqs
        model.zero_grad(
        )  # Zero out weights because we are computing gradients
    input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float()
    profiles = model_util.place_tensor(torch.tensor(profiles)).float()

    if controls is not None:
        tf_profs = profiles[:, :num_tasks, :, :]
        cont_profs = profiles[:, num_tasks:, :, :]  # Last half or just one
    else:
        tf_profs, cont_profs = profiles, None

    if return_losses or return_gradients:
        input_seqs.requires_grad = True  # Set gradient required
        logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs)

        # Subtract mean along output profile dimension; this wouldn't change
        # softmax probabilities, but normalizes the magnitude of gradients
        norm_logit_pred_profs = logit_pred_profs - \
            torch.mean(logit_pred_profs, dim=2, keepdim=True)

        # Weight by post-softmax probabilities, but do not take the
        # gradients of these probabilities; this upweights important regions
        # exponentially
        pred_prof_probs = profile_models.profile_logits_to_log_probs(
            logit_pred_profs).detach()
        weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

        input_grads, = torch.autograd.grad(
            weighted_norm_logits,
            input_seqs,
            grad_outputs=model_util.place_tensor(
                torch.ones(weighted_norm_logits.size())),
            retain_graph=True,
            create_graph=True
            # We'll be operating on the gradient itself, so we need to
            # create the graph
            # Gradients are summed across strands and tasks
        )
        input_grads_np = input_grads.detach().cpu().numpy()
        input_seqs.requires_grad = False  # Reset gradient required
    else:
        logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs)

    result["true_profs"] = tf_profs.detach().cpu().numpy()
    result["true_counts"] = np.sum(result["true_profs"], axis=2)
    logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy()
    result["log_pred_profs"] = profile_models.profile_logits_to_log_probs(
        logit_pred_profs_np)
    result["log_pred_counts"] = log_pred_counts.detach().cpu().numpy()

    if return_losses:
        log_pred_profs = profile_models.profile_logits_to_log_probs(
            logit_pred_profs)
        num_samples = log_pred_profs.size(0)
        result["prof_losses"] = np.empty(num_samples)
        result["count_losses"] = np.empty(num_samples)
        result["att_losses"] = np.empty(num_samples)

        # Compute losses separately for each example
        for i in range(num_samples):
            _, prof_loss, count_loss = model.correctness_loss(
                tf_profs[i:i + 1],
                log_pred_profs[i:i + 1],
                log_pred_counts[i:i + 1],
                1,
                return_separate_losses=True)
            att_loss = model.fourier_att_prior_loss(
                model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1],
                fourier_att_prior_freq_limit,
                fourier_att_prior_freq_limit_softness,
                att_prior_grad_smooth_sigma)
            result["prof_losses"][i] = prof_loss
            result["count_losses"][i] = count_loss
            result["att_losses"][i] = att_loss

    if return_gradients:
        result["input_seqs"] = input_seqs_np
        result["input_grads"] = input_grads_np

    return result
示例#12
0
def _get_binary_model_predictions_batch(
        model,
        bins,
        input_func,
        fourier_att_prior_freq_limit=150,
        fourier_att_prior_freq_limit_softness=0.2,
        att_prior_grad_smooth_sigma=3,
        return_losses=False,
        return_gradients=False):
    """
    Arguments:
        `model`: a trained `BinaryPredictor`,
        `bins`: an N-array of bin indices to compute outputs for
        `input_func`: a function that takes in `bins` and returns the B x I x 4
            array of one-hot sequences, the B x T array of output values, and
            B x 3 array of underlying coordinates for the input sequence
        `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior
            loss
        `fourier_att_prior_freq_limit_softness`: degree of softness for limit
        `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients
        `return_losses`: if True, compute/return the loss values
        `return_gradients`: if True, compute/return the input gradients and
            sequences
    Returns a dictionary of the following structure:
        true_vals: true binary values (B x T)
        pred_vals: predicted probabilities (B x T)
        coords: coordinates used for prediction (B x 3 object array)
        corr_losses: correctness losses (B-array) if `return_losses` is True
        att_losses: prior losses (B-array), if `return_losses` is True
        input_seqs: one-hot input sequences (B x I x 4), if `return_gradients`
            is True
        input_grads: "hypothetical" input gradients (B x I x 4), if
            `return_gradients` is true
    """
    result = {}
    input_seqs, output_vals, coords = input_func(bins)
    output_vals_np = output_vals
    if return_gradients:
        input_seqs_np = input_seqs
        model.zero_grad()
    input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float()
    output_vals = model_util.place_tensor(torch.tensor(output_vals)).float()

    if return_losses or return_gradients:
        input_seqs.requires_grad = True  # Set gradient required
        logit_pred_vals = model(input_seqs)
        # Compute the gradients of the output with respect to the input
        input_grads, = torch.autograd.grad(
            logit_pred_vals,
            input_seqs,
            grad_outputs=model_util.place_tensor(
                torch.ones(logit_pred_vals.size())),
            retain_graph=True,
            create_graph=True
            # We'll be operating on the gradient itself, so we need to
            # create the graph
            # Gradients are summed across tasks
        )
        input_grads_np = input_grads.detach().cpu().numpy()
        input_seqs.requires_grad = False  # Reset gradient required
    else:
        logit_pred_vals = model(input_seqs)
        status, input_grads = None, None

    result["true_vals"] = output_vals_np
    logit_pred_vals_np = logit_pred_vals.detach().cpu().numpy()
    result["pred_vals"] = binary_models.binary_logits_to_probs(
        logit_pred_vals_np)
    result["coords"] = coords

    if return_losses:
        num_samples = logit_pred_vals.size(0)
        result["corr_losses"] = np.empty(num_samples)
        result["att_losses"] = np.empty(num_samples)

        # Compute losses separately for each example
        for i in range(num_samples):
            corr_loss = model.correctness_loss(output_vals[i:i + 1],
                                               logit_pred_vals[i:i + 1], True)
            att_loss = model.fourier_att_prior_loss(
                model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1],
                fourier_att_prior_freq_limit,
                fourier_att_prior_freq_limit_softness,
                att_prior_grad_smooth_sigma)
            result["corr_losses"][i] = corr_loss
            result["att_losses"][i] = att_loss

    if return_gradients:
        result["input_seqs"] = input_seqs_np
        result["input_grads"] = input_grads_np

    return result
def run_epoch(
    data_loader,
    mode,
    model,
    epoch_num,
    params,
    optimizer=None,
    return_data=False,
    seq_mode=True,
    att_prior_loss_weight=None,
):
    """
    Runs the data from the data loader once through the model, to train,
    validate, or predict.
    Arguments:
        `data_loader`: an instantiated `DataLoader` instance that gives batches
            of data; each batch must yield the input sequences, profiles,
            statuses, coordinates, and peaks; if `controls` is "matched",
            profiles must be such that the first half are prediction (target)
            profiles, and the second half are control profiles; if `controls` is
            "shared", the last set of profiles is a shared control; otherwise,
            all tasks are prediction profiles
        `mode`: one of "train", "eval"; if "train", run the epoch and perform
            backpropagation; if "eval", only do evaluation
        `model`: the current PyTorch model being trained/evaluated
        `epoch_num`: 0-indexed integer representing the current epoch
        `optimizer`: an instantiated PyTorch optimizer, for training mode
        `return_data`: if specified, returns the following as NumPy arrays:
            true profile raw counts (N x T x O x S), predicted profile log
            probabilities (N x T x O x S), true total counts (N x T x S),
            predicted log counts (N x T x S), coordinates used (N x 3 object
            array), input gradients (N x I x 4), and input sequences (N x I x4);
            if the attribution prior is not used, the gradients will be garbage
    Returns lists of overall losses, correctness losses, attribution prior
    losses, profile losses, and count losses, where each list is over all
    batches. If the attribution prior loss is not computed, then it will be all
    0s. If `return_data` is True, then more things will be returned after these.
    """
    num_tasks = params["num_tasks"]
    controls = params["controls"]
    if att_prior_loss_weight is None:
        att_prior_loss_weight = params["att_prior_loss_weight"]
    batch_size = params["batch_size"]
    revcomp = params["revcomp"]
    input_length = params["input_length"]
    input_depth = params["input_depth"]
    profile_length = params["profile_length"]
    gpu_id = params.get("gpu_id")

    assert mode in ("train", "eval")
    if mode == "train":
        assert optimizer is not None
    else:
        assert optimizer is None

    data_loader.dataset.on_epoch_start()  # Set-up the epoch
    num_batches = len(data_loader.dataset)
    t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---")

    if mode == "train":
        model.train()  # Switch to training mode
        torch.set_grad_enabled(True)

    batch_losses, corr_losses, att_losses = [], [], []
    prof_losses, count_losses = [], []
    if return_data:
        # Allocate empty NumPy arrays to hold the results
        num_samples_exp = num_batches * batch_size
        num_samples_exp *= 2 if revcomp else 1
        # Real number of samples can be smaller because of partial last batch
        profile_shape = (num_samples_exp, num_tasks, profile_length, 2)
        count_shape = (num_samples_exp, num_tasks, 2)
        all_log_pred_profs = np.empty(profile_shape)
        all_log_pred_counts = np.empty(count_shape)
        all_true_profs = np.empty(profile_shape)
        all_true_counts = np.empty(count_shape)
        all_true_profs_trans = np.empty(profile_shape)
        all_true_counts_trans = np.empty(count_shape)
        all_input_seqs = np.empty((num_samples_exp, input_length, input_depth))
        all_input_grads = np.empty(
            (num_samples_exp, input_length, input_depth))
        all_coords = np.empty((num_samples_exp, 3), dtype=object)
        num_samples_seen = 0  # Real number of samples seen

    if seq_mode:
        model.set_seq_mode()
    else:
        model.set_aux_mode()

    for input_seqs, profiles, profiles_trans, statuses, coords, peaks in t_iter:
        if return_data:
            input_seqs_np = input_seqs
        input_seqs = util.place_tensor(torch.tensor(input_seqs),
                                       index=gpu_id).float()
        profiles = util.place_tensor(torch.tensor(profiles),
                                     index=gpu_id).float()
        profiles_trans = util.place_tensor(torch.tensor(profiles_trans),
                                           index=gpu_id).float()

        if controls is not None:
            tf_profs = profiles[:, :num_tasks, :, :].contiguous()
            cont_profs = profiles[:, num_tasks:, :, :].contiguous(
            )  # Last half or just one
            tf_profs_trans = profiles_trans[:, :num_tasks, :, :].contiguous()
            cont_profs_trans = profiles_trans[:, num_tasks:, :, :].contiguous(
            )  # Last half or just one
        else:
            tf_profs, cont_profs = profiles, None
            tf_profs_trans, cont_profs_trans = profiles, None

        # Clear gradients from last batch if training
        if mode == "train":
            optimizer.zero_grad()
        elif att_prior_loss_weight > 0:
            # Not training mode, but we still need to zero out weights because
            # we are computing the input gradients
            model.zero_grad()

        if att_prior_loss_weight > 0:
            input_seqs.requires_grad = True  # Set gradient required
            logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs,
                                                      tf_profs_trans)
            # print(input_seqs.is_contiguous()) ####
            # print(cont_profs.is_contiguous()) ####
            # print(tf_profs_trans.is_contiguous()) ####

            # Subtract mean along output profile dimension; this wouldn't change
            # softmax probabilities, but normalizes the magnitude of gradients
            norm_logit_pred_profs = logit_pred_profs - \
                torch.mean(logit_pred_profs, dim=2, keepdim=True)

            # Weight by post-softmax probabilities, but do not take the
            # gradients of these probabilities; this upweights important regions
            # exponentially
            pred_prof_probs = profile_models.profile_logits_to_log_probs(
                logit_pred_profs).detach()
            weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

            if seq_mode:
                model.set_seq_mode()

            input_grads, = torch.autograd.grad(
                weighted_norm_logits,
                input_seqs,
                grad_outputs=util.place_tensor(torch.ones(
                    weighted_norm_logits.size()),
                                               index=gpu_id),
                retain_graph=True,
                create_graph=True
                # We'll be operating on the gradient itself, so we need to
                # create the graph
                # Gradients are summed across strands and tasks
            )
            # print(torch.autograd.grad(torch.sum(input_grads), input_seqs, retain_graph=True)) ####
            # print(input_grads.shape, input_seqs.shape) ####
            # input_grads = input_grads.contiguous()
            if return_data:
                input_grads_np = input_grads.detach().cpu().numpy()
            # input_grads = input_grads.detach() ####
            input_grads = input_grads * input_seqs  # Gradient * input
            # input_grads = input_seqs ####
            status = util.place_tensor(torch.tensor(statuses), index=gpu_id)
            status[status != 0] = 1  # Set to 1 if not negative example
            input_seqs.requires_grad = False  # Reset gradient required
        else:
            logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs,
                                                      tf_profs_trans)
            status, input_grads = None, None

        loss, (corr_loss,
               att_loss), (prof_loss,
                           count_loss) = model_loss(model,
                                                    tf_profs,
                                                    logit_pred_profs,
                                                    log_pred_counts,
                                                    epoch_num,
                                                    status=status,
                                                    input_grads=input_grads)
        # print(input_grads) ####
        # print(input_grads.shape) ####

        if mode == "train":
            # att_loss.backward() ####

            loss.backward()  # Compute gradient
            optimizer.step()  # Update weights through backprop

        # if not ignore_aux:
        #    model.unfreeze_ptp_layers()

        batch_losses.append(loss.item())
        corr_losses.append(corr_loss.item())
        att_losses.append(att_loss.item())
        prof_losses.append(prof_loss.item())
        count_losses.append(count_loss.item())
        t_iter.set_description("\tLoss: %6.4f" % loss.item())

        if return_data:
            logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy()
            log_pred_counts_np = log_pred_counts.detach().cpu().numpy()
            true_profs_np = tf_profs.detach().cpu().numpy()
            true_counts = np.sum(true_profs_np, axis=2)
            true_profs_trans_np = tf_profs_trans.detach().cpu().numpy()
            true_counts_trans = np.sum(true_profs_trans_np, axis=2)

            num_in_batch = true_counts.shape[0]

            # Turn logit profile predictions into log probabilities
            log_pred_profs = profile_models.profile_logits_to_log_probs(
                logit_pred_profs_np, axis=2)

            # Fill in the batch data/outputs into the preallocated arrays
            start, end = num_samples_seen, num_samples_seen + num_in_batch
            all_log_pred_profs[start:end] = log_pred_profs
            all_log_pred_counts[start:end] = log_pred_counts_np
            all_true_profs[start:end] = true_profs_np
            all_true_counts[start:end] = true_counts
            all_true_profs_trans[start:end] = true_profs_trans_np
            all_true_counts_trans[start:end] = true_counts_trans
            all_input_seqs[start:end] = input_seqs_np
            if att_prior_loss_weight:
                all_input_grads[start:end] = input_grads_np
            all_coords[start:end] = coords

            num_samples_seen += num_in_batch

    if return_data:
        # Truncate the saved data to the proper size, based on how many
        # samples actually seen
        all_log_pred_profs = all_log_pred_profs[:num_samples_seen]
        all_log_pred_counts = all_log_pred_counts[:num_samples_seen]
        all_true_profs = all_true_profs[:num_samples_seen]
        all_true_counts = all_true_counts[:num_samples_seen]
        all_true_profs_trans = all_true_profs_trans[:num_samples_seen]
        all_true_counts_trans = all_true_counts_trans[:num_samples_seen]
        all_input_seqs = all_input_seqs[:num_samples_seen]
        all_input_grads = all_input_grads[:num_samples_seen]
        all_coords = all_coords[:num_samples_seen]
        return batch_losses, corr_losses, att_losses, prof_losses, \
            count_losses, all_true_profs, all_log_pred_profs, \
            all_true_counts, all_log_pred_counts, all_coords, all_input_grads, \
            all_input_seqs, all_true_profs_trans, all_true_counts_trans
    else:
        return batch_losses, corr_losses, att_losses, prof_losses, count_losses
示例#14
0
    def fourier_att_prior_loss(self,
                               status,
                               input_grads,
                               freq_limit,
                               limit_softness,
                               att_prior_grad_smooth_sigma,
                               gpu_id=None):
        """
        Computes an attribution prior loss for some given training examples,
        using a Fourier transform form.
        Arguments:
            `status`: a B-tensor, where B is the batch size; each entry is 1 if
                that example is to be treated as a positive example, and 0
                otherwise
            `input_grads`: a B x L x D tensor, where B is the batch size, L is
                the length of the input, and D is the dimensionality of each
                input base; this needs to be the gradients of the input with
                respect to the output (for multiple tasks, this gradient needs
                to be aggregated); this should be *gradient times input*
            `freq_limit`: the maximum integer frequency index, k, to consider
                for the loss; this corresponds to a frequency cut-off of
                pi * k / L; k should be less than L / 2
            `limit_softness`: amount to soften the limit by, using a hill
                function; None means no softness
            `att_prior_grad_smooth_sigma`: amount to smooth the gradient before
                computing the loss
        Returns a single scalar Tensor consisting of the attribution loss for
        the batch.
        """
        # return place_tensor(torch.zeros(1), index=gpu_id) ####
        abs_grads = torch.sum(torch.abs(input_grads), dim=2)
        # return torch.mean(abs_grads) ####

        # Smooth the gradients
        grads_smooth = smooth_tensor_1d(abs_grads,
                                        att_prior_grad_smooth_sigma,
                                        gpu_id=gpu_id)

        # Only do the positives
        pos_grads = grads_smooth[status == 1]

        # Loss for positives
        if pos_grads.nelement():
            pos_fft = torch.rfft(pos_grads, 1)
            pos_mags = torch.norm(pos_fft, dim=2)
            pos_mag_sum = torch.sum(pos_mags, dim=1, keepdim=True)
            pos_mag_sum[pos_mag_sum == 0] = 1  # Keep 0s when the sum is 0
            pos_mags = pos_mags / pos_mag_sum

            # Cut off DC
            pos_mags = pos_mags[:, 1:]

            # Construct weight vector
            weights = place_tensor(torch.ones_like(pos_mags), index=gpu_id)
            if limit_softness is None:
                weights[:, freq_limit:] = 0
            else:
                x = place_tensor(torch.arange(
                    1,
                    pos_mags.size(1) - freq_limit + 1),
                                 index=gpu_id).float()
                weights[:,
                        freq_limit:] = 1 / (1 + torch.pow(x, limit_softness))

            # Multiply frequency magnitudes by weights
            pos_weighted_mags = pos_mags * weights

            # Add up along frequency axis to get score
            pos_score = torch.sum(pos_weighted_mags, dim=1)
            pos_loss = 1 - pos_score
            return torch.mean(pos_loss)
        else:
            return place_tensor(torch.zeros(1), index=gpu_id)
def run_epoch(data_loader,
              mode,
              model,
              epoch_num,
              num_tasks,
              att_prior_loss_weight,
              batch_size,
              revcomp,
              input_length,
              input_depth,
              optimizer=None,
              return_data=False):
    """
    Runs the data from the data loader once through the model, to train,
    validate, or predict.
    Arguments:
        `data_loader`: an instantiated `DataLoader` instance that gives batches
            of data; each batch must yield the input sequences, the output
            values, the status, and coordinates
        `mode`: one of "train", "eval"; if "train", run the epoch and perform
            backpropagation; if "eval", only do evaluation
        `model`: the current PyTorch model being trained/evaluated
        `epoch_num`: 0-indexed integer representing the current epoch
        `optimizer`: an instantiated PyTorch optimizer, for training mode
        `return_data`: if specified, returns the following as NumPy arrays:
            true binding values (0, 1, or -1) (N x T), predicted binding
            probabilities (N x T), the underlying sequence coordinates (N x 3
            object array), input gradients (N x I x 4), and the input sequences
            (N x I x 4); if the attribution prior is not used, the gradients
            will be garbage
    Returns lists of overall losses, correctness losses, and attribution prior
    losses, where each list is over all batches. If the attribution prior loss
    is not computed, then it will be all 0s. If `return_data` is True, then more
    things will be returned after these.
    """
    assert mode in ("train", "eval")
    if mode == "train":
        assert optimizer is not None
    else:
        assert optimizer is None

    data_loader.dataset.on_epoch_start()  # Set-up the epoch
    num_batches = len(data_loader.dataset)
    t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---")

    if mode == "train":
        model.train()  # Switch to training mode
        torch.set_grad_enabled(True)

    batch_losses, corr_losses, att_losses = [], [], []
    if return_data:
        # Allocate empty NumPy arrays to hold the results
        num_samples_exp = num_batches * batch_size
        num_samples_exp *= 2 if revcomp else 1
        # Real number of samples can be smaller because of partial last batch
        all_true_vals = np.empty((num_samples_exp, num_tasks))
        all_pred_vals = np.empty((num_samples_exp, num_tasks))
        all_input_seqs = np.empty((num_samples_exp, input_length, input_depth))
        all_input_grads = np.empty(
            (num_samples_exp, input_length, input_depth))
        all_coords = np.empty((num_samples_exp, 3), dtype=object)
        num_samples_seen = 0  # Real number of samples seen

    for input_seqs, output_vals, statuses, coords in t_iter:
        if return_data:
            input_seqs_np = input_seqs
            output_vals_np = output_vals
        input_seqs = util.place_tensor(torch.tensor(input_seqs)).float()
        output_vals = util.place_tensor(torch.tensor(output_vals)).float()

        # Clear gradients from last batch if training
        if mode == "train":
            optimizer.zero_grad()
        elif att_prior_loss_weight > 0:
            # Not training mode, but we still need to zero out weights because
            # we are computing the input gradients
            model.zero_grad()

        if att_prior_loss_weight > 0:
            input_seqs.requires_grad = True  # Set gradient required
            logit_pred_vals = model(input_seqs)
            # Compute the gradients of the output with respect to the input
            input_grads, = torch.autograd.grad(
                logit_pred_vals,
                input_seqs,
                grad_outputs=util.place_tensor(
                    torch.ones(logit_pred_vals.size())),
                retain_graph=True,
                create_graph=True
                # We'll be operating on the gradient itself, so we need to
                # create the graph
                # Gradients are summed across tasks
            )
            if return_data:
                input_grads_np = input_grads.detach().cpu().numpy()
            input_grads = input_grads * input_seqs  # Gradient * input
            status = util.place_tensor(torch.tensor(statuses))
            input_seqs.requires_grad = False  # Reset gradient required
        else:
            logit_pred_vals = model(input_seqs)
            status, input_grads = None, None

        loss, (corr_loss, att_loss) = model_loss(model,
                                                 output_vals,
                                                 logit_pred_vals,
                                                 epoch_num,
                                                 status=status,
                                                 input_grads=input_grads)

        if mode == "train":
            loss.backward()  # Compute gradient
            optimizer.step()  # Update weights through backprop

        batch_losses.append(loss.item())
        corr_losses.append(corr_loss.item())
        att_losses.append(att_loss.item())
        t_iter.set_description("\tLoss: %6.4f" % loss.item())

        if return_data:
            logit_pred_vals_np = logit_pred_vals.detach().cpu().numpy()

            # Turn logits into probabilities
            pred_vals = binary_models.binary_logits_to_probs(
                logit_pred_vals_np)
            num_in_batch = pred_vals.shape[0]

            # Fill in the batch data/outputs into the preallocated arrays
            start, end = num_samples_seen, num_samples_seen + num_in_batch
            all_true_vals[start:end] = output_vals_np
            all_pred_vals[start:end] = pred_vals
            all_input_seqs[start:end] = input_seqs_np
            if att_prior_loss_weight:
                all_input_grads[start:end] = input_grads_np
            all_coords[start:end] = coords

            num_samples_seen += num_in_batch

    if return_data:
        # Truncate the saved data to the proper size, based on how many
        # samples actually seen
        all_true_vals = all_true_vals[:num_samples_seen]
        all_pred_vals = all_pred_vals[:num_samples_seen]
        all_input_seqs = all_input_seqs[:num_samples_seen]
        all_input_grads = all_input_grads[:num_samples_seen]
        all_coords = all_coords[:num_samples_seen]
        return batch_losses, corr_losses, att_losses, all_true_vals, \
            all_pred_vals, all_coords, all_input_grads, all_input_seqs
    else:
        return batch_losses, corr_losses, att_losses
def model_loss(model,
               true_vals,
               logit_pred_vals,
               epoch_num,
               avg_class_loss,
               att_prior_loss_weight,
               att_prior_loss_weight_anneal_type,
               att_prior_loss_weight_anneal_speed,
               att_prior_grad_smooth_sigma,
               fourier_att_prior_freq_limit,
               fourier_att_prior_freq_limit_softness,
               att_prior_loss_only,
               l2_reg_loss_weight,
               input_grads=None,
               status=None):
    """
    Computes the loss for the model.
    Arguments:
        `model`: the model being trained
        `true_vals`: a B x T tensor, where B is the batch size and T is the
            number of output tasks, containing the true binary values
        `logit_pred_vals`: a B x T tensor containing the predicted logits
        `epoch_num`: a 0-indexed integer representing the current epoch
        `input_grads`: a B x I x D tensor, where I is the input length and D is
            the input depth; this is the gradient of the output with respect to
            the input, times the input itself; only needed when attribution
            prior loss weight is positive
        `status`: a B-tensor, where B is the batch size; each entry is 1 if that
            that example is to be treated as a positive example, 0 if negative,
            and -1 if ambiguous; only needed when attribution prior loss weight
            is positive
    Returns a scalar Tensor containing the loss for the given batch, and a pair
    consisting of the correctness loss and the attribution prior loss.
    If the attribution prior loss is not computed at all, then 0 will be in its
    place, instead.
    """
    corr_loss = model.correctness_loss(true_vals, logit_pred_vals,
                                       avg_class_loss)
    final_loss = corr_loss

    if att_prior_loss_weight > 0:
        att_prior_loss = model.fourier_att_prior_loss(
            status, input_grads, fourier_att_prior_freq_limit,
            fourier_att_prior_freq_limit_softness, att_prior_grad_smooth_sigma)

        # att_prior_loss = model.smoothness_att_prior_loss(status, input_grads)
        # att_prior_loss = model.sparsity_att_prior_loss(status, input_grads)

        if att_prior_loss_weight_anneal_type is None:
            weight = att_prior_loss_weight
        elif att_prior_loss_weight_anneal_type == "inflate":
            exp = np.exp(-att_prior_loss_weight_anneal_speed * epoch_num)
            weight = att_prior_loss_weight * ((2 / (1 + exp)) - 1)
        elif att_prior_loss_weight_anneal_type == "deflate":
            exp = np.exp(-att_prior_loss_weight_anneal_speed * epoch_num)
            weight = att_prior_loss_weight * exp

        if att_prior_loss_only:
            final_loss = att_prior_loss
        else:
            final_loss = final_loss + (weight * att_prior_loss)
    else:
        att_prior_loss = torch.zeros(1)

    # If necessary, add the L2 penalty
    if l2_reg_loss_weight > 0:
        l2_loss = util.place_tensor(torch.tensor(0).float())
        for param in model.parameters():
            if param.requires_grad_:
                l2_loss = l2_loss + torch.sum(param * param)
        final_loss = final_loss + (l2_reg_loss_weight * l2_loss)

    return final_loss, (corr_loss, att_prior_loss)