Beispiel #1
0
    def __getitem__(self, index):
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + 64))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        tertiary = tertiary[:, 1] / 100

        protein = SubgraphStructure(
            torch.zeros(tertiary.size(0), dtype=torch.long))
        #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long))
        distances, _ = scatter.pairwise_no_pad(
            lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices)
        distances = distances[:, None]

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        #assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(angles.permute(1, 0).contiguous()),
                  PackedTensor(distances), protein)

        return inputs
Beispiel #2
0
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances = tertiary.permute(2, 0, 1) / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(distances), PackedTensor(primary_onehot),
                  protein)

        return inputs
Beispiel #3
0
    def __getitem__(self, index):
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + 500))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        angles = angles.permute(1, 0).contiguous()
        angles_gt = angles.clone()
        # angles = angles.roll(1, dims=0)
        # angles[0] = 0
        distances = distances / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = self.autoregressive_structure(distances)

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(angles), PackedTensor(distances), neighbours,
                  protein)

        return inputs, PackedTensor(
            angles_gt, split=False
        )  #, (PackedTensor(distances, split=False), PackedTensor(torch.ones(distances.size(0)), split=False), protein)
Beispiel #4
0
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        tertiary = tertiary.reshape(-1, 3) / 100

        protein = SubgraphStructure(
            torch.zeros(tertiary.size(0), dtype=torch.long))
        #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long))
        #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices)
        distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1)
        distances = distances.unsqueeze(0)

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        return (distances / 100, )
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        if primary.size(0) < 30:
            return GANNet.__getitem__(self, (index + 1) % len(self))

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        angles = angles.permute(1, 0).contiguous()
        distances = distances / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(angles), protein)

        return inputs, PackedTensor(
            angles, split=False), (PackedTensor(distances, split=False),
                                   PackedTensor(torch.ones(distances.size(0)),
                                                split=False), protein)
