def forward(self, inputs): tertiary, subgraph = inputs angles = tertiary asin = angles.sin() acos = angles.cos() afeat = torch.cat((asin, acos), dim=1) angle_result = self.angle_result(afeat) features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices) tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices)) ors = self.orientations(tertiary) pos = tertiary[:, 1] inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1) distances = torch.cat((pos, ors, inds), dim=1) dist, structure = self.knn_structure(tertiary, subgraph) neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6) dist = (neighbour_pos).contiguous() dist = dist.norm(dim=2, keepdim=True) dist = gaussian_rbf(dist, *self.rbf) distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message( distances, distances ) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) encoding = ts.scatter.batched(self.postprocess, encoding, subgraph.indices) encoding = torch.cat((features, encoding), dim=1) result = self.result(encoding)#self.result(ts.scatter.mean(encoding, subgraph.indices)) result = torch.cat((result, angle_result), dim=0) return result
def forward(self, sample): angles, subgraph = sample for idx in range(self.repeats): tertiary, _ = self.position_lookup(angles, torch.zeros_like(subgraph.indices)) asin = angles.sin() acos = angles.cos() afeat = torch.cat((asin, acos), dim=1) features = ts.scatter.batched(self.preprocess_correct, afeat, subgraph.indices) ors = self.orientations(tertiary) pos = tertiary[:, 1] inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1) distances = torch.cat((pos, ors, inds), dim=1) dist, structure = self.knn_structure(tertiary, subgraph) neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6) dist = (neighbour_pos).contiguous() dist = dist.norm(dim=2, keepdim=True) dist = gaussian_rbf(dist, *self.rbf) distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message( distances, distances ) relative_structure = OrientationStructure(structure, relative_data) correction = self.correction(features, relative_structure) angles = angles + self.corrected_angles(correction) result = ts.PackedTensor(angles, lengths=list(subgraph.counts)) return result, subgraph
def forward(self, primary, gt_ignore, angles, orientation, neighbours, protein, return_deltas=False): assert neighbours.connections.max() < primary.size(0) indices = neighbours.connections sequence = primary.clone() primary = primary[indices].clone() primary[:, 0] = 0 angles = angles[indices] relative = RelativeStructure(neighbours, self.rbf) orientation = relative.message(orientation, orientation) features = torch.cat((primary, angles, orientation), dim=2).permute(0, 2, 1) inputs = torch.cat((features, features[:, :, 0:1].expand_as(features)), dim=1) in_view = inputs.transpose(2, 1).reshape(-1, 2 * features.size(1)) p = self.features(in_view) w = self.weight(in_view) prod = (p * w) cat = prod.reshape(inputs.size(0), -1) out = self.out(cat) this_energy = out[torch.arange(out.size(0)), sequence.argmax(dim=1)] differences = this_energy.unsqueeze(-1) - out energy = scatter.mean(this_energy.unsqueeze(-1), protein.indices) if return_deltas: return (energy, -out), out return energy, out
def forward(self, features, sequence, distances, structure): distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message(distances, distances) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) sequence = self.prepare_sequence(sequence) masked_structure = MaskedStructure(structure, relative_data, sequence, encoding) result = self.decoder(encoding, masked_structure) return result
def forward(self, tertiary, noise, sequence, subgraph): noise = noise / 3.15 offset = (torch.log(noise) / torch.log(torch.tensor(0.60))).long() condition = torch.zeros(tertiary.size(0), 10, device=tertiary.device, dtype=torch.float) print(condition.shape, tertiary.shape, offset.shape, noise.shape) condition[torch.arange(offset.size(0)), offset.view(-1)] = 1 angles = tertiary asin = angles.sin() acos = angles.cos() tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices)) if self.angles: afeat = torch.cat((asin, acos, condition), dim=1) else: afeat = torch.cat( (tertiary.reshape(tertiary.size(0), -1), condition), dim=1) features = ts.scatter.batched(self.preprocess, afeat, subgraph.indices) if self.conditional: features = sequence ors = self.orientations(tertiary) pos = tertiary[:, 1] inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1) distances = torch.cat((pos, ors, inds), dim=1) dist, structure = self.knn_structure(tertiary, subgraph) neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6) dist = (neighbour_pos).contiguous() dist = dist.norm(dim=2, keepdim=True) dist = gaussian_rbf(dist, *self.rbf) distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message(distances, distances) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) encoding = ts.scatter.batched(self.postprocess, encoding, subgraph.indices) encoding = torch.cat((features, encoding), dim=1) result = self.result(encoding) return result
def forward(self, primary, mode, gt_ignore, angles, orientation, neighbours, protein): assert neighbours.connections.max() < primary.size(0) indices = neighbours.connections sequence = primary primary = primary[indices] angles = angles[indices] relative = RelativeStructure(neighbours, self.rbf) orientation = relative.message(orientation, orientation) features = torch.cat((primary, angles, orientation), dim=2).permute(0, 2, 1) inputs = torch.cat((features, features[:, :, 0:1].expand_as(features)), dim=1) in_view = inputs.transpose(2, 1).reshape(-1, 2 * features.size(1)) p = self.features(in_view) w = self.weight(in_view) prod = (p * w) cat = prod.reshape(inputs.size(0), -1) out = self.out(cat) return out
def forward(self, tertiary, subgraph): features = torch.zeros(tertiary.size(0), self.size, dtype=tertiary.dtype, device=tertiary.device) # mask = torch.rand(tertiary.size(0), device=tertiary.device) < 0.2 # features = tertiary.clone() # features[mask] = 0 # features = torch.cat((features, mask[:, None].float()), dim=1) #features = self.preprocess(features) tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices)) ors = self.orientations(tertiary) pos = tertiary[:, 1] inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1) distances = torch.cat((pos, ors, inds), dim=1) dist, structure = self.knn_structure(tertiary, subgraph) neighbour_pos = (pos[:, None] - pos[structure.connections] + 1e-6) dist = (neighbour_pos).contiguous() dist = dist.norm(dim=2, keepdim=True) dist = gaussian_rbf(dist, *self.rbf) distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message(distances, distances) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) encoding = ts.scatter.batched(self.postprocess, encoding, subgraph.indices) #encoding, _ = ts.scatter.sequential(self.rnn, encoding, subgraph.indices) result = self.result(encoding) result_tertiary, _ = self.lookup(result, torch.zeros_like(subgraph.indices)) return result, result_tertiary
def forward(self, tertiary, sequence, subgraph): features = torch.ones(tertiary.size(0), 27, dtype=tertiary.dtype, device=tertiary.device) if self.angles: angles = tertiary asin = angles.sin() acos = angles.cos() afeat = torch.cat((asin, acos), dim=1) features = ts.scatter.batched(self.local_features, afeat, subgraph.indices) tertiary, _ = self.lookup(tertiary, torch.zeros_like(subgraph.indices)) if self.conditional: features = sequence ors = self.orientations(tertiary) pos = tertiary[:, 1] inds = torch.arange(0, pos.size(0), dtype=torch.float, device=pos.device).view(-1, 1) distances = torch.cat((pos, ors, inds), dim=1) structure = self.knn_structure(tertiary, subgraph) distance_data = RelativeStructure(structure, self.rbf) relative_data = distance_data.message(distances, distances) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) #weight = self.reweighting(encoding) #weight = ts.scatter.softmax(weight, subgraph.indices) #encoding = ts.scatter.add(weight * encoding, subgraph.indices) encoding = self.pool(encoding, subgraph.indices) result = self.energy(encoding) return result