def compare(self, source, target): x, y = target[:, :3], source[:, :3] x_o, y_o = target[:, 3:-1].reshape(-1, 3, 3), source[:, 3:-1].reshape(-1, 3, 3) x_i, y_i = target[:, -1], source[:, -1] distance, direction, rotation = relative_orientation(x, y, x_o, y_o) distance_sin = torch.sin((x_i - y_i) / 10)[:, None] distance_cos = torch.cos((x_i - y_i) / 10)[:, None] return torch.cat((gaussian_rbf(distance, *self.rbf), direction, rotation, distance_sin, distance_cos), dim=1)
def __getitem__(self, index): N = self.N index = self.valid_indices[index] window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N)) inds = self.inds[window] primary = self.pris[window] - 1 # add noise: n_positions = random.randrange(max(1, primary.size(0) // 100)) primary[torch.randint(0, primary.size(0), (n_positions,))] = torch.randint(0, 20, (n_positions,)) tertiary = self.ters[:, :, window] tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1)) tertiary = tertiary[:, 1] / 100 ors = self.orientations(tertiary) print(ors.shape) count = tertiary.size(0) x_range = torch.repeat_interleave(torch.arange(count), count * torch.ones(count, dtype=torch.long)) y_range = torch.arange(count ** 2) - x_range * count t_x = tertiary[x_range] t_y = tertiary[y_range] o_x = ors[x_range] o_y = ors[y_range] _, directions, rotations = relative_orientation(t_x, t_y, o_x, o_y) directions = directions.reshape(count, count, *directions.shape[1:]).permute(2, 0, 1).contiguous() rotations = rotations.reshape(count, count, *rotations.shape[1:]).permute(2, 0, 1).contiguous() mask = torch.arange(N) mask = (mask[:, None] - mask[None, :]) > 0 mask = mask.float() directions = mask * directions + (1 - mask) * directions.permute(0, 2, 1).contiguous() rotations = mask * rotations + (1 - mask) * rotations.permute(0, 2, 1).contiguous() protein = SubgraphStructure(torch.zeros(tertiary.size(0), dtype=torch.long)) #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long)) #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices) distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1) distances = distances.unsqueeze(0) primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float) primary_onehot[torch.arange(primary.size(0)), primary] = 1 primary_onehot = primary_onehot.clamp(0, 1) result = distances / 100 result = torch.cat((result, rotations, directions), dim=0) return (result,)