def forward(self, inputs):
    tertiary, subgraph = inputs
    angles = tertiary
    asin = angles.sin()
    acos = angles.cos()
    afeat = torch.cat((asin, acos), dim=1)
    angle_result = self.angle_result(afeat)
    features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices)
    tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
    ors = self.orientations(tertiary)
    pos = tertiary[:, 1]
    inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1)
    distances = torch.cat((pos, ors, inds), dim=1)

    dist, structure = self.knn_structure(tertiary, subgraph)
    neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
    dist = (neighbour_pos).contiguous()
    dist = dist.norm(dim=2, keepdim=True)
    dist = gaussian_rbf(dist, *self.rbf)

    distance_data = RelativeStructure(structure, self.rbf)
    relative_data = distance_data.message(
      distances, distances
    )
    relative_structure = OrientationStructure(structure, relative_data)

    encoding = self.encoder(features, relative_structure)
    encoding = ts.scatter.batched(self.postprocess, encoding, subgraph.indices)
    encoding = torch.cat((features, encoding), dim=1)
    result = self.result(encoding)#self.result(ts.scatter.mean(encoding, subgraph.indices))

    result = torch.cat((result, angle_result), dim=0)

    return result
  def forward(self, sample):
    angles, subgraph = sample
    for idx in range(self.repeats):
      tertiary, _ = self.position_lookup(angles, torch.zeros_like(subgraph.indices))

      asin = angles.sin()
      acos = angles.cos()
      afeat = torch.cat((asin, acos), dim=1)
      features = ts.scatter.batched(self.preprocess_correct, afeat, subgraph.indices)
      ors = self.orientations(tertiary)
      pos = tertiary[:, 1]
      inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1)
      distances = torch.cat((pos, ors, inds), dim=1)

      dist, structure = self.knn_structure(tertiary, subgraph)
      neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
      dist = (neighbour_pos).contiguous()
      dist = dist.norm(dim=2, keepdim=True)
      dist = gaussian_rbf(dist, *self.rbf)

      distance_data = RelativeStructure(structure, self.rbf)
      relative_data = distance_data.message(
        distances, distances
      )
      relative_structure = OrientationStructure(structure, relative_data)

      correction = self.correction(features, relative_structure)
      angles = angles + self.corrected_angles(correction)

    result = ts.PackedTensor(angles, lengths=list(subgraph.counts))

    return result, subgraph
    def forward(self,
                primary,
                gt_ignore,
                angles,
                orientation,
                neighbours,
                protein,
                return_deltas=False):
        assert neighbours.connections.max() < primary.size(0)
        indices = neighbours.connections
        sequence = primary.clone()
        primary = primary[indices].clone()
        primary[:, 0] = 0
        angles = angles[indices]
        relative = RelativeStructure(neighbours, self.rbf)
        orientation = relative.message(orientation, orientation)
        features = torch.cat((primary, angles, orientation),
                             dim=2).permute(0, 2, 1)

        inputs = torch.cat((features, features[:, :, 0:1].expand_as(features)),
                           dim=1)
        in_view = inputs.transpose(2, 1).reshape(-1, 2 * features.size(1))
        p = self.features(in_view)
        w = self.weight(in_view)
        prod = (p * w)
        cat = prod.reshape(inputs.size(0), -1)
        out = self.out(cat)
        this_energy = out[torch.arange(out.size(0)), sequence.argmax(dim=1)]
        differences = this_energy.unsqueeze(-1) - out
        energy = scatter.mean(this_energy.unsqueeze(-1), protein.indices)
        if return_deltas:
            return (energy, -out), out
        return energy, out
    def forward(self, features, sequence, distances, structure):
        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)
        encoding = self.encoder(features, relative_structure)

        sequence = self.prepare_sequence(sequence)
        masked_structure = MaskedStructure(structure, relative_data, sequence,
                                           encoding)
        result = self.decoder(encoding, masked_structure)
        return result
    def forward(self, tertiary, noise, sequence, subgraph):
        noise = noise / 3.15
        offset = (torch.log(noise) / torch.log(torch.tensor(0.60))).long()
        condition = torch.zeros(tertiary.size(0),
                                10,
                                device=tertiary.device,
                                dtype=torch.float)
        print(condition.shape, tertiary.shape, offset.shape, noise.shape)
        condition[torch.arange(offset.size(0)), offset.view(-1)] = 1

        angles = tertiary
        asin = angles.sin()
        acos = angles.cos()
        tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
        if self.angles:
            afeat = torch.cat((asin, acos, condition), dim=1)
        else:
            afeat = torch.cat(
                (tertiary.reshape(tertiary.size(0), -1), condition), dim=1)
        features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices)
        if self.conditional:
            features = sequence
        ors = self.orientations(tertiary)
        pos = tertiary[:, 1]
        inds = torch.arange(0,
                            pos.size(0),
                            dtype=torch.float,
                            device=pos.device).view(-1, 1)
        distances = torch.cat((pos, ors, inds), dim=1)

        dist, structure = self.knn_structure(tertiary, subgraph)
        neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
        dist = (neighbour_pos).contiguous()
        dist = dist.norm(dim=2, keepdim=True)
        dist = gaussian_rbf(dist, *self.rbf)

        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)

        encoding = self.encoder(features, relative_structure)
        encoding = ts.scatter.batched(self.postprocess, encoding,
                                      subgraph.indices)
        encoding = torch.cat((features, encoding), dim=1)
        result = self.result(encoding)

        return result
