Example #1
0
def alignment_score_kernel(true_states: str,
                           pred_states: str,
                           kernel_widths: list,
                           query_offset: int = 0,
                           hit_offset: int = 0,
                           no_gaps=True):
    """ Computes ROC statistics on alignment.

    Parameters
    ----------
    true_states : str
        Ground truth state string
    pred_states : str
        Predicted state string
    """
    pred_states = list(map(tmstate_f, pred_states))
    true_states = list(map(tmstate_f, true_states))
    pred_edges = states2edges(pred_states)
    true_edges = states2edges(true_states)
    # add offset to account for local alignments
    true_edges = list(map(tuple, np.array(true_edges)))
    pred_edges = np.array(pred_edges)
    pred_edges[:, 0] += query_offset
    pred_edges[:, 1] += hit_offset
    pred_edges = list(map(tuple, pred_edges))
    if no_gaps:
        pred_edges = filter_gaps(pred_states, pred_edges)
        true_edges = filter_gaps(true_states, true_edges)

    res = []
    for k in kernel_widths:
        r = roc_edges_kernel_identity(true_edges, pred_edges, k)
        res.append(r)
    return res
Example #2
0
def alignment_score(true_states: str, pred_states: str):
    """
    Computes ROC statistics on alignment
    Parameters
    ----------
    true_states : str
        Ground truth state string
    pred_states : str
        Predicted state string
    """
    pred_states = list(map(tmstate_f, pred_states))
    true_states = list(map(tmstate_f, true_states))
    pred_edges = states2edges(pred_states)
    true_edges = states2edges(true_states)
    stats = roc_edges(true_edges, pred_edges)
    return stats
Example #3
0
 def validation_stats(self, x, y, xlen, ylen, gen, states, A, predA, theta,
                      gap, batch_idx):
     statistics = []
     for b in range(len(xlen)):
         # TODO: Issue #47
         x_str = decode(
             list(x[b, :xlen[b]].squeeze().cpu().detach().numpy()),
             self.tokenizer.alphabet)
         y_str = decode(
             list(y[b, :ylen[b]].squeeze().cpu().detach().numpy()),
             self.tokenizer.alphabet)
         decoded, _ = next(gen)
         pred_x, pred_y, pred_states = list(zip(*decoded))
         pred_states = np.array(list(pred_states))
         true_states = states[b].cpu().detach().numpy()
         pred_edges = states2edges(pred_states)
         true_edges = states2edges(true_states)
         pred_edges = filter_gaps(pred_states, pred_edges)
         true_edges = filter_gaps(true_states, true_edges)
         stats = roc_edges(true_edges, pred_edges)
         if random.random() < self.hparams.visualization_fraction:
             Av = A[b].cpu().detach().numpy().squeeze()
             pv = predA[b].cpu().detach().numpy().squeeze()
             tv = theta[b].cpu().detach().numpy().squeeze()
             gv = gap[b].cpu().detach().numpy().squeeze()
             fig, _ = alignment_visualization(Av, pv, tv, gv, xlen[b],
                                              ylen[b])
             self.logger.experiment.add_figure(
                 f'alignment-matrix/{batch_idx}/{b}',
                 fig,
                 self.global_step,
                 close=True)
             try:
                 text = alignment_text(x_str, y_str, pred_states,
                                       true_states, stats)
                 self.logger.experiment.add_text(
                     f'alignment/{batch_idx}/{b}', text, self.global_step)
             except Exception as e:
                 print(predA[b])
                 print(A[b])
                 print(theta[b])
                 print(xlen[b], ylen[b])
                 raise e
         statistics.append(stats)
     return statistics
Example #4
0
    def __getitem__(self, i):
        """ Gets alignment pair.

        Parameters
        ----------
        i : int
           Index of item

        Returns
        -------
        gene : torch.Tensor
           Encoded representation of protein of interest
        pos : torch.Tensor
           Encoded representation of protein that aligns with `gene`.
        states : torch.Tensor
           Alignment string
        alignment_matrix : torch.Tensor
           Ground truth alignment matrix
        path_matrix : torch.Tensor
           Pairwise path distances, where the smallest distance
           to the path is computed for every element in the matrix.
        """
        gene = self.pairs.iloc[i]['chain1']
        pos = self.pairs.iloc[i]['chain2']
        st = self.pairs.iloc[i]['alignment']

        states = list(map(tmstate_f, st))
        if self.clip_ends:
            gene, pos, states, st = clip_boundaries(gene, pos, states, st)

        if self.pad_ends:
            states = [m] + states + [m]

        states = torch.Tensor(states).long()
        gene = self.tokenizer(str.encode(gene))
        pos = self.tokenizer(str.encode(pos))
        gene = torch.Tensor(gene).long()
        pos = torch.Tensor(pos).long()
        alignment_matrix = torch.from_numpy(states2matrix(states))
        path_matrix = torch.empty(*alignment_matrix.shape)
        g_mask = torch.ones(*alignment_matrix.shape)
        if self.construct_paths:
            pi = states2edges(states)
            path_matrix = torch.from_numpy(path_distance_matrix(pi))
            path_matrix = reshape(path_matrix, len(gene), len(pos))
        if self.mask_gaps:
            g_mask = torch.from_numpy(gap_mask(st)).bool()

        alignment_matrix = reshape(alignment_matrix, len(gene), len(pos))
        g_mask = reshape(g_mask, len(gene), len(pos))
        if not self.return_names:
            return gene, pos, states, alignment_matrix, path_matrix, g_mask
        else:
            gene_name = self.pairs.iloc[i]['chain1_name']
            pos_name = self.pairs.iloc[i]['chain2_name']
            return (gene, pos, states, alignment_matrix, path_matrix, g_mask,
                    gene_name, pos_name)
Example #5
0
def alignment_score(true_states: str, pred_states: str, no_gaps=True):
    """ Computes ROC statistics on alignment.

    Parameters
    ----------
    true_states : str
        Ground truth state string
    pred_states : str
        Predicted state string
    no_gaps : bool
        Removes all gaps before computing ROC stats.
    """
    pred_states = list(map(tmstate_f, pred_states))
    true_states = list(map(tmstate_f, true_states))
    pred_edges = states2edges(pred_states)
    true_edges = states2edges(true_states)
    if no_gaps:
        pred_edges = filter_gaps(pred_states, pred_edges)
        true_edges = filter_gaps(true_states, true_edges)

    stats = roc_edges(true_edges, pred_edges)
    return stats