Example #1
0
	def forward(self, affine:QuatAffine, representations_list: Sequence[torch.Tensor], aatype:torch.Tensor):
		act = [self.input_projection[i](self.relu(x)) for i, x in enumerate(representations_list)]
		act = sum(act)

		for i in range(self.config.num_residual_block):
			old_act = act
			act = self.resblock1[i](self.relu(act))
			act = self.resblock2[i](self.relu(act))
			act += old_act

		unnormalized_angles = self.unnormalized_angles(self.relu(act))
		unnormalized_angles = unnormalized_angles.view(act.size(0), 7, 2)
		angles = self.l2_normalize(unnormalized_angles, dim=-1)

		affine.cast_to(dtype=torch.float32)
		angles = angles.to(dtype=torch.float32)
		backb_to_global = affine.to_rigids()
		# print('Sidechain angles dtype:', angles.dtype)
		# print('Sidechain frames dtype:', backb_to_global.rot.xx.dtype)
		all_frames_to_global = protein.torsion_angles_to_frames(aatype, backb_to_global, angles)
		pred_positions = protein.frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global)
		affine.cast_to(dtype=act.dtype)
		# print('Sidechain all_frames_to_global dtype:', all_frames_to_global.rot.xx.dtype)
		# print('Sidechain pred_positions dtype:', pred_positions.x.dtype)
		
		outputs = {	'angles_sin_cos': angles,
					'unnormalized_angles_sin_cos': unnormalized_angles,
					'atom_pos': pred_positions,
					'frames': all_frames_to_global}
		return outputs
Example #2
0
def InvariantPointAttentionTest(args, config, global_config):
    print('InvariantPointAttentionTest')
    feat, params, res = load_data(args, 'InvariantPointAttention')
    conf = config.model.heads.structure_module

    attn_single = InvariantPointAttention(
        conf,
        global_config,
        num_feat_1d=feat['inputs_1d'].shape[-1],
        num_feat_2d=feat['inputs_2d'].shape[-1])
    attn_single.load_weights_from_af2(params,
                                      rel_path='invariant_point_attention')
    attn_batch = InvariantPointAttentionB(
        conf,
        global_config,
        num_feat_1d=feat['inputs_1d'].shape[-1],
        num_feat_2d=feat['inputs_2d'].shape[-1])
    attn_batch.load_weights_from_af2(params,
                                     rel_path='invariant_point_attention')

    print('inputs1d:', feat['inputs_1d'].size())
    print('inputs2d:', feat['inputs_2d'].size())
    print('activations:', feat['activations'].size())
    print('mask:', feat['mask'].size())
    batch_size = 8
    inputs_1d_batch = feat['inputs_1d'][None, ...].repeat(batch_size, 1, 1)
    inputs_2d_batch = feat['inputs_2d'][None, ...].repeat(batch_size, 1, 1, 1)
    activations_batch = feat['activations'][None, ...].repeat(batch_size, 1, 1)
    mask_batch = feat['mask'][None, ...].repeat(batch_size, 1, 1)

    qa_single = QuatAffine.from_tensor(
        feat['activations'].to(dtype=torch.float32))
    qa_batch = QuatAffine.from_tensor(
        activations_batch.to(dtype=torch.float32))

    res_single = attn_single(inputs_1d=feat['inputs_1d'],
                             inputs_2d=feat['inputs_2d'],
                             mask=feat['mask'],
                             affine=qa_single)
    res_batch = attn_batch(inputs_1d=inputs_1d_batch,
                           inputs_2d=inputs_2d_batch,
                           mask=mask_batch,
                           affine=qa_batch)
    print(check_recursive(res_single, res))
    print(check_recursive(res_batch[0, ...], res))

    for i in range(batch_size):
        err = torch.sum(torch.abs(res_batch[i, ...] - res_single))
        print(i, err.item())
        assert err < 1e-2
