Пример #1
0
def pose_to_net(pose: Pose):
    length = pose.total_residue()
    positions = torch.zeros(4, 3, length)
    sequence = pose.sequence()
    sequence_one_hot = one_hot_encode(
        sequence, sorted(list(AA_ID_DICT.keys()), key=AA_ID_DICT.get))
    angles = torch.zeros(3, length)
    mask = torch.ones(length)
    for idx, residue in enumerate(pose.residues):
        nn = residue.atom("N").xyz()
        ca = residue.atom("CA").xyz()
        if residue.aa() == rosetta.core.chemical.AA.aa_gly:
            cb = residue.atom("1HA").xyz()
        else:
            cb = residue.atom("CB").xyz()
        co = residue.atom("C").xyz()
        positions[:, :, idx] = torch.tensor([nn, ca, cb, co])

    #   phi = pose.phi(idx + 1) / 180 * np.pi
    #   psi = pose.psi(idx + 1) / 180 * np.pi
    #   omega = pose.omega(idx + 1) / 180 * np.pi
    #   angles[:, idx] = torch.tensor([phi, psi, omega])
    # angles = angles.t().contiguous().view(-1).roll(1).view(-1, 3).t().contiguous()
    angles, _ = compute_dihedrals(
        positions[[0, 1, 3]].numpy().transpose(2, 0, 1).reshape(-1, 3),
        torch.ones(positions.size(-1)))
    angles = torch.tensor(angles, dtype=torch.float)
    return positions, angles, sequence_one_hot, mask
Пример #2
0
    def __getitem__(self, index):
        result = super().__getitem__(index)
        primary = result["primary"][:500]
        evolutionary = result["evolutionary"][:, :500].t()
        tertiary = result["tertiary"] / 100
        tertiary = tertiary[[0, 1, 3], :, :].permute(
            2, 0, 1).contiguous()[:500].view(-1, 3)
        angles = result["angles"].contiguous().view(-1, 3)[:500].contiguous()
        mask = result["mask"][:500].view(-1)

        print(angles.min(), angles.max())

        mask = mask  #torch.repeat_interleave(mask, 3)

        membership = SubgraphStructure(
            torch.zeros(primary.size(0), dtype=torch.long))
        primary_onehot = one_hot_encode(primary - 1, range(20)).t()
        primary_onehot = torch.cat((primary_onehot, evolutionary), dim=1)

        inputs = (PackedTensor(primary_onehot), membership)

        outputs = (PackedTensor(tertiary, split=False),
                   PackedTensor(mask.unsqueeze(1), split=False), membership)

        print("tsize", angles.size())

        return inputs, outputs, (PackedTensor(angles, split=False),
                                 PackedTensor(mask.view(-1), split=False))
Пример #3
0
  def sample_residue(self, seq, position, mask=None, argmax=False):
    inds = self.indices[position]
    rot = self.rotations[position]
    sequence = one_hot_encode(seq, self.lookup)
    sequence = sequence[:, inds]
    if mask is not None:
      mask = mask[inds].clone()
    else:
      mask = torch.rand(sequence.size(1)) < self.dropout
    mask[0] = 1
    sequence[:, mask] = 0.0
    sequence = torch.cat((mask.unsqueeze(0).float(), sequence), dim=0)

    tertiary = self.tertiary[:, :, inds].clone()
    tertiary = tertiary - tertiary[0:1, :, 0:1]
    tertiary = torch.tensor(rot, dtype=torch.float) @ tertiary
    tertiary = tertiary.view(-1, tertiary.size(-1)) / 10
    angles = self.angles[:, inds]
    features = torch.cat((
      angles.sin(), angles.cos(), tertiary, sequence
    ), dim=0).unsqueeze(0)

    if argmax:
      logits = self.net(features)
      prediction = logits.argmax(dim=1)
      sample = prediction.view(-1)[0]
    else:
      prediction = self.net(features)
      dist = torch.distributions.Categorical(logits=prediction)
      sample = dist.sample()[0]
    return self.lookup[sample]
