def forward(self, affine:QuatAffine, representations_list: Sequence[torch.Tensor], aatype:torch.Tensor): act = [self.input_projection[i](self.relu(x)) for i, x in enumerate(representations_list)] act = sum(act) for i in range(self.config.num_residual_block): old_act = act act = self.resblock1[i](self.relu(act)) act = self.resblock2[i](self.relu(act)) act += old_act unnormalized_angles = self.unnormalized_angles(self.relu(act)) unnormalized_angles = unnormalized_angles.view(act.size(0), 7, 2) angles = self.l2_normalize(unnormalized_angles, dim=-1) affine.cast_to(dtype=torch.float32) angles = angles.to(dtype=torch.float32) backb_to_global = affine.to_rigids() # print('Sidechain angles dtype:', angles.dtype) # print('Sidechain frames dtype:', backb_to_global.rot.xx.dtype) all_frames_to_global = protein.torsion_angles_to_frames(aatype, backb_to_global, angles) pred_positions = protein.frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global) affine.cast_to(dtype=act.dtype) # print('Sidechain all_frames_to_global dtype:', all_frames_to_global.rot.xx.dtype) # print('Sidechain pred_positions dtype:', pred_positions.x.dtype) outputs = { 'angles_sin_cos': angles, 'unnormalized_angles_sin_cos': unnormalized_angles, 'atom_pos': pred_positions, 'frames': all_frames_to_global} return outputs
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn_single = InvariantPointAttention( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_single.load_weights_from_af2(params, rel_path='invariant_point_attention') attn_batch = InvariantPointAttentionB( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_batch.load_weights_from_af2(params, rel_path='invariant_point_attention') print('inputs1d:', feat['inputs_1d'].size()) print('inputs2d:', feat['inputs_2d'].size()) print('activations:', feat['activations'].size()) print('mask:', feat['mask'].size()) batch_size = 8 inputs_1d_batch = feat['inputs_1d'][None, ...].repeat(batch_size, 1, 1) inputs_2d_batch = feat['inputs_2d'][None, ...].repeat(batch_size, 1, 1, 1) activations_batch = feat['activations'][None, ...].repeat(batch_size, 1, 1) mask_batch = feat['mask'][None, ...].repeat(batch_size, 1, 1) qa_single = QuatAffine.from_tensor( feat['activations'].to(dtype=torch.float32)) qa_batch = QuatAffine.from_tensor( activations_batch.to(dtype=torch.float32)) res_single = attn_single(inputs_1d=feat['inputs_1d'], inputs_2d=feat['inputs_2d'], mask=feat['mask'], affine=qa_single) res_batch = attn_batch(inputs_1d=inputs_1d_batch, inputs_2d=inputs_2d_batch, mask=mask_batch, affine=qa_batch) print(check_recursive(res_single, res)) print(check_recursive(res_batch[0, ...], res)) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-2
def pre_compose_test(args, name): (tensor, update), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) this_res = qa.pre_compose(update).to_tensor() err, max_err, mean_err = check_recursive(res, this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def invert_point_test(args, name): (tensor, point), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) this_res = qa.invert_point(point) err, max_err, mean_err = check_recursive(res, this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def test_torsion_angles_to_frames(args): print('test_torsion_angles_to_frames') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_torsion_angles_to_frames') rigs = QuatAffine.from_tensor(activations).to_rigids() this_res = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = affine.rigids_to_tensor_flat12(this_res) print(check_recursive(this_res, res))
def forward(self, activations:torch.Tensor, sequence_mask:torch.Tensor, update_affine:bool, initial_act:torch.Tensor, is_training:bool=False, static_feat_2d:torch.Tensor=None, aatype:torch.Tensor=None): affine = QuatAffine.from_tensor(activations['affine']) act = activations['act'] affine.cast_to(dtype=act.dtype) #Casting to float16 attn = self.attention_module(inputs_1d=act, inputs_2d=static_feat_2d, mask=sequence_mask, affine=affine) act = act + attn act = self.attention_layer_norm(act) input_act = act for i in range(self.config.num_layer_in_transition): act = self.transition[i](act) if i < self.config.num_layer_in_transition - 1: act = self.relu(act) act = act + input_act act = self.transition_layer_norm(act) if update_affine: affine_update = self.affine_update(act) #Not sure if it helps affine.cast_to(dtype=torch.float32) #Casting to float32 to avoid multiplying rotations in float16 affine = affine.pre_compose(affine_update.to(dtype=torch.float32)) affine.cast_to(dtype=act.dtype) #Casting back to float16 sc = self.side_chain(affine.scale_translation(self.config.position_scale), [act, initial_act], aatype) #Casting final affines to float32 outputs = {'affine': affine.to_tensor().to(dtype=torch.float32), 'sc': sc} new_activations = {'act': act, 'affine': affine.apply_rotation_tensor_fn(torch.detach).to_tensor()} return new_activations, outputs
def init_test(args, name): args, res = load_data(args, name) qa = QuatAffine.from_tensor(*args) this_res = qa.quaternion, qa.translation, qa.rotation for a,b in zip(res, this_res): err, max_err, mean_err = check_recursive(a,b) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def to_tensor_test(args, name): (tensor, ), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0) this_res = qa.to_tensor() err, max_err, mean_err = check_recursive(res,this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def test_frames_and_literature_positions_to_atom14_pos(args): print('test_frames_and_literature_positions_to_atom14_pos') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_frames_and_literature_positions_to_atom14_pos') rigs = QuatAffine.from_tensor(activations).to_rigids() all_frames = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = frames_and_literature_positions_to_atom14_pos(aatype=aatype, all_frames_to_global=all_frames) this_res = affine.vecs_to_tensor(this_res) print(check_recursive(this_res, res))
def apply_rot_func_test(args, name): (tensor, ), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0) this_res = qa.quaternion, qa.translation, qa.rotation for a,b in zip(res, this_res): err, max_err, mean_err = check_recursive(a,b) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn = InvariantPointAttention( conf, global_config, num_res=feat['inputs_1d'].shape[-2], num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn.load_weights_from_af2(params, rel_path='invariant_point_attention') qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32)) this_res = attn(inputs_1d = feat['inputs_1d'], inputs_2d = feat['inputs_2d'], mask=feat['mask'], affine=qa) print(check_recursive(this_res, res))
def loss(self, value: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_affine = QuatAffine.from_tensor( value['structure_module']['final_affines']) true_affine = QuatAffine.from_tensor(batch['backbone_affine_tensor']) mask = batch['backbone_affine_mask'] square_mask = mask[:, None] * mask[None, :] num_bins = self.config.num_bins breaks = value['predicted_aligned_error']['breaks'] logits = value['predicted_aligned_error']['logits'] def _local_frame_points(affine): points = [x.unsqueeze(dim=-2) for x in affine.translation] return affine.invert_point(points, extra_dims=1) error_dist2_xyz = [ torch.square(x - y) for x, y in zip(_local_frame_points(predicted_affine), _local_frame_points(true_affine)) ] error_dist2 = sum(error_dist2_xyz) error_dist2.detach() sq_breaks = torch.square(breaks) true_bins = torch.sum( (error_dist2[..., None] > sq_breaks).to(dtype=torch.int32), dim=-1) errors = self.loss_function(logits, true_bins) loss = torch.sum(errors * square_mask, dim=(-2, -1)) / ( 1e-8 + torch.sum(square_mask, dim=(-2, -1))) if self.config.filter_by_resolution: loss *= ((batch['resolution'] >= self.config.min_resolution) & (batch['resolution'] < self.config.max_resolution)).to( dtype=torch.float32) return {'loss': loss}
def MultiRigidSidechainTest(args, config, global_config): print('MultiRigidSidechainTest') feat, params, res = load_data(args, 'MultiRigidSidechain') conf = config.model.heads.structure_module.sidechain for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn = MultiRigidSidechain( conf, global_config, repr_dim=feat['representations_list'][0].shape[-1], num_repr=len(feat['representations_list']) ) attn.load_weights_from_af2(params, rel_path='rigid_sidechain') qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32)) this_res = attn(affine = qa, representations_list = feat['representations_list'], aatype=feat['aatype']) this_res['atom_pos'] = affine.vecs_to_tensor(this_res['atom_pos']) this_res['frames'] = affine.rigids_to_tensor_flat12(this_res['frames']) print(check_recursive(this_res, res))
def backbone_loss(self, ret:Dict[str, torch.Tensor], value:Dict[str, torch.Tensor], batch:Dict[str, torch.Tensor]) -> None: affine_trajectory = QuatAffine.from_tensor(value['traj']) rigid_trajectory = affine_trajectory.to_rigids() # gt_trajectory = QuatAffine.from_tensor(batch['backbone_affine_tensor']) # gt_rigid = gt_trajectory.to_rigids() gt_rigid = rigids_from_tensor_flat12(batch['backbone_affine_tensor']) backbone_mask = batch['backbone_affine_mask'] fape_loss_fn = functools.partial(protein.frame_aligned_point_error, l1_clamp_distance=self.config.fape.clamp_distance, length_scale=self.config.fape.loss_unit_distance) fape_loss = [] for i in range(value['traj'].size(0)): pred = rigids_apply(lambda x: x[i,...], rigid_trajectory) fape_loss.append(fape_loss_fn(pred, gt_rigid, backbone_mask, pred.trans, gt_rigid.trans, backbone_mask)) fape_loss = torch.stack(fape_loss, dim=0) # print(gt_rigid.trans) # print(gt_rigid.trans) if 'use_clamped_fape' in batch: use_clamped_fape = torch.Tensor(batch['use_clamped_fape'], dtype=torch.float32) unclamped_fape_loss_fn = functools.partial( protein.frame_aligned_point_error, l1_clamp_distance=None, length_scale=self.config.fape.loss_unit_distance) fape_loss_unclamped = [] for i in range(value['traj'].size(0)): pred = rigids_apply(lambda x: x[i,...], rigid_trajectory) fape_loss_unclamped = unclamped_fape_loss_fn( pred, gt_rigid, backbone_mask, pred.trans, gt_rigid.trans, backbone_mask) fape_loss_unclamped = torch.stack(fape_loss_unclamped, dim=0) fape_loss = fape_loss*use_clamped_fape + fape_loss_unclamped*(1.0-use_clamped_fape) ret['fape'] = fape_loss[-1] ret['loss'] += torch.mean(fape_loss)
def generate_new_affine(self, sequence_mask:torch.Tensor): num_residues = sequence_mask.size(0) quaternion = torch.tile(sequence_mask.new_tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32).reshape(1, 4), (num_residues, 1)) translation = sequence_mask.new_zeros(num_residues, 3, dtype=torch.float32) return QuatAffine(quaternion, translation, unstack_inputs=True)