def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        if primary.size(0) < 30:
            return GANNet.__getitem__(self, (index + 1) % len(self))

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        angles = angles.permute(1, 0).contiguous()
        distances = distances / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(angles), protein)

        return inputs, PackedTensor(
            angles, split=False), (PackedTensor(distances, split=False),
                                   PackedTensor(torch.ones(distances.size(0)),
                                                split=False), protein)
Beispiel #2
0
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances = tertiary.permute(2, 0, 1) / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(distances), PackedTensor(primary_onehot),
                  protein)

        return inputs
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)

        #assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (PackedTensor(
            (angles + 0.01 * torch.randn_like(angles)).permute(1, 0)),
                  PackedTensor(primary_onehot), protein)

        return inputs
Beispiel #4
0
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1
        evolutionary = self.evos[:, window]
        tertiary = self.ters[:, :, window]
        orientation = self.ors[window, :, :].view(window.stop - window.start,
                                                  -1)
        distances = self.ters[1, :, window].transpose(0, 1) / 100
        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)

        orientation = torch.cat((distances, orientation, indices), dim=1)
        angles = self.angs[:, window].transpose(0, 1)

        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        angle_features = torch.cat((sin, cos), dim=1)

        inputs = (PackedTensor(angle_features), PackedTensor(primary),
                  PackedTensor(orientation), neighbours)

        return inputs, PackedTensor(primary, split=False)
    def __getitem__(self, index):
        # Extract the boundaries of a whole protein
        window = slice(self.index[index], self.index[index + 1])
        seq_len = window.stop - window.start

        # Make me a mask
        # Predict at least 5% of the sequence up to the whole seq
        mask = np.random.choice(seq_len,
                                size=np.random.randint(seq_len // 20, seq_len),
                                replace=False)
        mask = torch.tensor(mask, dtype=torch.long)
        mask_binary = torch.zeros(seq_len, dtype=torch.uint8)
        mask_binary[mask] = 1

        # Get sequence info
        primary = self.pris[window] - 1

        primary_masked = primary.clone()
        primary_masked[mask] = 20
        primary_onehot = torch.zeros((seq_len, 21), dtype=torch.float)
        primary_onehot[torch.arange(seq_len), primary_masked] = 1

        # Prepare orientation infos
        orientation = self.ors[window, :, :].view(seq_len, -1)

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100
        distances = distances + 0.01 * torch.randn_like(
            distances)  # corruption FIXME
        angles = angles.transpose(0, 1)

        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)
        orientation = torch.cat((distances, indices), dim=1)

        # Prepare neighborhood structure
        inds = self.inds[window]
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        # Prepare local features
        dmap = (distances[None, :] - distances[:, None]).norm(dim=-1)
        closest = torch.arange(distances.size(0))
        closest = abs(closest[None, :] - closest[:, None]).topk(15,
                                                                dim=1).indices
        local_features = dmap[torch.arange(dmap.size(0))[:, None],
                              closest] / 100

        inputs = (PackedTensor(local_features), PackedTensor(primary_onehot),
                  PackedTensor(orientation), neighbours)

        targets = (PackedTensor(primary, split=False),
                   PackedTensor(mask_binary, split=False))

        return inputs, targets
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        start = self.index[index]
        end = self.index[index + 1]
        rstart = start
        if end - start > N:
            rstart = random.randrange(start, end - N)
        rend = rstart + N
        window = slice(rstart, rend)
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # get sequence positions
        keeps = self.keeps[window]
        keeps = keeps - keeps[0]

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 5))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = tertiary.permute(2, 0, 1)[:, 1] / 100

        # distances = (distances[None, :] - distances[:, None]).norm(dim=-1).unsqueeze(0) / 40
        # distances = 0.99 * distances + 0.01 * torch.rand_like(distances)
        # distances = distances.clamp(0, 1)
        # print(distances.max())
        sequence = torch.zeros(42, N, N)
        sequence[-1] = 1

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot + 0.1 * torch.rand_like(primary_onehot)
        primary_onehot = primary_onehot.clamp(0, 1)

        inputs = ((distances, primary_onehot.transpose(0, 1)), keeps)

        return inputs
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 100))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        evolutionary = self.evos[:, window]
        tertiary = self.ters[:, :, window]
        orientation = self.ors[window, :, :].view(window.stop - window.start,
                                                  -1)
        distances = self.ters[1, :, window].transpose(0, 1) / 100
        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)

        orientation = torch.cat((distances, orientation, indices), dim=1)
        angles = self.angs[:, window].transpose(0, 1)

        protein = SubgraphStructure(
            torch.zeros(indices.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        angle_features = torch.cat((sin, cos), dim=1)

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        #primary_onehot = primary_onehot + 0.05 * torch.randn_like(primary_onehot)
        primary_onehot = primary_onehot.clamp(0, 1)

        assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (
            PackedTensor(primary_onehot),
            PackedTensor(primary_onehot),  # gt_ignore
            PackedTensor(angle_features),
            PackedTensor(orientation),
            neighbours,
            protein)

        return inputs, PackedTensor(primary, split=False)
    def __getitem__(self, index):
        N = self.N
        index = self.valid_indices[index]
        window = slice(self.index[index],
                       min(self.index[index + 1], self.index[index] + N))
        inds = self.inds[window]
        primary = self.pris[window] - 1

        # add noise:
        n_positions = random.randrange(max(1, primary.size(0) // 5))
        primary[torch.randint(0, primary.size(0),
                              (n_positions, ))] = torch.randint(
                                  0, 20, (n_positions, ))

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100

        distances = (distances[None, :] -
                     distances[:, None]).norm(dim=-1).unsqueeze(0) / 100
        distances = 0.99 * distances + 0.01 * torch.rand_like(distances)
        sequence = torch.zeros(42, N, N)
        sequence[-1] = 1

        protein = SubgraphStructure(
            torch.zeros(distances.size(0), dtype=torch.long))
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
        primary_onehot[torch.arange(primary.size(0)), primary] = 1
        primary_onehot = primary_onehot.clamp(0, 1)
        #    sequence[:20, :, :] = primary_onehot.permute(1, 0)[:, None, :]
        #    sequence[20:40, :, :] = primary_onehot.permute(1, 0)[:, :, None]
        #
        #    mask = torch.rand(primary.size(0)) < torch.rand(1)
        #
        #    sequence[:, mask, :] = 0
        #    sequence[-1, mask, :] = 1
        #    sequence[:, :, mask] = 0
        #    sequence[-1, :, mask] = 1

        #assert neighbours.connections.max() < primary_onehot.size(0)
        inputs = (distances, sequence)

        return inputs
    def __getitem__(self, index):
        window = slice(self.index[index], self.index[index + 1])
        inds = self.inds[window]
        primary = self.pris[window] - 1
        if len(primary) > 500:
            return self.__getitem__((index + 1) % len(self))
        evolutionary = self.evos[:, window]
        tertiary = self.ters[:, :, window]
        orientation = self.ors[window, :, :].view(window.stop - window.start,
                                                  -1)
        distances = self.ters[1, :, window].transpose(0, 1) / 100
        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)

        orientation = torch.cat((distances, orientation, indices), dim=1)
        angles = self.angs[:, window].transpose(0, 1)

        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        mask = torch.zeros(primary.size(0))
        number_of_masked = random.randrange(0, primary.size(0))
        mask_positions = torch.randint(0, primary.size(0),
                                       (number_of_masked, ))
        primary_onehot = torch.zeros(primary.size(0), 20)
        primary_onehot[torch.arange(0, primary.size(0)), primary.view(-1)] = 1
        primary_onehot[mask.nonzero().view(-1)] = 0
        mask[mask_positions] = 1.0

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        angle_features = torch.cat((sin, cos), dim=1)

        #features = torch.cat((angle_features, primary_onehot, mask.unsqueeze(-1)), dim=1)
        primary_logits = torch.log(primary_onehot + 1e-16)

        inputs = (PackedTensor(angle_features,
                               box=True), PackedTensor(primary_onehot,
                                                       box=True),
                  PackedTensor(mask,
                               box=True), PackedTensor(orientation,
                                                       box=True), neighbours)

        return inputs, PackedTensor(primary_logits, box=True)
    def __getitem__(self, index):
        # Extract the boundaries of a whole protein
        window = slice(self.index[index], self.index[index + 1])
        seq_len = window.stop - window.start

        # Make me a mask (used for the conditional training)
        # Predict at least 5% of the sequence up to the whole seq
        mask = np.random.choice(seq_len,
                                size=np.random.randint(seq_len // 20, seq_len),
                                replace=False)
        mask = torch.tensor(mask, dtype=torch.long)
        mask_binary = torch.zeros(seq_len, dtype=torch.uint8)
        mask_binary[mask] = 1

        # Get sequence info
        primary = (self.pris[window] - 1).clone()

        # Do the resampling
        to_resample = np.random.choice(seq_len,
                                       size=(self.desired_resample *
                                             seq_len.float()).long().numpy(),
                                       replace=False)
        pssm = self.evos[:20, window][:, to_resample]

        try:
            resampled = torch.multinomial(pssm.t(), 1).flatten()
            primary[to_resample] = resampled
        except:
            pass  # TODO

        # Run the masking for conditional transformer training
        primary_masked = primary.clone()
        primary_masked[mask] = 20
        primary_onehot = torch.zeros((seq_len, 21), dtype=torch.float)
        primary_onehot[torch.arange(seq_len), primary_masked] = 1

        # Prepare orientation infos
        orientation = self.ors[window, :, :].view(seq_len, -1)

        tertiary = self.ters[:, :, window]
        distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        distances = distances[:, 1] / 100
        angles = angles.transpose(0, 1)
        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        indices = indices.view(-1, 1)
        orientation = torch.cat((distances, orientation, indices), dim=1)

        # Prepare neighborhood structure
        inds = self.inds[window]
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        # Prepare angle features
        sin = torch.sin(angles)
        cos = torch.cos(angles)
        angle_features = torch.cat((sin, cos), dim=1)

        inputs = (PackedTensor(angle_features), PackedTensor(primary_onehot),
                  PackedTensor(orientation), neighbours)

        targets = (PackedTensor(primary, split=False),
                   PackedTensor(mask_binary, split=False))

        return inputs, targets
    def __getitem__(self, index):
        # Extract the boundaries of a whole protein
        window = slice(self.index[index], self.index[index + 1])
        seq_len = window.stop - window.start

        # Make me a mask
        # Predict at least 5% of the sequence up to the whole seq
        mask = np.random.choice(seq_len,
                                size=np.random.randint(1, seq_len),
                                replace=False)
        mask = torch.tensor(mask, dtype=torch.long)
        mask_binary = torch.zeros(seq_len, dtype=torch.uint8)
        mask_binary[mask] = 1

        # Get sequence info
        primary = self.pris[window] - 1

        primary_masked = primary.clone()
        primary_masked[mask] = 20
        primary_onehot = torch.zeros((seq_len, 21), dtype=torch.float)
        primary_onehot[torch.arange(seq_len), primary_masked] = 1

        # Prepare neighborhood structure
        inds = self.inds[window]
        neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(
            torch.long))

        # Prepare orientation infos
        orientation = self.ors[window, :, :].view(seq_len, -1)

        tertiary = self.ters[:, :, window]
        tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
        tertiary = tertiary[:, 1] / 100
        #tertiary = tertiary + 0.01 * torch.randn_like(tertiary) # corruption FIXME
        angles = angles.transpose(0, 1)

        indices = torch.tensor(range(window.start, window.stop),
                               dtype=torch.float)
        relative_indices = indices[None, :] - indices[:, None]
        relative_sin = (relative_indices / 10).sin()
        relative_cos = (relative_indices / 10).cos()

        distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1)
        distances = distances.unsqueeze(0)

        inds = torch.arange(distances.size(-1))
        idx_a = inds[:, None]
        idy_a = inds[None, :]
        chain_angle = self.fast_angle(
            distances, idx_a - 1, idx_a,
            (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0)
        chain_dihedral = self.fast_dihedral(
            distances, idx_a - 1, idx_a - 2, idx_a,
            (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0)
        contact_angles = self.fast_angle(distances, idx_a - 1, idx_a, idy_a)
        contact_dihedrals = self.fast_dihedral(distances, idx_a, idx_a - 1,
                                               idy_a, idy_a - 1)
        into_contact_dihedrals = self.fast_dihedral(distances, idx_a - 1,
                                                    idx_a - 2, idx_a, idy_a)

        angle_features = torch.cat(
            (chain_angle.sin(), chain_angle.cos(), chain_dihedral.sin(),
             chain_dihedral.cos()),
            dim=1)

        orientation = torch.cat(
            (distances, contact_angles.sin(), contact_angles.cos(),
             contact_dihedrals.sin(), contact_dihedrals.cos(),
             into_contact_dihedrals.sin(), into_contact_dihedrals.cos(),
             relative_sin[None], relative_cos[None]),
            dim=0)
        orientation_slice = orientation[:,
                                        torch.
                                        arange(neighbours.connections.size(0)
                                               )[:, None],
                                        neighbours.connections].permute(
                                            1, 2, 0).contiguous()

        # Prepare local features
        dmap = (tertiary[None, :] - tertiary[:, None]).norm(dim=-1)
        closest = torch.arange(tertiary.size(0))
        closest = abs(closest[None, :] - closest[:, None]).topk(15,
                                                                dim=1).indices
        local_features = dmap[torch.arange(dmap.size(0))[:, None],
                              closest] / 100

        protein = SubgraphStructure(torch.zeros_like(inds))

        inputs = (PackedTensor(angle_features), PackedTensor(primary_onehot),
                  PackedTensor(orientation_slice), neighbours, protein)

        targets = (PackedTensor(primary, split=False),
                   PackedTensor(mask_binary, split=False))

        return inputs, targets