Example #1
0
    def forward(self, input_seqs, profs_trans, cont_profs=None):
        # Run through inner model, disregarding the predicted counts
        logit_pred_profs, _ = self.inner_model(input_seqs, cont_profs,
                                               profs_trans)

        # As with the computation of the gradients, instead of explaining the
        # logits, explain the mean-normalized logits, weighted by the final
        # probabilities after passing through the softmax; this exponentially
        # increases the weight for high-probability positions, and exponentially
        # reduces the weight for low-probability positions, resulting in a
        # cleaner signal

        # Subtract mean along output profile dimension; this wouldn't change
        # softmax probabilities, but normalizes the magnitude of the logits
        norm_logit_pred_profs = logit_pred_profs - \
            torch.mean(logit_pred_profs, dim=2, keepdim=True)

        # Weight by post-softmax probabilities, but detach it from the graph to
        # avoid explaining those
        pred_prof_probs = profile_models.profile_logits_to_log_probs(
            logit_pred_profs).detach()
        weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

        if self.task_index is not None:
            # Subset to specific task
            weighted_norm_logits = \
                weighted_norm_logits[:, self.task_index : (self.task_index + 1)]
        prof_sum = torch.sum(weighted_norm_logits, dim=(1, 2, 3))

        # DeepSHAP requires the shape to be B x 1
        return torch.unsqueeze(prof_sum, dim=1)
Example #2
0
def test_vectorized_multinomial_nll():
    np.random.seed(20191110)
    batch_size, num_tasks, prof_len = 500, 10, 1000
    prof_shape = (batch_size, num_tasks, prof_len, 2)
    true_profs_np = np.random.randint(5, size=prof_shape)
    logit_pred_profs_np = np.random.randn(*prof_shape)
    log_pred_profs_np = profile_models.profile_logits_to_log_probs(
        logit_pred_profs_np, axis=2)
    true_counts_np = np.sum(true_profs_np, axis=2)

    print("Testing Multinomial NLL...")
    a = datetime.now()
    # Using the profile performance function:
    nll_vec_np = profile_performance.profile_multinomial_nll(
        true_profs_np, log_pred_profs_np, true_counts_np)
    b = datetime.now()
    print("\tTime to compute (NumPy vectorization): %ds" % (b - a).seconds)

    # Using the profile models function:
    # Convert to tensors and swap axes to make profile dimension last
    true_profs_tc = torch.tensor(true_profs_np).transpose(2, 3)
    log_pred_profs_tc = torch.tensor(log_pred_profs_np).transpose(2, 3)
    true_counts_tc = torch.tensor(true_counts_np)

    a = datetime.now()
    nll_vec_tc = -profile_models.multinomial_log_probs(
        log_pred_profs_tc, true_counts_tc, true_profs_tc)
    # Average across strands
    nll_vec_tc = torch.mean(nll_vec_tc, dim=2).numpy()
    b = datetime.now()
    print("\tTime to compute (PyTorch vectorization): %ds" % (b - a).seconds)

    # Using PyTorch's class
    # Convert to tensors
    logit_pred_profs_tc = torch.tensor(logit_pred_profs_np).transpose(2, 3)

    a = datetime.now()
    nll_torch = np.empty((batch_size, num_tasks))
    for i in range(batch_size):
        for j in range(num_tasks):
            dist_0 = torch.distributions.Multinomial(
                true_counts_tc[i, j, 0].item(),
                logits=logit_pred_profs_tc[i, j, 0, :])
            dist_1 = torch.distributions.Multinomial(
                true_counts_tc[i, j, 1].item(),
                logits=logit_pred_profs_tc[i, j, 1, :])

            nll_0 = -dist_0.log_prob(true_profs_tc[i, j, 0, :].float()).item()
            nll_1 = -dist_1.log_prob(true_profs_tc[i, j, 1, :].float()).item()
            nll_torch[i][j] = np.mean([nll_0, nll_1])
    b = datetime.now()
    print("\tTime to compute (PyTorch distributions): %ds" % (b - a).seconds)
    print("\tSame result? %s" % (
        np.allclose(nll_vec_np, nll_vec_tc) and \
        np.allclose(nll_vec_tc, nll_torch)
    ))
