コード例 #1
0
    def compare(self, source, target):
        x, y = target[:, :3], source[:, :3]
        x_o, y_o = target[:, 3:-1].reshape(-1, 3,
                                           3), source[:,
                                                      3:-1].reshape(-1, 3, 3)
        x_i, y_i = target[:, -1], source[:, -1]

        distance, direction, rotation = relative_orientation(x, y, x_o, y_o)
        distance_sin = torch.sin((x_i - y_i) / 10)[:, None]
        distance_cos = torch.cos((x_i - y_i) / 10)[:, None]
        return torch.cat((gaussian_rbf(distance, *self.rbf), direction,
                          rotation, distance_sin, distance_cos),
                         dim=1)
コード例 #2
0
  def __getitem__(self, index):
    N = self.N
    index = self.valid_indices[index]
    window = slice(self.index[index], min(self.index[index + 1], self.index[index] + N))
    inds = self.inds[window]
    primary = self.pris[window] - 1

    # add noise:
    n_positions = random.randrange(max(1, primary.size(0) // 100))
    primary[torch.randint(0, primary.size(0), (n_positions,))] = torch.randint(0, 20, (n_positions,))

    tertiary = self.ters[:, :, window]
    tertiary, angles = self.backrub(tertiary[[0, 1, 3]].permute(2, 0, 1))
    tertiary = tertiary[:, 1] / 100

    ors = self.orientations(tertiary)
    print(ors.shape)

    count = tertiary.size(0)

    x_range = torch.repeat_interleave(torch.arange(count), count * torch.ones(count, dtype=torch.long))
    y_range = torch.arange(count ** 2) - x_range * count
    t_x = tertiary[x_range]
    t_y = tertiary[y_range]
    o_x = ors[x_range]
    o_y = ors[y_range]
    _, directions, rotations = relative_orientation(t_x, t_y, o_x, o_y)
    directions = directions.reshape(count, count, *directions.shape[1:]).permute(2, 0, 1).contiguous()
    rotations = rotations.reshape(count, count, *rotations.shape[1:]).permute(2, 0, 1).contiguous()

    mask = torch.arange(N)
    mask = (mask[:, None] - mask[None, :]) > 0
    mask = mask.float()

    directions = mask * directions + (1 - mask) * directions.permute(0, 2, 1).contiguous()
    rotations = mask * rotations + (1 - mask) * rotations.permute(0, 2, 1).contiguous()

    protein = SubgraphStructure(torch.zeros(tertiary.size(0), dtype=torch.long))
    #neighbours = ConstantStructure(0, 0, (inds - self.index[index]).to(torch.long))
    #distances, _ = scatter.pairwise_no_pad(lambda x, y: (x - y).norm(dim=1), tertiary, protein.indices)
    distances = (tertiary[None, :, :] - tertiary[:, None, :]).norm(dim=-1)
    distances = distances.unsqueeze(0)

    primary_onehot = torch.zeros(primary.size(0), 20, dtype=torch.float)
    primary_onehot[torch.arange(primary.size(0)), primary] = 1
    primary_onehot = primary_onehot.clamp(0, 1)

    result = distances / 100
    result = torch.cat((result, rotations, directions), dim=0)

    return (result,)