def __getitem__(self, index): window = slice(self.index[index], min(self.index[index + 1], self.index[index] + 64)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) tertiary = tertiary[:, 1] / 100 protein = SubgraphStructure( torch.zeros(tertiary.size(0), dtype=torch.long)) #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long)) distances, _ = scatter.pairwise_no_pad( lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices) distances = distances[:, None] primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) #assert neighbours.connections.max() < primary_onehot.size(0) inputs = (PackedTensor(angles.permute(1, 0).contiguous()), PackedTensor(distances), protein) return inputs
def __getitem__(self, index): window = slice(self.index[index], self.index[index + 1]) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances = tertiary.permute(2, 0, 1) / 100 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) assert neighbours.connections.max() < primary_onehot.size(0) inputs = (PackedTensor(distances), PackedTensor(primary_onehot), protein) return inputs
def __getitem__(self, index): window = slice(self.index[index], min(self.index[index + 1], self.index[index] + 500)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) angles = angles.permute(1, 0).contiguous() angles_gt = angles.clone() # angles = angles.roll(1, dims=0) # angles[0] = 0 distances = distances / 100 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = self.autoregressive_structure(distances) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) assert neighbours.connections.max() < primary_onehot.size(0) inputs = (PackedTensor(angles), PackedTensor(distances), neighbours, protein) return inputs, PackedTensor( angles_gt, split=False ) #, (PackedTensor(distances, split=False), PackedTensor(torch.ones(distances.size(0)), split=False), protein)
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) tertiary = tertiary.reshape(-1, 3) / 100 protein = SubgraphStructure( torch.zeros(tertiary.size(0), dtype=torch.long)) #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long)) #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices) distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1) distances = distances.unsqueeze(0) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) return (distances / 100, )
def __getitem__(self, index): window = slice(self.index[index], self.index[index + 1]) inds = self.inds[window] primary = self.pris[window] - 1 if primary.size(0) < 30: return GANNet.__getitem__(self, (index + 1) % len(self)) # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) angles = angles.permute(1, 0).contiguous() distances = distances / 100 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) assert neighbours.connections.max() < primary_onehot.size(0) inputs = (PackedTensor(angles), protein) return inputs, PackedTensor( angles, split=False), (PackedTensor(distances, split=False), PackedTensor(torch.ones(distances.size(0)), split=False), protein)
def __getitem__(self, index): result = super().__getitem__(index) primary = result["primary"][:500] evolutionary = result["evolutionary"][:, :500].t() tertiary = result["tertiary"] / 100 tertiary = tertiary[[0, 1, 3], :, :].permute( 2, 0, 1).contiguous()[:500].view(-1, 3) angles = result["angles"].contiguous().view(-1, 3)[:500].contiguous() mask = result["mask"][:500].view(-1) print(angles.min(), angles.max()) mask = mask #torch.repeat_interleave(mask, 3) membership = SubgraphStructure( torch.zeros(primary.size(0), dtype=torch.long)) primary_onehot = one_hot_encode(primary - 1, range(20)).t() primary_onehot = torch.cat((primary_onehot, evolutionary), dim=1) inputs = (PackedTensor(primary_onehot), membership) outputs = (PackedTensor(tertiary, split=False), PackedTensor(mask.unsqueeze(1), split=False), membership) print("tsize", angles.size()) return inputs, outputs, (PackedTensor(angles, split=False), PackedTensor(mask.view(-1), split=False))
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) distances = distances[:, 1] / 100 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) #assert neighbours.connections.max() < primary_onehot.size(0) inputs = (PackedTensor( (angles + 0.01 * torch.randn_like(angles)).permute(1, 0)), PackedTensor(primary_onehot), protein) return inputs
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions,))] = torch.randint(0, 20, (n_positions,)) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) tertiary = tertiary[:, 1] / 100 ors = self.orientations(tertiary) print(ors.shape) count = tertiary.size(0) x_range = torch.repeat_interleave(torch.arange(count), count * torch.ones(count, dtype=torch.long)) y_range = torch.arange(count ** 2) - x_range * count t_x = tertiary[x_range] t_y = tertiary[y_range] o_x = ors[x_range] o_y = ors[y_range] _, directions, rotations = relative_orientation(t_x, t_y, o_x, o_y) directions = directions.reshape(count, count, *directions.shape[1:]).permute(2, 0, 1).contiguous() rotations = rotations.reshape(count, count, *rotations.shape[1:]).permute(2, 0, 1).contiguous() mask = torch.arange(N) mask = (mask[:, None] - mask[None, :]) > 0 mask = mask.float() directions = mask * directions + (1 - mask) * directions.permute(0, 2, 1).contiguous() rotations = mask * rotations + (1 - mask) * rotations.permute(0, 2, 1).contiguous() protein = SubgraphStructure(torch.zeros(tertiary.size(0), dtype=torch.long)) #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long)) #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices) distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1) distances = distances.unsqueeze(0) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) result = distances / 100 result = torch.cat((result, rotations, directions), dim=0) return (result,)
def __getitem__(self, index): N = self.N index = self.valid_indices[index] start = self.index[index] end = self.index[index + 1] rstart = start if end - start > N: rstart = random.randrange(start, end - N) rend = rstart + N window = slice(rstart, rend) inds = self.inds[window] primary = self.pris[window] - 1 # get sequence positions keeps = self.keeps[window] keeps = keeps - keeps[0] # add noise: n_positions = random.randrange(max(1, primary.size(0) // 5)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) distances = tertiary.permute(2, 0, 1)[:, 1] / 100 # distances = (distances[None, :] - distances[:, None]).norm(dim=-1).unsqueeze(0) / 40 # distances = 0.99 * distances + 0.01 * torch.rand_like(distances) # distances = distances.clamp(0, 1) # print(distances.max()) sequence = torch.zeros(42, N, N) sequence[-1] = 1 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot + 0.1 * torch.rand_like(primary_onehot) primary_onehot = primary_onehot.clamp(0, 1) inputs = ((distances, primary_onehot.transpose(0, 1)), keeps) return inputs
def __getitem__(self, index): window = slice(self.index[index], self.index[index + 1]) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) evolutionary = self.evos[:, window] tertiary = self.ters[:, :, window] orientation = self.ors[window, :, :].view(window.stop - window.start, -1) distances = self.ters[1, :, window].transpose(0, 1) / 100 indices = torch.tensor(range(window.start, window.stop), dtype=torch.float) indices = indices.view(-1, 1) orientation = torch.cat((distances, orientation, indices), dim=1) angles = self.angs[:, window].transpose(0, 1) protein = SubgraphStructure( torch.zeros(indices.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) sin = torch.sin(angles) cos = torch.cos(angles) angle_features = torch.cat((sin, cos), dim=1) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 #primary_onehot = primary_onehot + 0.05 * torch.randn_like(primary_onehot) primary_onehot = primary_onehot.clamp(0, 1) assert neighbours.connections.max() < primary_onehot.size(0) inputs = ( PackedTensor(primary_onehot), PackedTensor(primary_onehot), # gt_ignore PackedTensor(angle_features), PackedTensor(orientation), neighbours, protein) return inputs, PackedTensor(primary, split=False)
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 5)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] distances, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) distances = distances[:, 1] / 100 distances = (distances[None, :] - distances[:, None]).norm(dim=-1).unsqueeze(0) / 100 distances = 0.99 * distances + 0.01 * torch.rand_like(distances) sequence = torch.zeros(42, N, N) sequence[-1] = 1 protein = SubgraphStructure( torch.zeros(distances.size(0), dtype=torch.long)) neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) # sequence[:20, :, :] = primary_onehot.permute(1, 0)[:, None, :] # sequence[20:40, :, :] = primary_onehot.permute(1, 0)[:, :, None] # # mask = torch.rand(primary.size(0)) < torch.rand(1) # # sequence[:, mask, :] = 0 # sequence[-1, mask, :] = 1 # sequence[:, :, mask] = 0 # sequence[-1, :, mask] = 1 #assert neighbours.connections.max() < primary_onehot.size(0) inputs = (distances, sequence) return inputs
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions, ))] = torch.randint( 0, 20, (n_positions, )) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) left = torch.repeat_interleave(torch.arange(3), 3 * torch.ones(3, dtype=torch.long)) right = torch.arange(9) - 3 * left tertiary = tertiary / 100 protein = SubgraphStructure( torch.zeros(tertiary.size(0), dtype=torch.long)) #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long)) #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices) distances = (tertiary[None, :, right, :] - tertiary[:, None, left, :]).norm(dim=-1) distances = distances.permute(2, 0, 1).contiguous() #distances = distances.unsqueeze(0) mask = torch.arange(N) mask = (mask[:, None] - mask[None, :]) > 0 mask = mask.float() distances = mask * distances + (1 - mask) * distances.permute(0, 2, 1) #distances[1:] = distances[1:] - distances[:1] primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) return (distances / 100, )
def __getitem__(self, index): # Extract the boundaries of a whole protein window = slice(self.index[index], self.index[index + 1]) seq_len = window.stop - window.start # Make me a mask # Predict at least 5% of the sequence up to the whole seq mask = np.random.choice(seq_len, size=np.random.randint(1, seq_len), replace=False) mask = torch.tensor(mask, dtype=torch.long) mask_binary = torch.zeros(seq_len, dtype=torch.uint8) mask_binary[mask] = 1 # Get sequence info primary = self.pris[window] - 1 primary_masked = primary.clone() primary_masked[mask] = 20 primary_onehot = torch.zeros((seq_len, 21), dtype=torch.float) primary_onehot[torch.arange(seq_len), primary_masked] = 1 # Prepare neighborhood structure inds = self.inds[window] neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to( torch.long)) # Prepare orientation infos orientation = self.ors[window, :, :].view(seq_len, -1) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) tertiary = tertiary[:, 1] / 100 #tertiary = tertiary + 0.01 * torch.randn_like(tertiary) # corruption FIXME angles = angles.transpose(0, 1) indices = torch.tensor(range(window.start, window.stop), dtype=torch.float) relative_indices = indices[None, :] - indices[:, None] relative_sin = (relative_indices / 10).sin() relative_cos = (relative_indices / 10).cos() distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1) distances = distances.unsqueeze(0) inds = torch.arange(distances.size(-1)) idx_a = inds[:, None] idy_a = inds[None, :] chain_angle = self.fast_angle( distances, idx_a - 1, idx_a, (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0) chain_dihedral = self.fast_dihedral( distances, idx_a - 1, idx_a - 2, idx_a, (idx_a + 1) % distances.size(-1))[:, :, 0].permute(1, 0) contact_angles = self.fast_angle(distances, idx_a - 1, idx_a, idy_a) contact_dihedrals = self.fast_dihedral(distances, idx_a, idx_a - 1, idy_a, idy_a - 1) into_contact_dihedrals = self.fast_dihedral(distances, idx_a - 1, idx_a - 2, idx_a, idy_a) angle_features = torch.cat( (chain_angle.sin(), chain_angle.cos(), chain_dihedral.sin(), chain_dihedral.cos()), dim=1) orientation = torch.cat( (distances, contact_angles.sin(), contact_angles.cos(), contact_dihedrals.sin(), contact_dihedrals.cos(), into_contact_dihedrals.sin(), into_contact_dihedrals.cos(), relative_sin[None], relative_cos[None]), dim=0) orientation_slice = orientation[:, torch. arange(neighbours.connections.size(0) )[:, None], neighbours.connections].permute( 1, 2, 0).contiguous() # Prepare local features dmap = (tertiary[None, :] - tertiary[:, None]).norm(dim=-1) closest = torch.arange(tertiary.size(0)) closest = abs(closest[None, :] - closest[:, None]).topk(15, dim=1).indices local_features = dmap[torch.arange(dmap.size(0))[:, None], closest] / 100 protein = SubgraphStructure(torch.zeros_like(inds)) inputs = (PackedTensor(angle_features), PackedTensor(primary_onehot), PackedTensor(orientation_slice), neighbours, protein) targets = (PackedTensor(primary, split=False), PackedTensor(mask_binary, split=False)) return inputs, targets