def OuterProductMeanTest(args, config, global_config): feat, params, res = load_data(args, 'OuterProductMean') conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean attn_opt = OuterProductMeanFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) attn_opt.load_weights_from_af2(params, 'outer_product_mean') attn_vanilla = OuterProductMeanOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) attn_vanilla.load_weights_from_af2(params, 'outer_product_mean') attn_vanilla.cuda() feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32)[:,:127,:] feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32)[:,:127] alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMean')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['msa_act'], feat['msa_mask']) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMeanOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['msa_act'], feat['msa_mask']) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def EmbeddingsAndEvoformerTest(args, config, global_config, cuda: bool = False): feat, params, res = load_data(args, 'EmbeddingsAndEvoformer') conf = config.model.embeddings_and_evoformer for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param + ' ' + str(params[key][param].shape)) for key in feat.keys(): print(key, feat[key].shape) conf.template.enabled = False conf.recycle_pos = False conf.recycle_features = False conf.evoformer_num_block = 1 conf.extra_msa_stack_num_block = 1 global_config.deterministic = True attn = EmbeddingsAndEvoformer(conf, global_config, target_dim=feat['target_feat'].shape[-1], msa_dim=feat['msa_feat'].shape[-1], extra_msa_dim=25, clear_cache=False) attn.load_weights_from_af2(params, rel_path='evoformer') if cuda: attn = attn.cuda() for key in feat.keys(): if isinstance(feat[key], torch.Tensor): feat[key] = feat[key].to(device='cuda') this_res = attn(feat, is_training=False) check_recursive(this_res, res)
def TransitionTest(args, config, global_config): feat, params, res = load_data(args, 'Transition') conf = config.model.embeddings_and_evoformer.evoformer.pair_transition global_config.subbatch_size = 2 attn_opt = TransitionFF(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_opt.load_weights_from_af2(params, 'transition_block') attn_vanilla = TransitionOpt(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, 'transition_block') attn_vanilla.cuda() feat['seq_act'] = feat['seq_act'].to(device='cuda', dtype=torch.float32) feat['seq_mask'] = feat['seq_mask'].to(device='cuda', dtype=torch.float32) alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('Transition')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['seq_act'], feat['seq_mask']) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TransitionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['seq_act'], feat['seq_mask']) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def AlphaFoldTest(args, config): batch, params, res = load_data(args, 'AlphaFold') conf = config.model for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param + ' ' + str(params[key][param].shape)) for key in batch.keys(): print(key, batch[key].shape) conf.embeddings_and_evoformer.recycle_pos = False conf.embeddings_and_evoformer.recycle_features = False conf.embeddings_and_evoformer.template.enabled = False conf.embeddings_and_evoformer.evoformer_num_block = 1 conf.embeddings_and_evoformer.extra_msa_stack_num_block = 1 conf.num_recycle = 0 conf.resample_msa_in_recycling = False conf.global_config.deterministic = True attn = AlphaFold(conf, num_res=batch['target_feat'].shape[-2], target_dim=batch['target_feat'].shape[-1], msa_dim=batch['msa_feat'].shape[-1], extra_msa_dim=25) attn.load_weights_from_af2(params, rel_path='alphafold') with torch.no_grad(): this_res = attn(batch, is_training=False) print(this_res) check_recursive(res, this_res)
def EvoformerIterationTest(args, config, global_config): feat, params, res = load_data(args, 'EvoformerIteration1') conf = config.model.embeddings_and_evoformer.evoformer attn_vanilla = EvoformerIterationOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_vanilla.load_weights_from_af2(params, rel_path='evoformer_iteration') attn_opt = EvoformerIterationFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_opt.load_weights_from_af2(params, rel_path='evoformer_iteration') feat['msa_act'] = feat['msa_act'].to(device='cuda',dtype=torch.float32) feat['pair_act'] = feat['pair_act'].to(device='cuda',dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda',dtype=torch.float32) feat['pair_mask'] = feat['pair_mask'].to(device='cuda',dtype=torch.float32) attn_vanilla.cuda() alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIteration')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIterationOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def TriangleMultiplicationTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleMultiplication') # conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming attn_opt = TriangleMultiplicationFF(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn_opt.load_weights_from_af2(params, 'triangle_multiplication') attn_vanilla = TriangleMultiplicationOpt(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, 'triangle_multiplication') attn_vanilla.cuda() feat['pair_act'] = feat['pair_act'].to(device='cuda', dtype=torch.float32) feat['pair_mask'] = feat['pair_mask'].to(device='cuda', dtype=torch.float32) alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplication')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['pair_act'], feat['pair_mask'], is_training=False) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplicationOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['pair_act'], feat['pair_mask'], is_training=False) profiler.step() alloc_end_opt = get_total_alloc() if isinstance(attn_opt, TriangleMultiplicationFF): check_recursive(res_opt, res_vanilla + feat['pair_act']) else: check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def MSAColumnGlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'MSAColumnGlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention attn = MSAColumnGlobalAttention(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='msa_column_global_attention') this_res = attn(feat['msa_act'], feat['msa_mask']) check_recursive(this_res, res)
def TriangleMultiplicationIncomingTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleMultiplicationIncoming') conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming attn = TriangleMultiplication(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='triangle_multiplication') this_res = attn(feat['pair_act'], feat['pair_mask']) check_recursive(this_res, res)
def frame_aligned_point_error_test(args): (activations1, activations2, frames_mask, pos_mask), res = load_data(args, 'frame_aligned_point_error_test') qa1 = affine.QuatAffine.from_tensor(activations1) rigs1 = qa1.to_rigids() qa2 = affine.QuatAffine.from_tensor(activations2) rigs2 = qa2.to_rigids() this_res = frame_aligned_point_error(rigs1, rigs2, frames_mask, rigs1.trans, rigs2.trans, pos_mask, 1.0) check_recursive(res, this_res)
def GlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'GlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias attn_opt = GlobalAttentionOpt(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn_opt.load_weights_from_af2(params['attention'], None) attn_vanilla = GlobalAttention(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn_vanilla.load_weights_from_af2(params['attention'], None) attn_vanilla.cuda() feat['q_data'] = feat['q_data'].to(device='cuda', dtype=torch.float32) feat['m_data'] = feat['m_data'].to(device='cuda', dtype=torch.float32) feat['q_mask'] = feat['q_mask'].to(device='cuda', dtype=torch.float32) handler_vanilla = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('GlobalAttention')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla( q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'].to(dtype=torch.float32), bias=feat['bias']) profiler.step() attn_opt.cuda() reporter = MemReporter() handler_opt = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('GlobalAttentionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'].to(dtype=torch.float32), bias=feat['bias']) profiler.step() reporter.report() check_recursive(res_opt, res_vanilla)
def MSAColumnAttentionTest(args, config, global_config, is_training=False): feat, params, res = load_data(args, 'MSAColumnAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention attn_opt = MSAColumnAttentionFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_opt.load_weights_from_af2(params, rel_path='msa_column_attention') attn_vanilla = MSAColumnAttentionOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, rel_path='msa_column_attention') attn_vanilla.cuda() feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32) reporter = MemReporter() handler_vanilla = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('MSAColumnAttention')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['msa_act'], feat['msa_mask'], is_training=is_training) profiler.step() # reporter.report() attn_opt.cuda() reporter = MemReporter() handler_opt = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('MSAColumnAttentionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['msa_act'], feat['msa_mask'], is_training=is_training) profiler.step() # reporter.report() if isinstance(attn_opt, MSAColumnAttentionFF): check_recursive(res_opt, res_vanilla + feat['msa_act']) else: check_recursive(res_opt, res_vanilla)
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn_single = InvariantPointAttention( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_single.load_weights_from_af2(params, rel_path='invariant_point_attention') attn_batch = InvariantPointAttentionB( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_batch.load_weights_from_af2(params, rel_path='invariant_point_attention') print('inputs1d:', feat['inputs_1d'].size()) print('inputs2d:', feat['inputs_2d'].size()) print('activations:', feat['activations'].size()) print('mask:', feat['mask'].size()) batch_size = 8 inputs_1d_batch = feat['inputs_1d'][None, ...].repeat(batch_size, 1, 1) inputs_2d_batch = feat['inputs_2d'][None, ...].repeat(batch_size, 1, 1, 1) activations_batch = feat['activations'][None, ...].repeat(batch_size, 1, 1) mask_batch = feat['mask'][None, ...].repeat(batch_size, 1, 1) qa_single = QuatAffine.from_tensor( feat['activations'].to(dtype=torch.float32)) qa_batch = QuatAffine.from_tensor( activations_batch.to(dtype=torch.float32)) res_single = attn_single(inputs_1d=feat['inputs_1d'], inputs_2d=feat['inputs_2d'], mask=feat['mask'], affine=qa_single) res_batch = attn_batch(inputs_1d=inputs_1d_batch, inputs_2d=inputs_2d_batch, mask=mask_batch, affine=qa_batch) print(check_recursive(res_single, res)) print(check_recursive(res_batch[0, ...], res)) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-2
def TransitionTest(args, config, global_config): feat, params, res = load_data(args, 'Transition') conf = config.model.embeddings_and_evoformer.evoformer.pair_transition attn = Transition(conf, global_config, num_channel=feat['seq_act'].shape[-1]) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='transition_block') this_res = attn(feat['seq_act'], feat['seq_mask']) check_recursive(this_res, res)
def TriangleAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleAttention') conf = config.model.embeddings_and_evoformer.evoformer.triangle_attention_starting_node attn = TriangleAttention(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='triangle_attention') this_res = attn(feat['pair_act'], feat['pair_mask']) check_recursive(this_res, res)
def create_extra_msa_features_test(args, config, global_config): feat, params, res = load_data(args, 'create_extra_msa_feature') conf = config.model.embeddings_and_evoformer # for key in params.keys(): # print(key) # for param in params[key].keys(): # print('\t' + param + ' ' + str(params[key][param].shape)) for key in feat.keys(): print(key, feat[key].shape) emb = ExtraMSAEmbedding(conf, global_config, msa_dim=feat['msa_feat'].shape[-1]) this_res = emb.create_extra_msa_features(feat) check_recursive(this_res, res)
def OuterProductMeanTest(args, config, global_config): feat, params, res = load_data(args, 'OuterProductMean') conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean attn = OuterProductMean(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='outer_product_mean') this_res = attn(feat['msa_act'], feat['msa_mask']) check_recursive(this_res, res)
def test_torsion_angles_to_frames(args): print('test_torsion_angles_to_frames') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_torsion_angles_to_frames') rigs = QuatAffine.from_tensor(activations).to_rigids() this_res = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = affine.rigids_to_tensor_flat12(this_res) print(check_recursive(this_res, res))
def GlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'GlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias attn = GlobalAttention(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn.load_weights_from_af2(params['attention'], None) this_res = attn(q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'], bias=feat['bias']) check_recursive(this_res, res)
def pre_compose_test(args, name): (tensor, update), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) this_res = qa.pre_compose(update).to_tensor() err, max_err, mean_err = check_recursive(res, this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def invert_point_test(args, name): (tensor, point), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) this_res = qa.invert_point(point) err, max_err, mean_err = check_recursive(res, this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def init_test(args, name): args, res = load_data(args, name) qa = QuatAffine.from_tensor(*args) this_res = qa.quaternion, qa.translation, qa.rotation for a,b in zip(res, this_res): err, max_err, mean_err = check_recursive(a,b) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def to_tensor_test(args, name): (tensor, ), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0) this_res = qa.to_tensor() err, max_err, mean_err = check_recursive(res,this_res) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def MSARowAttentionWithPairBiasTest(args, config, global_config): feat, params, res = load_data(args, 'MSARowAttentionWithPairBias') # for key in params.keys(): # print(key) # for param in params[key].keys(): # print('\t' + param) conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias # conf.gating = False attn = MSARowAttentionWithPairBias(conf, global_config, pair_dim=feat['pair_act'].shape[-1], msa_dim=feat['msa_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='msa_row_attention_with_pair_bias') this_res = attn(feat['msa_act'], feat['msa_mask'], feat['pair_act']) check_recursive(this_res, res)
def test_frames_and_literature_positions_to_atom14_pos(args): print('test_frames_and_literature_positions_to_atom14_pos') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_frames_and_literature_positions_to_atom14_pos') rigs = QuatAffine.from_tensor(activations).to_rigids() all_frames = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = frames_and_literature_positions_to_atom14_pos(aatype=aatype, all_frames_to_global=all_frames) this_res = affine.vecs_to_tensor(this_res) print(check_recursive(this_res, res))
def apply_rot_func_test(args, name): (tensor, ), res = load_data(args, name) qa = QuatAffine.from_tensor(tensor) qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0) this_res = qa.quaternion, qa.translation, qa.rotation for a,b in zip(res, this_res): err, max_err, mean_err = check_recursive(a,b) print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}') print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
def EvoformerIterationTest1(args, config, global_config): feat, params, res = load_data(args, 'EvoformerIteration1') conf = config.model.embeddings_and_evoformer.evoformer attn = EvoformerIteration(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn.load_weights_from_af2(params, rel_path='evoformer_iteration') this_res = attn(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'].float(), pair_mask=feat['pair_mask'].float(), is_training=False) this_res = {'msa': this_res[0], 'pair': this_res[1]} check_recursive(this_res, res)
def AttentionTest(args, config, global_config): feat, params, res = load_data(args, 'Attention') # for param in params['attention'].keys(): # print(param) conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias # conf.gating = False attn = Attention(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn.load_weights_from_af2(params['attention'], None) this_res = attn(q_data=feat['q_data'], m_data=feat['m_data'], bias=feat['bias'], nonbatched_bias=feat['nonbatched_bias']) check_recursive(this_res, res)
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn = InvariantPointAttention( conf, global_config, num_res=feat['inputs_1d'].shape[-2], num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn.load_weights_from_af2(params, rel_path='invariant_point_attention') qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32)) this_res = attn(inputs_1d = feat['inputs_1d'], inputs_2d = feat['inputs_2d'], mask=feat['mask'], affine=qa) print(check_recursive(this_res, res))
def PredictedAlignedErrorHeadTest(args, config, global_config): print('PredictedAlignedErrorHeadTest') feat, params, res = load_data(args, 'PredictedAlignedErrorHead') conf = config.model.heads.predicted_aligned_error representations = feat['representations'] batch = feat['batch'] attn = PredictedAlignedErrorHead(conf, global_config, num_feat_2d=representations['pair'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='predicted_aligned_error_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))
def DistogramHeadTest(args, config, global_config): print('DistogramHeadTest') feat, params, res = load_data(args, 'DistogramHead') conf = config.model.heads.distogram representations = feat['representations'] batch = feat['batch'] attn = DistogramHead(conf, global_config, num_feat_2d=representations['pair'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='distogram_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))