Example #3
0
def pre_compose_test(args, name):
	(tensor, update), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	this_res = qa.pre_compose(update).to_tensor()
	err, max_err, mean_err = check_recursive(res, this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Example #4
0
def invert_point_test(args, name):
	(tensor, point), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	this_res = qa.invert_point(point)
	err, max_err, mean_err = check_recursive(res, this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Example #5
0
def test_torsion_angles_to_frames(args):
	print('test_torsion_angles_to_frames')
	(activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_torsion_angles_to_frames')
	rigs = QuatAffine.from_tensor(activations).to_rigids()
	this_res = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos)
	this_res = affine.rigids_to_tensor_flat12(this_res)
	print(check_recursive(this_res, res))
Example #6
0
	def forward(self, activations:torch.Tensor, sequence_mask:torch.Tensor, update_affine:bool, initial_act:torch.Tensor, 
					is_training:bool=False, static_feat_2d:torch.Tensor=None, aatype:torch.Tensor=None):
		affine = QuatAffine.from_tensor(activations['affine'])
		act = activations['act']
		affine.cast_to(dtype=act.dtype) #Casting to float16
		attn = self.attention_module(inputs_1d=act, inputs_2d=static_feat_2d, mask=sequence_mask, affine=affine)
		act = act + attn
		act = self.attention_layer_norm(act)
		
		input_act = act
		for i in range(self.config.num_layer_in_transition):
			act = self.transition[i](act)
			if i < self.config.num_layer_in_transition - 1:
				act = self.relu(act)
		act = act + input_act
		act = self.transition_layer_norm(act)

		if update_affine:
			affine_update = self.affine_update(act)
			#Not sure if it helps
			affine.cast_to(dtype=torch.float32) #Casting to float32 to avoid multiplying rotations in float16
			affine = affine.pre_compose(affine_update.to(dtype=torch.float32))
			affine.cast_to(dtype=act.dtype) #Casting back to float16

		sc = self.side_chain(affine.scale_translation(self.config.position_scale), [act, initial_act], aatype)
		#Casting final affines to float32
		outputs = {'affine': affine.to_tensor().to(dtype=torch.float32), 'sc': sc}
		new_activations = {'act': act,	'affine': affine.apply_rotation_tensor_fn(torch.detach).to_tensor()}

		return new_activations, outputs
Example #7
0
def init_test(args, name):
	args, res = load_data(args, name)
	qa = QuatAffine.from_tensor(*args)
	this_res = qa.quaternion, qa.translation, qa.rotation
	for a,b in zip(res, this_res):
		err, max_err, mean_err = check_recursive(a,b)
		print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
		print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Example #8
0
def to_tensor_test(args, name):
	(tensor, ), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0)
	this_res = qa.to_tensor()
	err, max_err, mean_err = check_recursive(res,this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Example #9
0
def test_frames_and_literature_positions_to_atom14_pos(args):
	print('test_frames_and_literature_positions_to_atom14_pos')
	(activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_frames_and_literature_positions_to_atom14_pos')
	rigs = QuatAffine.from_tensor(activations).to_rigids()
	all_frames = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos)
	this_res = frames_and_literature_positions_to_atom14_pos(aatype=aatype, all_frames_to_global=all_frames)
	this_res = affine.vecs_to_tensor(this_res)
	print(check_recursive(this_res, res))
Example #10
0
def apply_rot_func_test(args, name):
	(tensor, ), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0)
	this_res = qa.quaternion, qa.translation, qa.rotation
	for a,b in zip(res, this_res):
		err, max_err, mean_err = check_recursive(a,b)
		print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
		print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Example #11
0
def InvariantPointAttentionTest(args, config, global_config):
	print('InvariantPointAttentionTest')
	feat, params, res = load_data(args, 'InvariantPointAttention')
	conf = config.model.heads.structure_module
	
	attn = InvariantPointAttention(	conf, global_config, 
									num_res=feat['inputs_1d'].shape[-2], 
									num_feat_1d=feat['inputs_1d'].shape[-1],
									num_feat_2d=feat['inputs_2d'].shape[-1])
	attn.load_weights_from_af2(params, rel_path='invariant_point_attention')
	
	qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32))
	this_res = attn(inputs_1d = feat['inputs_1d'], inputs_2d = feat['inputs_2d'], mask=feat['mask'], affine=qa)
	print(check_recursive(this_res, res))