Example #3
0
def _get_profile_model_predictions_batch(
        model,
        coords,
        num_tasks,
        input_func,
        controls=None,
        fourier_att_prior_freq_limit=200,
        fourier_att_prior_freq_limit_softness=0.2,
        att_prior_grad_smooth_sigma=3,
        return_losses=False,
        return_gradients=False):
    """
    Fetches the necessary data from the given coordinates or bin indices and
    runs it through a profile or binary model. This will perform computation
    in a single batch.
    Arguments:
        `model`: a trained `ProfilePredictorWithMatchedControls`,
            `ProfilePredictorWithSharedControls`, or
            `ProfilePredictorWithoutControls`
        `coords`: a B x 3 array of coordinates to compute outputs for
        `num_tasks`: number of tasks for the model
        `input_func`: a function that takes in `coords` and returns the
            B x I x 4 array of one-hot sequences and the
            B x (T or T + 1 or 2T) x O x S array of profiles (perhaps with
            controls)
        `controls`: the type of control profiles (if any) used in model; can be
            "matched" (each task has a matched control), "shared" (all tasks
            share a control), or None (no controls); must match the model class
        `fourier_att_prior_freq_limit`: limit for frequencies in Fourier prior
            loss
        `fourier_att_prior_freq_limit_softness`: degree of softness for limit
        `att_prior_grad_smooth_sigma`: width of smoothing kernel for gradients
        `return_losses`: if True, compute/return the loss values
        `return_gradients`: if True, compute/return the input gradients and
            sequences
    Returns a dictionary of the following structure:
        true_profs: true profile raw counts (B x T x O x S)
        log_pred_profs: predicted profile log probabilities (B x T x O x S)
        true_counts: true total counts (B x T x S)
        log_pred_counts: predicted log counts (B x T x S)
        prof_losses: profile NLL losses (B-array), if `return_losses` is True
        count_losses: counts MSE losses (B-array) if `return_losses` is True
        att_losses: prior losses (B-array), if `return_losses` is True
        input_seqs: one-hot input sequences (B x I x 4), if `return_gradients`
            is true
        input_grads: "hypothetical" input gradients (B x I x 4), if
            `return_gradients` is true
    """
    result = {}
    input_seqs, profiles = input_func(coords)
    if return_gradients:
        input_seqs_np = input_seqs
        model.zero_grad(
        )  # Zero out weights because we are computing gradients
    input_seqs = model_util.place_tensor(torch.tensor(input_seqs)).float()
    profiles = model_util.place_tensor(torch.tensor(profiles)).float()

    if controls is not None:
        tf_profs = profiles[:, :num_tasks, :, :]
        cont_profs = profiles[:, num_tasks:, :, :]  # Last half or just one
    else:
        tf_profs, cont_profs = profiles, None

    if return_losses or return_gradients:
        input_seqs.requires_grad = True  # Set gradient required
        logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs)

        # Subtract mean along output profile dimension; this wouldn't change
        # softmax probabilities, but normalizes the magnitude of gradients
        norm_logit_pred_profs = logit_pred_profs - \
            torch.mean(logit_pred_profs, dim=2, keepdim=True)

        # Weight by post-softmax probabilities, but do not take the
        # gradients of these probabilities; this upweights important regions
        # exponentially
        pred_prof_probs = profile_models.profile_logits_to_log_probs(
            logit_pred_profs).detach()
        weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

        input_grads, = torch.autograd.grad(
            weighted_norm_logits,
            input_seqs,
            grad_outputs=model_util.place_tensor(
                torch.ones(weighted_norm_logits.size())),
            retain_graph=True,
            create_graph=True
            # We'll be operating on the gradient itself, so we need to
            # create the graph
            # Gradients are summed across strands and tasks
        )
        input_grads_np = input_grads.detach().cpu().numpy()
        input_seqs.requires_grad = False  # Reset gradient required
    else:
        logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs)

    result["true_profs"] = tf_profs.detach().cpu().numpy()
    result["true_counts"] = np.sum(result["true_profs"], axis=2)
    logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy()
    result["log_pred_profs"] = profile_models.profile_logits_to_log_probs(
        logit_pred_profs_np)
    result["log_pred_counts"] = log_pred_counts.detach().cpu().numpy()

    if return_losses:
        log_pred_profs = profile_models.profile_logits_to_log_probs(
            logit_pred_profs)
        num_samples = log_pred_profs.size(0)
        result["prof_losses"] = np.empty(num_samples)
        result["count_losses"] = np.empty(num_samples)
        result["att_losses"] = np.empty(num_samples)

        # Compute losses separately for each example
        for i in range(num_samples):
            _, prof_loss, count_loss = model.correctness_loss(
                tf_profs[i:i + 1],
                log_pred_profs[i:i + 1],
                log_pred_counts[i:i + 1],
                1,
                return_separate_losses=True)
            att_loss = model.fourier_att_prior_loss(
                model_util.place_tensor(torch.ones(1)), input_grads[i:i + 1],
                fourier_att_prior_freq_limit,
                fourier_att_prior_freq_limit_softness,
                att_prior_grad_smooth_sigma)
            result["prof_losses"][i] = prof_loss
            result["count_losses"][i] = count_loss
            result["att_losses"][i] = att_loss

    if return_gradients:
        result["input_seqs"] = input_seqs_np
        result["input_grads"] = input_grads_np

    return result
