Ejemplo n.º 1
0
  def forward(self, sample):
    angles, subgraph = sample
    for idx in range(self.repeats):
      tertiary, _ = self.position_lookup(angles, torch.zeros_like(subgraph.indices))

      asin = angles.sin()
      acos = angles.cos()
      afeat = torch.cat((asin, acos), dim=1)
      features = ts.scatter.batched(self.preprocess_correct, afeat, subgraph.indices)
      ors = self.orientations(tertiary)
      pos = tertiary[:, 1]
      inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1)
      distances = torch.cat((pos, ors, inds), dim=1)

      dist, structure = self.knn_structure(tertiary, subgraph)
      neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
      dist = (neighbour_pos).contiguous()
      dist = dist.norm(dim=2, keepdim=True)
      dist = gaussian_rbf(dist, *self.rbf)

      distance_data = RelativeStructure(structure, self.rbf)
      relative_data = distance_data.message(
        distances, distances
      )
      relative_structure = OrientationStructure(structure, relative_data)

      correction = self.correction(features, relative_structure)
      angles = angles + self.corrected_angles(correction)

    result = ts.PackedTensor(angles, lengths=list(subgraph.counts))

    return result, subgraph
Ejemplo n.º 2
0
  def forward(self, inputs):
    tertiary, subgraph = inputs
    angles = tertiary
    asin = angles.sin()
    acos = angles.cos()
    afeat = torch.cat((asin, acos), dim=1)
    angle_result = self.angle_result(afeat)
    features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices)
    tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
    ors = self.orientations(tertiary)
    pos = tertiary[:, 1]
    inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1)
    distances = torch.cat((pos, ors, inds), dim=1)

    dist, structure = self.knn_structure(tertiary, subgraph)
    neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
    dist = (neighbour_pos).contiguous()
    dist = dist.norm(dim=2, keepdim=True)
    dist = gaussian_rbf(dist, *self.rbf)

    distance_data = RelativeStructure(structure, self.rbf)
    relative_data = distance_data.message(
      distances, distances
    )
    relative_structure = OrientationStructure(structure, relative_data)

    encoding = self.encoder(features, relative_structure)
    encoding = ts.scatter.batched(self.postprocess, encoding, subgraph.indices)
    encoding = torch.cat((features, encoding), dim=1)
    result = self.result(encoding)#self.result(ts.scatter.mean(encoding, subgraph.indices))

    result = torch.cat((result, angle_result), dim=0)

    return result
Ejemplo n.º 3
0
  def forward(self, features, sequence, pair_features, structure, protein=None):
    # featurize distances
    distance = pair_features[:, :, 0]
    distance = gaussian_rbf(distance.view(-1, 1), *self.rbf).reshape(distance.size(0), distance.size(1), -1)

    pair_features = torch.cat((distance, pair_features[:, :, 1:]), dim=2)

    relative_structure = OrientationStructure(structure, pair_features)
    encoding = self.encoder(features, relative_structure)

    # initial evaluation
    sequence = self.prepare_sequence(sequence)
    masked_structure = MaskedStructure(
      structure, pair_features, sequence, encoding
    )
    result = self.decoder(encoding, masked_structure)

    # differentiable scheduled sampling
    for idx in range(self.schedule):
      samples = self.sequence_embedding(hard_one_hot(result))
      mask = torch.rand(samples.size(0), device=samples.device) < 0.25
      mask = mask.unsqueeze(1).float()
      sequence = mask * samples + (1 - mask) * sequence
      masked_structure = MaskedStructure(
        structure, relative_data, sequence, encoding
      )
      result = self.decoder(encoding, masked_structure)

    return result
    def forward(self,
                angle_features,
                sequence,
                pair_features,
                structure,
                protein=None):
        features = torch.cat((angle_features, sequence), dim=1)

        # featurize distances
        distance = pair_features[:, :, 0]
        distance = gaussian_rbf(distance.view(-1, 1),
                                *self.rbf).reshape(distance.size(0),
                                                   distance.size(1), -1)

        print(distance.shape, pair_features.shape)
        pair_features = torch.cat((distance, pair_features[:, :, 1:]), dim=2)

        relative_structure = ...
        if self.sequence:
            relative_structure = SequenceOrientationStructure(
                structure, pair_features, sequence)
        else:
            relative_structure = OrientationStructure(structure, pair_features)
        encoding = self.encoder(features, relative_structure)
        if self.local is not None and protein is not None:
            combined = torch.cat((encoding, features), dim=1)
            encoding = ts.scatter.batched(self.local, combined,
                                          protein.indices)
        result = self.decoder(encoding)

        return result
Ejemplo n.º 5
0
 def compare(self, source, target):
     x, y = source[:, :3], target[:, :3]
     x_i, y_i = target[:, -1], source[:, -1]
     distance = (x - y).norm(dim=1, keepdim=True)
     distance_sin = torch.sin((x_i - y_i) / 10)[:, None]
     distance_cos = torch.cos((x_i - y_i) / 10)[:, None]
     return torch.cat(
         (gaussian_rbf(distance, *self.rbf), distance_sin, distance_cos),
         dim=1)