Пример #4
0
  def single_score(self, seq, position, mask=None):
    inds = self.indices[position]
    rot = self.rotations[position]
    sequence = one_hot_encode(seq, self.lookup)
    sequence = sequence[:, inds]
    if mask is not None:
      mask = mask[inds].clone()
    else:
      mask = torch.rand(sequence.size(1)) < 0
    mask[0] = 1
    sequence[:, mask] = 0.0
    sequence = torch.cat((mask.unsqueeze(0).float(), sequence), dim=0)

    tertiary = self.tertiary[:, :, inds].clone()
    tertiary = tertiary - tertiary[0:1, :, 0:1]
    tertiary = torch.tensor(rot, dtype=torch.float) @ tertiary
    tertiary = tertiary.view(-1, tertiary.size(-1)) / 10
    angles = self.angles[:, inds]
    features = torch.cat((
      angles.sin(), angles.cos(), tertiary, sequence
    ), dim=0).unsqueeze(0)

    logits = self.net(features)
    result = -logits.softmax(dim=1)[0, AA_ID_DICT[seq[position]] - 1]
    return result
Пример #5
0
 def getfeatures(self, data):
     primary = data["primary"]
     evolutionary = data["evolutionary"]
     position = torch.tensor(range(primary.size(0)),
                             dtype=torch.float32).view(1, -1)
     primary_hot = one_hot_encode(primary, list(range(1, 21)))
     primary = self.tile(primary_hot).to(torch.float)
     evolutionary = self.tile(evolutionary)
     position = self.tile(position).to(torch.float)
     position = torch.cat(
         (torch.sin(position[:1] - position[1:] / 250 * np.pi),
          torch.cos(position[:1] - position[1:] / 250 * np.pi)),
         dim=0)
     return torch.cat((position, primary, evolutionary), dim=0)
Пример #6
0
    def __getitem__(self, index):
        aa = random.randrange(0, 20)
        aa_range = self.aa_indices[aa]
        aa_index = aa_range[index % aa_range.size(0)]
        data = super().__getitem__(aa_index)
        structure = data["tertiary"]
        structure = structure.view(-1, structure.size(-1)) / 1000
        angles = data["angles"]
        sin = torch.sin(angles)
        cos = torch.cos(angles)
        primary_subset = data["primary"] - 1
        primary_subset[0] = -1
        num_unknown = random.randint(0, len(primary_subset))
        set_unknown = random.sample(range(len(primary_subset)), num_unknown)
        primary_subset[set_unknown] = -1
        primary_onehot = one_hot_encode(primary_subset, list(range(-1, 20)))
        # positional = data["indices"][None].to(torch.float)
        structure_features = torch.cat((sin, cos, structure, primary_onehot),
                                       dim=0)

        primary = data["primary"][0] - 1

        return structure_features, primary
Пример #7
0
  def design(self):
    starting_sequence = self.init_sequence(self.sequence)
    logits = self.net(
      self.angle_features,
      starting_sequence,
      self.orientations,
      self.structure
    )
    sample = self.sample(logits)

    if self.fix.size(0) > 0:
      sample[self.fix] = starting_sequence[self.fix, :-1].argmax(dim=1)
    result = "".join(map(lambda x: self.lookup[x], sample))

    for idx in range(self.steps):
      result = "".join(map(lambda x: self.lookup[x], sample))
      log_likelihood = logits[torch.arange(logits.size(0)), sample].sum()
      if log_likelihood >= self.best:
        self.sequence = result
        self.best = log_likelihood
      seq = one_hot_encode(result, self.lookup)
      seq = seq.permute(1, 0)
      zero = self.update_mask(logits)
      zero[self.fix] = 0
      seq[zero] = 0
      mask = zero.float()[:, None]
      seq = torch.cat((seq, mask), dim=1)
      new_logits = self.net(
        self.angle_features,
        seq,
        self.orientations,
        self.structure)
      logits[zero] = new_logits[zero]
      new_sample = self.sample(logits)
      sample[zero] = new_sample[zero]

    return self.sequence, self.best
 def prepare_sequence(self, sequence):
     sequence = one_hot_encode(sequence, list(range(20))).transpose(
         0, 1).to(sequence.device)
     return self.sequence_embedding(sequence)
Пример #9
0
def one_hot_secondary(data, numeric=False):
    """Encodes a secondary structure into one-hot format."""
    return one_hot_encode(data, OneHotSecondary.code, numeric=numeric)
Пример #10
0
def one_hot_aa(data, numeric=False):
    """Encodes a sequence of amino acids into one-hot format."""
    return one_hot_encode(data, OneHotAA.code, numeric=numeric)
Пример #11
0
 def init_sequence(self, sequence):
   encoding = one_hot_encode(sequence, self.lookup).permute(1, 0)
   mask = ~torch.zeros(encoding.size(0), dtype=torch.bool)
   mask[self.fix] = 0
   encoding[mask] = 0
   return torch.cat((encoding, mask[:, None].float()), dim=1)