예제 #1
0
    def __getitem__(self, i):
        """ Gets alignment pair.

        Parameters
        ----------
        i : int
           Index of item

        Returns
        -------
        gene : torch.Tensor
           Encoded representation of protein of interest
        pos : torch.Tensor
           Encoded representation of protein that aligns with `gene`.
        states : torch.Tensor
           Alignment string
        alignment_matrix : torch.Tensor
           Ground truth alignment matrix
        """
        gene = self.pairs.loc[i, 0]
        pos = self.pairs.loc[i, 1]
        assert len(gene) == len(pos)
        alnstr = list(zip(list(gene), list(pos)))
        states = torch.Tensor(list(map(state_f, alnstr)))
        gene = self.tokenizer(str.encode(gene.replace('-', '')))
        pos = self.tokenizer(str.encode(pos.replace('-', '')))
        gene = torch.Tensor(gene).long()
        pos = torch.Tensor(pos).long()
        alignment_matrix = torch.from_numpy(states2matrix(states))
        return gene, pos, states, alignment_matrix
예제 #2
0
 def test_states2matrix_swap_y(self):
     s = "::1122::"
     s = np.array(list(map(tmstate_f, s)))
     npt.assert_allclose(s, np.array([1, 1, 0, 0, 2, 2, 1, 1]))
     M = states2matrix(s, sparse=True)
     res_coords = list(zip(list(M.row), list(M.col)))
     exp_coords = [(0, 0), (1, 1), (2, 1), (3, 1), (3, 2), (3, 3), (4, 4),
                   (5, 5)]
     self.assertListEqual(res_coords, exp_coords)
예제 #3
0
 def test_states2matrix_only_matches(self):
     s = ":11::11:"
     s = np.array(list(map(tmstate_f, s)))
     npt.assert_allclose(s, np.array([1, 0, 0, 1, 1, 0, 0, 1]))
     M = states2matrix(s, sparse=True)
     res_coords = list(zip(list(M.row), list(M.col)))
     exp_coords = [(0, 0), (1, 0), (2, 0), (3, 1), (4, 2), (5, 2), (6, 2),
                   (7, 3)]
     self.assertListEqual(res_coords, exp_coords)
예제 #4
0
    def __getitem__(self, i):
        """ Gets alignment pair.

        Parameters
        ----------
        i : int
           Index of item

        Returns
        -------
        gene : torch.Tensor
           Encoded representation of protein of interest
        pos : torch.Tensor
           Encoded representation of protein that aligns with `gene`.
        states : torch.Tensor
           Alignment string
        alignment_matrix : torch.Tensor
           Ground truth alignment matrix
        path_matrix : torch.Tensor
           Pairwise path distances, where the smallest distance
           to the path is computed for every element in the matrix.
        """
        gene = self.pairs.iloc[i]['chain1']
        pos = self.pairs.iloc[i]['chain2']
        st = self.pairs.iloc[i]['alignment']

        states = list(map(tmstate_f, st))
        if self.clip_ends:
            gene, pos, states, st = clip_boundaries(gene, pos, states, st)

        if self.pad_ends:
            states = [m] + states + [m]

        states = torch.Tensor(states).long()
        gene = self.tokenizer(str.encode(gene))
        pos = self.tokenizer(str.encode(pos))
        gene = torch.Tensor(gene).long()
        pos = torch.Tensor(pos).long()
        alignment_matrix = torch.from_numpy(states2matrix(states))
        path_matrix = torch.empty(*alignment_matrix.shape)
        g_mask = torch.ones(*alignment_matrix.shape)
        if self.construct_paths:
            pi = states2edges(states)
            path_matrix = torch.from_numpy(path_distance_matrix(pi))
            path_matrix = reshape(path_matrix, len(gene), len(pos))
        if self.mask_gaps:
            g_mask = torch.from_numpy(gap_mask(st)).bool()

        alignment_matrix = reshape(alignment_matrix, len(gene), len(pos))
        g_mask = reshape(g_mask, len(gene), len(pos))
        if not self.return_names:
            return gene, pos, states, alignment_matrix, path_matrix, g_mask
        else:
            gene_name = self.pairs.iloc[i]['chain1_name']
            pos_name = self.pairs.iloc[i]['chain2_name']
            return (gene, pos, states, alignment_matrix, path_matrix, g_mask,
                    gene_name, pos_name)
예제 #5
0
 def test_states2matrix_zinc(self):
     s = ':1111::::1:'
     # x = 'RGCFH '
     # y = 'YGSVHASERH'
     s = np.array(list(map(tmstate_f, s)))
     states2matrix(s, sparse=True)