Example #12
0
    def loss(self, value: Dict[str, torch.Tensor],
             batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_affine = QuatAffine.from_tensor(
            value['structure_module']['final_affines'])
        true_affine = QuatAffine.from_tensor(batch['backbone_affine_tensor'])
        mask = batch['backbone_affine_mask']
        square_mask = mask[:, None] * mask[None, :]
        num_bins = self.config.num_bins
        breaks = value['predicted_aligned_error']['breaks']
        logits = value['predicted_aligned_error']['logits']

        def _local_frame_points(affine):
            points = [x.unsqueeze(dim=-2) for x in affine.translation]
            return affine.invert_point(points, extra_dims=1)

        error_dist2_xyz = [
            torch.square(x - y)
            for x, y in zip(_local_frame_points(predicted_affine),
                            _local_frame_points(true_affine))
        ]
        error_dist2 = sum(error_dist2_xyz)
        error_dist2.detach()

        sq_breaks = torch.square(breaks)
        true_bins = torch.sum(
            (error_dist2[..., None] > sq_breaks).to(dtype=torch.int32), dim=-1)

        errors = self.loss_function(logits, true_bins)
        loss = torch.sum(errors * square_mask, dim=(-2, -1)) / (
            1e-8 + torch.sum(square_mask, dim=(-2, -1)))

        if self.config.filter_by_resolution:
            loss *= ((batch['resolution'] >= self.config.min_resolution) &
                     (batch['resolution'] < self.config.max_resolution)).to(
                         dtype=torch.float32)
        return {'loss': loss}
Example #13
0
def MultiRigidSidechainTest(args, config, global_config):
	print('MultiRigidSidechainTest')
	feat, params, res = load_data(args, 'MultiRigidSidechain')
	conf = config.model.heads.structure_module.sidechain

	for key in params.keys():
		print(key)
		for param in params[key].keys():
			print('\t' + param)
	
	attn = MultiRigidSidechain(	conf, global_config, 
								repr_dim=feat['representations_list'][0].shape[-1], 
								num_repr=len(feat['representations_list'])
								)
	attn.load_weights_from_af2(params, rel_path='rigid_sidechain')
	
	qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32))
	this_res = attn(affine = qa, representations_list = feat['representations_list'], aatype=feat['aatype'])
	this_res['atom_pos'] = affine.vecs_to_tensor(this_res['atom_pos'])
	this_res['frames'] = affine.rigids_to_tensor_flat12(this_res['frames'])
	print(check_recursive(this_res, res))
Example #14
0
	def backbone_loss(self, ret:Dict[str, torch.Tensor], value:Dict[str, torch.Tensor], batch:Dict[str, torch.Tensor]) -> None:
		affine_trajectory = QuatAffine.from_tensor(value['traj'])
		rigid_trajectory = affine_trajectory.to_rigids()
		
		# gt_trajectory = QuatAffine.from_tensor(batch['backbone_affine_tensor'])
		# gt_rigid = gt_trajectory.to_rigids()
		gt_rigid = rigids_from_tensor_flat12(batch['backbone_affine_tensor'])
		backbone_mask = batch['backbone_affine_mask']
		
		fape_loss_fn = functools.partial(protein.frame_aligned_point_error, 
						l1_clamp_distance=self.config.fape.clamp_distance,
						length_scale=self.config.fape.loss_unit_distance)
		fape_loss = []
		for i in range(value['traj'].size(0)):
			pred = rigids_apply(lambda x: x[i,...], rigid_trajectory)
			fape_loss.append(fape_loss_fn(pred, gt_rigid, backbone_mask, pred.trans, gt_rigid.trans, backbone_mask))
		fape_loss = torch.stack(fape_loss, dim=0)

		# print(gt_rigid.trans)
		# print(gt_rigid.trans)
		
		if 'use_clamped_fape' in batch:
			use_clamped_fape = torch.Tensor(batch['use_clamped_fape'], dtype=torch.float32)
			unclamped_fape_loss_fn = functools.partial(	protein.frame_aligned_point_error, 
														l1_clamp_distance=None,
														length_scale=self.config.fape.loss_unit_distance)
			fape_loss_unclamped = []
			for i in range(value['traj'].size(0)):
				pred = rigids_apply(lambda x: x[i,...], rigid_trajectory)
				fape_loss_unclamped = unclamped_fape_loss_fn(	pred, gt_rigid, backbone_mask,
																pred.trans, gt_rigid.trans, backbone_mask)
			fape_loss_unclamped = torch.stack(fape_loss_unclamped, dim=0)
			fape_loss = fape_loss*use_clamped_fape + fape_loss_unclamped*(1.0-use_clamped_fape)

		ret['fape'] = fape_loss[-1]
		ret['loss'] += torch.mean(fape_loss)
Example #15
0
	def generate_new_affine(self, sequence_mask:torch.Tensor):
		num_residues = sequence_mask.size(0)
		quaternion = torch.tile(sequence_mask.new_tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32).reshape(1, 4), (num_residues, 1))
		translation = sequence_mask.new_zeros(num_residues, 3, dtype=torch.float32)
		return QuatAffine(quaternion, translation, unstack_inputs=True)