Beispiel #6
0
    def forward(self, primary, mode, gt_ignore, angles, orientation,
                neighbours, protein):
        assert neighbours.connections.max() < primary.size(0)
        indices = neighbours.connections
        sequence = primary
        primary = primary[indices]
        angles = angles[indices]
        relative = RelativeStructure(neighbours, self.rbf)
        orientation = relative.message(orientation, orientation)
        features = torch.cat((primary, angles, orientation),
                             dim=2).permute(0, 2, 1)

        inputs = torch.cat((features, features[:, :, 0:1].expand_as(features)),
                           dim=1)
        in_view = inputs.transpose(2, 1).reshape(-1, 2 * features.size(1))
        p = self.features(in_view)
        w = self.weight(in_view)
        prod = (p * w)
        cat = prod.reshape(inputs.size(0), -1)
        out = self.out(cat)
        return out
Beispiel #7
0
    def forward(self, tertiary, subgraph):
        features = torch.zeros(tertiary.size(0),
                               self.size,
                               dtype=tertiary.dtype,
                               device=tertiary.device)

        # mask = torch.rand(tertiary.size(0), device=tertiary.device) < 0.2
        # features = tertiary.clone()
        # features[mask] = 0
        # features = torch.cat((features, mask[:, None].float()), dim=1)

        #features = self.preprocess(features)
        tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices))
        ors = self.orientations(tertiary)
        pos = tertiary[:, 1]
        inds = torch.arange(0,
                            pos.size(0),
                            dtype=torch.float,
                            device=pos.device).view(-1, 1)
        distances = torch.cat((pos, ors, inds), dim=1)

        dist, structure = self.knn_structure(tertiary, subgraph)
        neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6)
        dist = (neighbour_pos).contiguous()
        dist = dist.norm(dim=2, keepdim=True)
        dist = gaussian_rbf(dist, *self.rbf)

        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)

        encoding = self.encoder(features, relative_structure)
        encoding = ts.scatter.batched(self.postprocess, encoding,
                                      subgraph.indices)
        #encoding, _ = ts.scatter.sequential(self.rnn, encoding, subgraph.indices)
        result = self.result(encoding)
        result_tertiary, _ = self.lookup(result,
                                         torch.zeros_like(subgraph.indices))

        return result, result_tertiary
Beispiel #8
0
    def forward(self, tertiary, sequence, subgraph):
        features = torch.ones(tertiary.size(0),
                              27,
                              dtype=tertiary.dtype,
                              device=tertiary.device)
        if self.angles:
            angles = tertiary
            asin = angles.sin()
            acos = angles.cos()
            afeat = torch.cat((asin, acos), dim=1)
            features = ts.scatter.batched(self.local_features, afeat,
                                          subgraph.indices)
            tertiary, _ = self.lookup(tertiary,
                                      torch.zeros_like(subgraph.indices))
        if self.conditional:
            features = sequence
        ors = self.orientations(tertiary)
        pos = tertiary[:, 1]
        inds = torch.arange(0,
                            pos.size(0),
                            dtype=torch.float,
                            device=pos.device).view(-1, 1)
        distances = torch.cat((pos, ors, inds), dim=1)

        structure = self.knn_structure(tertiary, subgraph)

        distance_data = RelativeStructure(structure, self.rbf)
        relative_data = distance_data.message(distances, distances)
        relative_structure = OrientationStructure(structure, relative_data)

        encoding = self.encoder(features, relative_structure)
        #weight = self.reweighting(encoding)
        #weight = ts.scatter.softmax(weight, subgraph.indices)
        #encoding = ts.scatter.add(weight * encoding, subgraph.indices)
        encoding = self.pool(encoding, subgraph.indices)
        result = self.energy(encoding)

        return result