Beispiel #6
0
    def __getitem__(self, index):
        result = super().__getitem__(index)
        primary = result["primary"][:500]
        evolutionary = result["evolutionary"][:, :500].t()
        tertiary = result["tertiary"] / 100
        tertiary = tertiary[[0, 1, 3], :, :].permute(
            2, 0, 1).contiguous()[:500].view(-1, 3)
        angles = result["angles"].contiguous().view(-1, 3)[:500].contiguous()
        mask = result["mask"][:500].view(-1)

        print(angles.min(), angles.max())

        mask = mask  #torch.repeat_interleave(mask, 3)

        membership = SubgraphStructure(
            torch.zeros(primary.size(0), dtype=torch.long))
        primary_onehot = one_hot_encode(primary - 1, range(20)).t()
        primary_onehot = torch.cat((primary_onehot, evolutionary), dim=1)

        inputs = (PackedTensor(primary_onehot), membership)

        outputs = (PackedTensor(tertiary, split=False),
                   PackedTensor(mask.unsqueeze(1), split=False), membership)

        print("tsize", angles.size())

        return inputs, outputs, (PackedTensor(angles, split=False),
                                 PackedTensor(mask.view(-1), split=False))
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        #assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(
            (angles + 0.01 * torch.randn_like(angles)).permute(1, 0)),
                  PackedTensor(primary_onehot), protein)

        return inputs
  def __getitem__(self, index):
    N = self.N
    index = self.valid_indices[index]
    window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N))
    inds = self.inds[window]
    primary = self.pris[window] - 1

    # add noise:
    n_positions = random.randrange(max(1, primary.size(0) // 100))
    primary[torch.randint(0, primary.size(0), (n_positions,))] = torch.randint(0, 20, (n_positions,))

    tertiary = self.ters[:, :, window]
    tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
    tertiary = tertiary[:, 1] / 100

    ors = self.orientations(tertiary)
    print(ors.shape)

    count = tertiary.size(0)

    x_range = torch.repeat_interleave(torch.arange(count), count * torch.ones(count, dtype=torch.long))
    y_range = torch.arange(count ** 2) - x_range * count
    t_x = tertiary[x_range]
    t_y = tertiary[y_range]
    o_x = ors[x_range]
    o_y = ors[y_range]
    _, directions, rotations = relative_orientation(t_x, t_y, o_x, o_y)
    directions = directions.reshape(count, count, *directions.shape[1:]).permute(2, 0, 1).contiguous()
    rotations = rotations.reshape(count, count, *rotations.shape[1:]).permute(2, 0, 1).contiguous()

    mask = torch.arange(N)
    mask = (mask[:, None] - mask[None, :]) > 0
    mask = mask.float()

    directions = mask * directions + (1 - mask) * directions.permute(0, 2, 1).contiguous()
    rotations = mask * rotations + (1 - mask) * rotations.permute(0, 2, 1).contiguous()

    protein = SubgraphStructure(torch.zeros(tertiary.size(0), dtype=torch.long))
    #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long))
    #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices)
    distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1)
    distances = distances.unsqueeze(0)

    primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
    primary_onehot[torch.arange(primary.size(0)), primary] = 1
    primary_onehot = primary_onehot.clamp(0, 1)

    result = distances / 100
    result = torch.cat((result, rotations, directions), dim=0)

    return (result,)
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        start = self.index[index]
        end = self.index[index + 1]
        rstart = start
        if end - start > N:
            rstart = random.randrange(start, end - N)
        rend = rstart + N
        window = slice(rstart, rend)
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # get sequence positions
        keeps = self.keeps[window]
        keeps = keeps - keeps[0]

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 5))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = tertiary.permute(2, 0, 1)[:, 1] / 100

        # distances = (distances[None, :] - distances[:, None]).norm(dim=-1).unsqueeze(0) / 40
        # distances = 0.99 * distances + 0.01 * torch.rand_like(distances)
        # distances = distances.clamp(0, 1)
        # print(distances.max())
        sequence = torch.zeros(42, N, N)
        sequence[-1] = 1

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot + 0.1 * torch.rand_like(primary_onehot)
        primary_onehot = primary_onehot.clamp(0, 1)

        inputs = ((distances, primary_onehot.transpose(0, 1)), keeps)

        return inputs
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        evolutionary = self.evos[:, window]
        tertiary = self.ters[:, :, window]
        orientation = self.ors[window, :, :].view(window.stop - window.start,
                                                  -1)
        distances = self.ters[1, :, window].transpose(0, 1) / 100
        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)

        orientation = torch.cat((distances, orientation, indices), dim=1)
        angles = self.angs[:, window].transpose(0, 1)

        protein = SubgraphStructure(
            torch.zeros(indices.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        angle_features = torch.cat((sin, cos), dim=1)

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        #primary_onehot = primary_onehot + 0.05 * torch.randn_like(primary_onehot)
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (
            PackedTensor(primary_onehot),
            PackedTensor(primary_onehot),  # gt_ignore
            PackedTensor(angle_features),
            PackedTensor(orientation),
            neighbours,
            protein)

        return inputs, PackedTensor(primary, split=False)
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 5))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100

        distances = (distances[None, :] -
                     distances[:, None]).norm(dim=-1).unsqueeze(0) / 100
        distances = 0.99 * distances + 0.01 * torch.rand_like(distances)
        sequence = torch.zeros(42, N, N)
        sequence[-1] = 1

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)
        #    sequence[:20, :, :] = primary_onehot.permute(1, 0)[:, None, :]
        #    sequence[20:40, :, :] = primary_onehot.permute(1, 0)[:, :, None]
        #
        #    mask = torch.rand(primary.size(0)) < torch.rand(1)
        #
        #    sequence[:, mask, :] = 0
        #    sequence[-1, mask, :] = 1
        #    sequence[:, :, mask] = 0
        #    sequence[-1, :, mask] = 1

        #assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (distances, sequence)

        return inputs
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))

        left = torch.repeat_interleave(torch.arange(3),
                                       3 * torch.ones(3, dtype=torch.long))
        right = torch.arange(9) - 3 * left
        tertiary = tertiary / 100

        protein = SubgraphStructure(
            torch.zeros(tertiary.size(0), dtype=torch.long))
        #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long))
        #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices)
        distances = (tertiary[None, :, right, :] -
                     tertiary[:, None, left, :]).norm(dim=-1)
        distances = distances.permute(2, 0, 1).contiguous()
        #distances = distances.unsqueeze(0)

        mask = torch.arange(N)
        mask = (mask[:, None] - mask[None, :]) > 0
        mask = mask.float()

        distances = mask * distances + (1 - mask) * distances.permute(0, 2, 1)
        #distances[1:] = distances[1:] - distances[:1]

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        return (distances / 100, )
    def __getitem__(self, index):
        # Extract the boundaries of a whole protein
        window = slice(self.index[index], self.index[index + 1])
        seq_len = window.stop - window.start

        # Make me a mask
        # Predict at least 5% of the sequence up to the whole seq
        mask = np.random.choice(seq_len,
                                size=np.random.randint(1, seq_len),
                                replace=False)
        mask = torch.tensor(mask, dtype=torch.long)
        mask_binary = torch.zeros(seq_len, dtype=torch.uint8)
        mask_binary[mask] = 1

        # Get sequence info
        primary = self.pris[window] - 1

        primary_masked = primary.clone()
        primary_masked[mask] = 20
        primary_onehot = torch.zeros((seq_len, 21), dtype=torch.float)
        primary_onehot[torch.arange(seq_len), primary_masked] = 1

        # Prepare neighborhood structure
        inds = self.inds[window]
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        # Prepare orientation infos
        orientation = self.ors[window, :, :].view(seq_len, -1)

        tertiary = self.ters[:, :, window]
        tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        tertiary = tertiary[:, 1] / 100
        #tertiary = tertiary + 0.01 * torch.randn_like(tertiary) # corruption FIXME
        angles = angles.transpose(0, 1)

        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        relative_indices = indices[None, :] - indices[:, None]
        relative_sin = (relative_indices / 10).sin()
        relative_cos = (relative_indices / 10).cos()

        distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1)
        distances = distances.unsqueeze(0)

        inds = torch.arange(distances.size(-1))
        idx_a = inds[:, None]
        idy_a = inds[None, :]
        chain_angle = self.fast_angle(
            distances, idx_a - 1, idx_a,
            (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0)
        chain_dihedral = self.fast_dihedral(
            distances, idx_a - 1, idx_a - 2, idx_a,
            (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0)
        contact_angles = self.fast_angle(distances, idx_a - 1, idx_a, idy_a)
        contact_dihedrals = self.fast_dihedral(distances, idx_a, idx_a - 1,
                                               idy_a, idy_a - 1)
        into_contact_dihedrals = self.fast_dihedral(distances, idx_a - 1,
                                                    idx_a - 2, idx_a, idy_a)

        angle_features = torch.cat(
            (chain_angle.sin(), chain_angle.cos(), chain_dihedral.sin(),
             chain_dihedral.cos()),
            dim=1)

        orientation = torch.cat(
            (distances, contact_angles.sin(), contact_angles.cos(),
             contact_dihedrals.sin(), contact_dihedrals.cos(),
             into_contact_dihedrals.sin(), into_contact_dihedrals.cos(),
             relative_sin[None], relative_cos[None]),
            dim=0)
        orientation_slice = orientation[:,
                                        torch.
                                        arange(neighbours.connections.size(0)
                                               )[:, None],
                                        neighbours.connections].permute(
                                            1, 2, 0).contiguous()

        # Prepare local features
        dmap = (tertiary[None, :] - tertiary[:, None]).norm(dim=-1)
        closest = torch.arange(tertiary.size(0))
        closest = abs(closest[None, :] - closest[:, None]).topk(15,
                                                                dim=1).indices
        local_features = dmap[torch.arange(dmap.size(0))[:, None],
                              closest] / 100

        protein = SubgraphStructure(torch.zeros_like(inds))

        inputs = (PackedTensor(angle_features), PackedTensor(primary_onehot),
                  PackedTensor(orientation_slice), neighbours, protein)

        targets = (PackedTensor(primary, split=False),
                   PackedTensor(mask_binary, split=False))

        return inputs, targets