Ejemplo n.º 6
0
    def compare(self, source, target):
        x, y = target[:, :3], source[:, :3]
        x_o, y_o = target[:, 3:-1].reshape(-1, 3,
                                           3), source[:,
                                                      3:-1].reshape(-1, 3, 3)
        x_i, y_i = target[:, -1], source[:, -1]

        distance, direction, rotation = relative_orientation(x, y, x_o, y_o)
        distance_sin = torch.sin((x_i - y_i) / 10)[:, None]
        distance_cos = torch.cos((x_i - y_i) / 10)[:, None]
        return torch.cat((gaussian_rbf(distance, *self.rbf), direction,
                          rotation, distance_sin, distance_cos),
                         dim=1)
Ejemplo n.º 7
0
    def forward(self, tertiary, noise, sequence, subgraph):
        noise = noise / 3.15
        offset = (torch.log(noise) / torch.log(torch.tensor(0.60))).long()
        condition = torch.zeros(tertiary.size(0),
                                10,
                                device=tertiary.device,
                                dtype=torch.float)
        print(condition.shape, tertiary.shape, offset.shape, noise.shape)
        condition[torch.arange(offset.size(0)), offset.view(-1)] = 1

        angles = tertiary
        asin = angles.sin()
        acos = angles.cos()
        tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
        if self.angles:
            afeat = torch.cat((asin, acos, condition), dim=1)
        else:
            afeat = torch.cat(
                (tertiary.reshape(tertiary.size(0), -1), condition), dim=1)
        features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices)
        if self.conditional:
            features = sequence
        ors = self.orientations(tertiary)
        pos = tertiary[:, 1]
        inds = torch.arange(0,
                            pos.size(0),
                            dtype=torch.float,
                            device=pos.device).view(-1, 1)
        distances = torch.cat((pos, ors, inds), dim=1)

        dist, structure = self.knn_structure(tertiary, subgraph)
        neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
        dist = (neighbour_pos).contiguous()
        dist = dist.norm(dim=2, keepdim=True)
        dist = gaussian_rbf(dist, *self.rbf)

        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)

        encoding = self.encoder(features, relative_structure)
        encoding = ts.scatter.batched(self.postprocess, encoding,
                                      subgraph.indices)
        encoding = torch.cat((features, encoding), dim=1)
        result = self.result(encoding)

        return result
    def forward(self, angle_features, sequence, pair_features, structure):
        features = torch.cat((angle_features, sequence), dim=1)

        # featurize distances
        distance = pair_features[:, :, 0]
        distance = gaussian_rbf(distance.view(-1, 1),
                                *self.rbf).reshape(distance.size(0),
                                                   distance.size(1), -1)

        print(distance.shape, pair_features.shape)
        pair_features = torch.cat((distance, pair_features[:, :, 1:]), dim=2)

        relative_structure = OrientationStructure(structure, pair_features)
        encoding = self.encoder(features, relative_structure)
        result = self.decoder(encoding)

        return result
Ejemplo n.º 9
0
    def forward(self, tertiary, subgraph):
        features = torch.zeros(tertiary.size(0),
                               self.size,
                               dtype=tertiary.dtype,
                               device=tertiary.device)

        # mask = torch.rand(tertiary.size(0), device=tertiary.device) < 0.2
        # features = tertiary.clone()
        # features[mask] = 0
        # features = torch.cat((features, mask[:, None].float()), dim=1)

        #features = self.preprocess(features)
        tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
        ors = self.orientations(tertiary)
        pos = tertiary[:, 1]
        inds = torch.arange(0,
                            pos.size(0),
                            dtype=torch.float,
                            device=pos.device).view(-1, 1)
        distances = torch.cat((pos, ors, inds), dim=1)

        dist, structure = self.knn_structure(tertiary, subgraph)
        neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
        dist = (neighbour_pos).contiguous()
        dist = dist.norm(dim=2, keepdim=True)
        dist = gaussian_rbf(dist, *self.rbf)

        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)

        encoding = self.encoder(features, relative_structure)
        encoding = ts.scatter.batched(self.postprocess, encoding,
                                      subgraph.indices)
        #encoding, _ = ts.scatter.sequential(self.rnn, encoding, subgraph.indices)
        result = self.result(encoding)
        result_tertiary, _ = self.lookup(result,
                                         torch.zeros_like(subgraph.indices))

        return result, result_tertiary
Ejemplo n.º 10
0
  def forward(self, inputs):
    angles, distances, structure = inputs

    print(distances.mean(), distances.min(), distances.max())
    print(distances.shape, angles.shape)

    node_struc, dist_struc = get_distance_structure(structure)

    node_out = self.angle_lookup(angles)
    dist_out = self.distance_lookup(gaussian_rbf(distances, *self.rbf))

    node_out = ts.scatter.batched(self.preproc, node_out, structure.indices)

    node_out, dist_out = self.transformer(node_out, dist_out, node_struc, dist_struc, structure)

    indices = torch.repeat_interleave(structure.unique, structure.counts * (structure.counts - 1) // 2, dim=0)

    angles = self.angle_result(ts.scatter.mean(node_out, structure.indices))
    distances = self.distance_result(dist_out)

    result = torch.cat((angles, distances), dim=0)

    return result