def pose_to_net(pose: Pose): length = pose.total_residue() positions = torch.zeros(4, 3, length) sequence = pose.sequence() sequence_one_hot = one_hot_encode( sequence, sorted(list(AA_ID_DICT.keys()), key=AA_ID_DICT.get)) angles = torch.zeros(3, length) mask = torch.ones(length) for idx, residue in enumerate(pose.residues): nn = residue.atom("N").xyz() ca = residue.atom("CA").xyz() if residue.aa() == rosetta.core.chemical.AA.aa_gly: cb = residue.atom("1HA").xyz() else: cb = residue.atom("CB").xyz() co = residue.atom("C").xyz() positions[:, :, idx] = torch.tensor([nn, ca, cb, co]) # phi = pose.phi(idx + 1) / 180 * np.pi # psi = pose.psi(idx + 1) / 180 * np.pi # omega = pose.omega(idx + 1) / 180 * np.pi # angles[:, idx] = torch.tensor([phi, psi, omega]) # angles = angles.t().contiguous().view(-1).roll(1).view(-1, 3).t().contiguous() angles, _ = compute_dihedrals( positions[[0, 1, 3]].numpy().transpose(2, 0, 1).reshape(-1, 3), torch.ones(positions.size(-1))) angles = torch.tensor(angles, dtype=torch.float) return positions, angles, sequence_one_hot, mask
def __getitem__(self, index): result = super().__getitem__(index) primary = result["primary"][:500] evolutionary = result["evolutionary"][:, :500].t() tertiary = result["tertiary"] / 100 tertiary = tertiary[[0, 1, 3], :, :].permute( 2, 0, 1).contiguous()[:500].view(-1, 3) angles = result["angles"].contiguous().view(-1, 3)[:500].contiguous() mask = result["mask"][:500].view(-1) print(angles.min(), angles.max()) mask = mask #torch.repeat_interleave(mask, 3) membership = SubgraphStructure( torch.zeros(primary.size(0), dtype=torch.long)) primary_onehot = one_hot_encode(primary - 1, range(20)).t() primary_onehot = torch.cat((primary_onehot, evolutionary), dim=1) inputs = (PackedTensor(primary_onehot), membership) outputs = (PackedTensor(tertiary, split=False), PackedTensor(mask.unsqueeze(1), split=False), membership) print("tsize", angles.size()) return inputs, outputs, (PackedTensor(angles, split=False), PackedTensor(mask.view(-1), split=False))
def sample_residue(self, seq, position, mask=None, argmax=False): inds = self.indices[position] rot = self.rotations[position] sequence = one_hot_encode(seq, self.lookup) sequence = sequence[:, inds] if mask is not None: mask = mask[inds].clone() else: mask = torch.rand(sequence.size(1)) < self.dropout mask[0] = 1 sequence[:, mask] = 0.0 sequence = torch.cat((mask.unsqueeze(0).float(), sequence), dim=0) tertiary = self.tertiary[:, :, inds].clone() tertiary = tertiary - tertiary[0:1, :, 0:1] tertiary = torch.tensor(rot, dtype=torch.float) @ tertiary tertiary = tertiary.view(-1, tertiary.size(-1)) / 10 angles = self.angles[:, inds] features = torch.cat(( angles.sin(), angles.cos(), tertiary, sequence ), dim=0).unsqueeze(0) if argmax: logits = self.net(features) prediction = logits.argmax(dim=1) sample = prediction.view(-1)[0] else: prediction = self.net(features) dist = torch.distributions.Categorical(logits=prediction) sample = dist.sample()[0] return self.lookup[sample]
def single_score(self, seq, position, mask=None): inds = self.indices[position] rot = self.rotations[position] sequence = one_hot_encode(seq, self.lookup) sequence = sequence[:, inds] if mask is not None: mask = mask[inds].clone() else: mask = torch.rand(sequence.size(1)) < 0 mask[0] = 1 sequence[:, mask] = 0.0 sequence = torch.cat((mask.unsqueeze(0).float(), sequence), dim=0) tertiary = self.tertiary[:, :, inds].clone() tertiary = tertiary - tertiary[0:1, :, 0:1] tertiary = torch.tensor(rot, dtype=torch.float) @ tertiary tertiary = tertiary.view(-1, tertiary.size(-1)) / 10 angles = self.angles[:, inds] features = torch.cat(( angles.sin(), angles.cos(), tertiary, sequence ), dim=0).unsqueeze(0) logits = self.net(features) result = -logits.softmax(dim=1)[0, AA_ID_DICT[seq[position]] - 1] return result
def getfeatures(self, data): primary = data["primary"] evolutionary = data["evolutionary"] position = torch.tensor(range(primary.size(0)), dtype=torch.float32).view(1, -1) primary_hot = one_hot_encode(primary, list(range(1, 21))) primary = self.tile(primary_hot).to(torch.float) evolutionary = self.tile(evolutionary) position = self.tile(position).to(torch.float) position = torch.cat( (torch.sin(position[:1] - position[1:] / 250 * np.pi), torch.cos(position[:1] - position[1:] / 250 * np.pi)), dim=0) return torch.cat((position, primary, evolutionary), dim=0)
def __getitem__(self, index): aa = random.randrange(0, 20) aa_range = self.aa_indices[aa] aa_index = aa_range[index % aa_range.size(0)] data = super().__getitem__(aa_index) structure = data["tertiary"] structure = structure.view(-1, structure.size(-1)) / 1000 angles = data["angles"] sin = torch.sin(angles) cos = torch.cos(angles) primary_subset = data["primary"] - 1 primary_subset[0] = -1 num_unknown = random.randint(0, len(primary_subset)) set_unknown = random.sample(range(len(primary_subset)), num_unknown) primary_subset[set_unknown] = -1 primary_onehot = one_hot_encode(primary_subset, list(range(-1, 20))) # positional = data["indices"][None].to(torch.float) structure_features = torch.cat((sin, cos, structure, primary_onehot), dim=0) primary = data["primary"][0] - 1 return structure_features, primary
def design(self): starting_sequence = self.init_sequence(self.sequence) logits = self.net( self.angle_features, starting_sequence, self.orientations, self.structure ) sample = self.sample(logits) if self.fix.size(0) > 0: sample[self.fix] = starting_sequence[self.fix, :-1].argmax(dim=1) result = "".join(map(lambda x: self.lookup[x], sample)) for idx in range(self.steps): result = "".join(map(lambda x: self.lookup[x], sample)) log_likelihood = logits[torch.arange(logits.size(0)), sample].sum() if log_likelihood >= self.best: self.sequence = result self.best = log_likelihood seq = one_hot_encode(result, self.lookup) seq = seq.permute(1, 0) zero = self.update_mask(logits) zero[self.fix] = 0 seq[zero] = 0 mask = zero.float()[:, None] seq = torch.cat((seq, mask), dim=1) new_logits = self.net( self.angle_features, seq, self.orientations, self.structure) logits[zero] = new_logits[zero] new_sample = self.sample(logits) sample[zero] = new_sample[zero] return self.sequence, self.best
def prepare_sequence(self, sequence): sequence = one_hot_encode(sequence, list(range(20))).transpose( 0, 1).to(sequence.device) return self.sequence_embedding(sequence)
def one_hot_secondary(data, numeric=False): """Encodes a secondary structure into one-hot format.""" return one_hot_encode(data, OneHotSecondary.code, numeric=numeric)
def one_hot_aa(data, numeric=False): """Encodes a sequence of amino acids into one-hot format.""" return one_hot_encode(data, OneHotAA.code, numeric=numeric)
def init_sequence(self, sequence): encoding = one_hot_encode(sequence, self.lookup).permute(1, 0) mask = ~torch.zeros(encoding.size(0), dtype=torch.bool) mask[self.fix] = 0 encoding[mask] = 0 return torch.cat((encoding, mask[:, None].float()), dim=1)