def run_epoch(
    data_loader,
    mode,
    model,
    epoch_num,
    params,
    optimizer=None,
    return_data=False,
    seq_mode=True,
    att_prior_loss_weight=None,
):
    """
    Runs the data from the data loader once through the model, to train,
    validate, or predict.
    Arguments:
        `data_loader`: an instantiated `DataLoader` instance that gives batches
            of data; each batch must yield the input sequences, profiles,
            statuses, coordinates, and peaks; if `controls` is "matched",
            profiles must be such that the first half are prediction (target)
            profiles, and the second half are control profiles; if `controls` is
            "shared", the last set of profiles is a shared control; otherwise,
            all tasks are prediction profiles
        `mode`: one of "train", "eval"; if "train", run the epoch and perform
            backpropagation; if "eval", only do evaluation
        `model`: the current PyTorch model being trained/evaluated
        `epoch_num`: 0-indexed integer representing the current epoch
        `optimizer`: an instantiated PyTorch optimizer, for training mode
        `return_data`: if specified, returns the following as NumPy arrays:
            true profile raw counts (N x T x O x S), predicted profile log
            probabilities (N x T x O x S), true total counts (N x T x S),
            predicted log counts (N x T x S), coordinates used (N x 3 object
            array), input gradients (N x I x 4), and input sequences (N x I x4);
            if the attribution prior is not used, the gradients will be garbage
    Returns lists of overall losses, correctness losses, attribution prior
    losses, profile losses, and count losses, where each list is over all
    batches. If the attribution prior loss is not computed, then it will be all
    0s. If `return_data` is True, then more things will be returned after these.
    """
    num_tasks = params["num_tasks"]
    controls = params["controls"]
    if att_prior_loss_weight is None:
        att_prior_loss_weight = params["att_prior_loss_weight"]
    batch_size = params["batch_size"]
    revcomp = params["revcomp"]
    input_length = params["input_length"]
    input_depth = params["input_depth"]
    profile_length = params["profile_length"]
    gpu_id = params.get("gpu_id")

    assert mode in ("train", "eval")
    if mode == "train":
        assert optimizer is not None
    else:
        assert optimizer is None

    data_loader.dataset.on_epoch_start()  # Set-up the epoch
    num_batches = len(data_loader.dataset)
    t_iter = tqdm.tqdm(data_loader, total=num_batches, desc="\tLoss: ---")

    if mode == "train":
        model.train()  # Switch to training mode
        torch.set_grad_enabled(True)

    batch_losses, corr_losses, att_losses = [], [], []
    prof_losses, count_losses = [], []
    if return_data:
        # Allocate empty NumPy arrays to hold the results
        num_samples_exp = num_batches * batch_size
        num_samples_exp *= 2 if revcomp else 1
        # Real number of samples can be smaller because of partial last batch
        profile_shape = (num_samples_exp, num_tasks, profile_length, 2)
        count_shape = (num_samples_exp, num_tasks, 2)
        all_log_pred_profs = np.empty(profile_shape)
        all_log_pred_counts = np.empty(count_shape)
        all_true_profs = np.empty(profile_shape)
        all_true_counts = np.empty(count_shape)
        all_true_profs_trans = np.empty(profile_shape)
        all_true_counts_trans = np.empty(count_shape)
        all_input_seqs = np.empty((num_samples_exp, input_length, input_depth))
        all_input_grads = np.empty(
            (num_samples_exp, input_length, input_depth))
        all_coords = np.empty((num_samples_exp, 3), dtype=object)
        num_samples_seen = 0  # Real number of samples seen

    if seq_mode:
        model.set_seq_mode()
    else:
        model.set_aux_mode()

    for input_seqs, profiles, profiles_trans, statuses, coords, peaks in t_iter:
        if return_data:
            input_seqs_np = input_seqs
        input_seqs = util.place_tensor(torch.tensor(input_seqs),
                                       index=gpu_id).float()
        profiles = util.place_tensor(torch.tensor(profiles),
                                     index=gpu_id).float()
        profiles_trans = util.place_tensor(torch.tensor(profiles_trans),
                                           index=gpu_id).float()

        if controls is not None:
            tf_profs = profiles[:, :num_tasks, :, :].contiguous()
            cont_profs = profiles[:, num_tasks:, :, :].contiguous(
            )  # Last half or just one
            tf_profs_trans = profiles_trans[:, :num_tasks, :, :].contiguous()
            cont_profs_trans = profiles_trans[:, num_tasks:, :, :].contiguous(
            )  # Last half or just one
        else:
            tf_profs, cont_profs = profiles, None
            tf_profs_trans, cont_profs_trans = profiles, None

        # Clear gradients from last batch if training
        if mode == "train":
            optimizer.zero_grad()
        elif att_prior_loss_weight > 0:
            # Not training mode, but we still need to zero out weights because
            # we are computing the input gradients
            model.zero_grad()

        if att_prior_loss_weight > 0:
            input_seqs.requires_grad = True  # Set gradient required
            logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs,
                                                      tf_profs_trans)
            # print(input_seqs.is_contiguous()) ####
            # print(cont_profs.is_contiguous()) ####
            # print(tf_profs_trans.is_contiguous()) ####

            # Subtract mean along output profile dimension; this wouldn't change
            # softmax probabilities, but normalizes the magnitude of gradients
            norm_logit_pred_profs = logit_pred_profs - \
                torch.mean(logit_pred_profs, dim=2, keepdim=True)

            # Weight by post-softmax probabilities, but do not take the
            # gradients of these probabilities; this upweights important regions
            # exponentially
            pred_prof_probs = profile_models.profile_logits_to_log_probs(
                logit_pred_profs).detach()
            weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

            if seq_mode:
                model.set_seq_mode()

            input_grads, = torch.autograd.grad(
                weighted_norm_logits,
                input_seqs,
                grad_outputs=util.place_tensor(torch.ones(
                    weighted_norm_logits.size()),
                                               index=gpu_id),
                retain_graph=True,
                create_graph=True
                # We'll be operating on the gradient itself, so we need to
                # create the graph
                # Gradients are summed across strands and tasks
            )
            # print(torch.autograd.grad(torch.sum(input_grads), input_seqs, retain_graph=True)) ####
            # print(input_grads.shape, input_seqs.shape) ####
            # input_grads = input_grads.contiguous()
            if return_data:
                input_grads_np = input_grads.detach().cpu().numpy()
            # input_grads = input_grads.detach() ####
            input_grads = input_grads * input_seqs  # Gradient * input
            # input_grads = input_seqs ####
            status = util.place_tensor(torch.tensor(statuses), index=gpu_id)
            status[status != 0] = 1  # Set to 1 if not negative example
            input_seqs.requires_grad = False  # Reset gradient required
        else:
            logit_pred_profs, log_pred_counts = model(input_seqs, cont_profs,
                                                      tf_profs_trans)
            status, input_grads = None, None

        loss, (corr_loss,
               att_loss), (prof_loss,
                           count_loss) = model_loss(model,
                                                    tf_profs,
                                                    logit_pred_profs,
                                                    log_pred_counts,
                                                    epoch_num,
                                                    status=status,
                                                    input_grads=input_grads)
        # print(input_grads) ####
        # print(input_grads.shape) ####

        if mode == "train":
            # att_loss.backward() ####

            loss.backward()  # Compute gradient
            optimizer.step()  # Update weights through backprop

        # if not ignore_aux:
        #    model.unfreeze_ptp_layers()

        batch_losses.append(loss.item())
        corr_losses.append(corr_loss.item())
        att_losses.append(att_loss.item())
        prof_losses.append(prof_loss.item())
        count_losses.append(count_loss.item())
        t_iter.set_description("\tLoss: %6.4f" % loss.item())

        if return_data:
            logit_pred_profs_np = logit_pred_profs.detach().cpu().numpy()
            log_pred_counts_np = log_pred_counts.detach().cpu().numpy()
            true_profs_np = tf_profs.detach().cpu().numpy()
            true_counts = np.sum(true_profs_np, axis=2)
            true_profs_trans_np = tf_profs_trans.detach().cpu().numpy()
            true_counts_trans = np.sum(true_profs_trans_np, axis=2)

            num_in_batch = true_counts.shape[0]

            # Turn logit profile predictions into log probabilities
            log_pred_profs = profile_models.profile_logits_to_log_probs(
                logit_pred_profs_np, axis=2)

            # Fill in the batch data/outputs into the preallocated arrays
            start, end = num_samples_seen, num_samples_seen + num_in_batch
            all_log_pred_profs[start:end] = log_pred_profs
            all_log_pred_counts[start:end] = log_pred_counts_np
            all_true_profs[start:end] = true_profs_np
            all_true_counts[start:end] = true_counts
            all_true_profs_trans[start:end] = true_profs_trans_np
            all_true_counts_trans[start:end] = true_counts_trans
            all_input_seqs[start:end] = input_seqs_np
            if att_prior_loss_weight:
                all_input_grads[start:end] = input_grads_np
            all_coords[start:end] = coords

            num_samples_seen += num_in_batch

    if return_data:
        # Truncate the saved data to the proper size, based on how many
        # samples actually seen
        all_log_pred_profs = all_log_pred_profs[:num_samples_seen]
        all_log_pred_counts = all_log_pred_counts[:num_samples_seen]
        all_true_profs = all_true_profs[:num_samples_seen]
        all_true_counts = all_true_counts[:num_samples_seen]
        all_true_profs_trans = all_true_profs_trans[:num_samples_seen]
        all_true_counts_trans = all_true_counts_trans[:num_samples_seen]
        all_input_seqs = all_input_seqs[:num_samples_seen]
        all_input_grads = all_input_grads[:num_samples_seen]
        all_coords = all_coords[:num_samples_seen]
        return batch_losses, corr_losses, att_losses, prof_losses, \
            count_losses, all_true_profs, all_log_pred_profs, \
            all_true_counts, all_log_pred_counts, all_coords, all_input_grads, \
            all_input_seqs, all_true_profs_trans, all_true_counts_trans
    else:
        return batch_losses, corr_losses, att_losses, prof_losses, count_losses