def OuterProductMeanTest(args, config, global_config): feat, params, res = load_data(args, 'OuterProductMean') conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean attn_opt = OuterProductMeanFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) attn_opt.load_weights_from_af2(params, 'outer_product_mean') attn_vanilla = OuterProductMeanOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) attn_vanilla.load_weights_from_af2(params, 'outer_product_mean') attn_vanilla.cuda() feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32)[:,:127,:] feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32)[:,:127] alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMean')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['msa_act'], feat['msa_mask']) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMeanOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['msa_act'], feat['msa_mask']) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def test_torsion_angles_to_frames(args): print('test_torsion_angles_to_frames') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_torsion_angles_to_frames') rigs = QuatAffine.from_tensor(activations).to_rigids() this_res = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = affine.rigids_to_tensor_flat12(this_res) print(check_recursive(this_res, res))
def TransitionTest(args, config, global_config): feat, params, res = load_data(args, 'Transition') conf = config.model.embeddings_and_evoformer.evoformer.pair_transition global_config.subbatch_size = 2 attn_opt = TransitionFF(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_opt.load_weights_from_af2(params, 'transition_block') attn_vanilla = TransitionOpt(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, 'transition_block') attn_vanilla.cuda() feat['seq_act'] = feat['seq_act'].to(device='cuda', dtype=torch.float32) feat['seq_mask'] = feat['seq_mask'].to(device='cuda', dtype=torch.float32) alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('Transition')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['seq_act'], feat['seq_mask']) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TransitionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['seq_act'], feat['seq_mask']) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def EvoformerIterationTest(args, config, global_config): feat, params, res = load_data(args, 'EvoformerIteration1') conf = config.model.embeddings_and_evoformer.evoformer attn_vanilla = EvoformerIterationOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_vanilla.load_weights_from_af2(params, rel_path='evoformer_iteration') attn_opt = EvoformerIterationFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_opt.load_weights_from_af2(params, rel_path='evoformer_iteration') feat['msa_act'] = feat['msa_act'].to(device='cuda',dtype=torch.float32) feat['pair_act'] = feat['pair_act'].to(device='cuda',dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda',dtype=torch.float32) feat['pair_mask'] = feat['pair_mask'].to(device='cuda',dtype=torch.float32) attn_vanilla.cuda() alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIteration')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIterationOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False) profiler.step() alloc_end_opt = get_total_alloc() check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def AlphaFoldTest(args, config): batch, params, res = load_data(args, 'AlphaFold') conf = config.model for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param + ' ' + str(params[key][param].shape)) for key in batch.keys(): print(key, batch[key].shape) conf.embeddings_and_evoformer.recycle_pos = False conf.embeddings_and_evoformer.recycle_features = False conf.embeddings_and_evoformer.template.enabled = False conf.embeddings_and_evoformer.evoformer_num_block = 1 conf.embeddings_and_evoformer.extra_msa_stack_num_block = 1 conf.num_recycle = 0 conf.resample_msa_in_recycling = False conf.global_config.deterministic = True attn = AlphaFold(conf, num_res=batch['target_feat'].shape[-2], target_dim=batch['target_feat'].shape[-1], msa_dim=batch['msa_feat'].shape[-1], extra_msa_dim=25) attn.load_weights_from_af2(params, rel_path='alphafold') with torch.no_grad(): this_res = attn(batch, is_training=False) print(this_res) check_recursive(res, this_res)
def TriangleMultiplicationTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleMultiplication') # conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming attn_opt = TriangleMultiplicationFF(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn_opt.load_weights_from_af2(params, 'triangle_multiplication') attn_vanilla = TriangleMultiplicationOpt(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, 'triangle_multiplication') attn_vanilla.cuda() feat['pair_act'] = feat['pair_act'].to(device='cuda', dtype=torch.float32) feat['pair_mask'] = feat['pair_mask'].to(device='cuda', dtype=torch.float32) alloc_start_vanilla = get_total_alloc() handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplication')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['pair_act'], feat['pair_mask'], is_training=False) profiler.step() alloc_end_vanilla = get_total_alloc() attn_opt.cuda() alloc_start_opt = get_total_alloc() handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplicationOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['pair_act'], feat['pair_mask'], is_training=False) profiler.step() alloc_end_opt = get_total_alloc() if isinstance(attn_opt, TriangleMultiplicationFF): check_recursive(res_opt, res_vanilla + feat['pair_act']) else: check_recursive(res_opt, res_vanilla) print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
def EvoformerIterationTest(args, config, global_config, is_training: bool = False): feat, params, res = load_data(args, 'EvoformerIteration1') conf = config.model.embeddings_and_evoformer.evoformer conf.msa_row_attention_with_pair_bias.dropout_rate = 0.0 conf.msa_column_attention.dropout_rate = 0.0 conf.triangle_attention_starting_node.dropout_rate = 0.0 conf.triangle_attention_ending_node.dropout_rate = 0.0 conf.triangle_multiplication_outgoing.dropout_rate = 0.0 conf.triangle_multiplication_incoming.dropout_rate = 0.0 conf.outer_product_mean.dropout_rate = 0.0 conf.pair_transition.dropout_rate = 0.0 attn_batch = EvoformerIterationFFB(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_batch.load_weights_from_af2(params, rel_path='evoformer_iteration') attn_single = EvoformerIterationFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False) attn_single.load_weights_from_af2(params, rel_path='evoformer_iteration') batch_size = 8 feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32) feat['pair_act'] = feat['pair_act'].to(device='cuda', dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32) feat['pair_mask'] = feat['pair_mask'].to(device='cuda', dtype=torch.float32) batch_msa_mask = feat['msa_mask'][None, ...].repeat(batch_size, 1, 1) batch_msa_act = feat['msa_act'][None, ...].repeat(batch_size, 1, 1, 1) batch_pair_mask = feat['pair_mask'][None, ...].repeat(batch_size, 1, 1) batch_pair_act = feat['pair_act'][None, ...].repeat(batch_size, 1, 1, 1) attn_single.cuda() attn_batch.cuda() res_single_msa, res_single_pair = attn_single(msa_act=feat['msa_act'], pair_act=feat['pair_act'], msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=is_training) res_batch_msa, res_batch_pair = attn_batch(msa_act=batch_msa_act, pair_act=batch_pair_act, msa_mask=batch_msa_mask, pair_mask=batch_pair_mask, is_training=is_training) for i in range(batch_size): err_msa = torch.sum(torch.abs(res_batch_msa[i, ...] - res_single_msa)) err_pair = torch.sum( torch.abs(res_batch_pair[i, ...] - res_single_pair)) print(i, err_msa.item(), err_pair.item()) assert (err_msa < 1e-5) and (err_pair < 1e-5)
def MSAColumnAttentionTest(args, config, global_config, is_training=False): feat, params, res = load_data(args, 'MSAColumnAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention conf.dropout_rate = 0.0 attn_batch = MSAColumnAttentionFFB(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_batch.load_weights_from_af2(params, rel_path='msa_column_attention') attn_single = MSAColumnAttentionFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_single.load_weights_from_af2(params, rel_path='msa_column_attention') attn_batch.cuda() attn_single.cuda() batch_size = 8 feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32) batch_msa_act = feat['msa_act'][None, ...].repeat(batch_size, 1, 1, 1) batch_msa_mask = feat['msa_mask'][None, ...].repeat(batch_size, 1, 1) res_batch = attn_batch(batch_msa_act, batch_msa_mask, is_training=is_training) res_single = attn_single(feat['msa_act'], feat['msa_mask'], is_training=is_training) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-5
def EmbeddingsAndEvoformerTest(args, config, global_config, cuda: bool = False): feat, params, res = load_data(args, 'EmbeddingsAndEvoformer') conf = config.model.embeddings_and_evoformer for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param + ' ' + str(params[key][param].shape)) for key in feat.keys(): print(key, feat[key].shape) conf.template.enabled = False conf.recycle_pos = False conf.recycle_features = False conf.evoformer_num_block = 1 conf.extra_msa_stack_num_block = 1 global_config.deterministic = True attn = EmbeddingsAndEvoformer(conf, global_config, target_dim=feat['target_feat'].shape[-1], msa_dim=feat['msa_feat'].shape[-1], extra_msa_dim=25, clear_cache=False) attn.load_weights_from_af2(params, rel_path='evoformer') if cuda: attn = attn.cuda() for key in feat.keys(): if isinstance(feat[key], torch.Tensor): feat[key] = feat[key].to(device='cuda') this_res = attn(feat, is_training=False) check_recursive(this_res, res)
def TransitionTest(args, config, global_config, is_training: bool = False): feat, params, res = load_data(args, 'Transition') conf = config.model.embeddings_and_evoformer.evoformer.pair_transition conf.dropout_rate = 0.0 global_config.subbatch_size = 2 attn_batch = TransitionFFB(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_batch.load_weights_from_af2(params, 'transition_block') attn_single = TransitionFF(conf, global_config, num_channel=feat['seq_act'].shape[-1]) attn_single.load_weights_from_af2(params, 'transition_block') attn_single.cuda() attn_batch.cuda() batch_size = 8 feat['seq_act'] = feat['seq_act'].to(device='cuda', dtype=torch.float32) feat['seq_mask'] = feat['seq_mask'].to(device='cuda', dtype=torch.float32) batch_seq_mask = feat['seq_mask'][None, ...].repeat(batch_size, 1, 1) batch_seq_act = feat['seq_act'][None, ...].repeat(batch_size, 1, 1, 1) res_single = attn_single(feat['seq_act'], feat['seq_mask'], is_training=is_training) res_batch = attn_batch(batch_seq_act, batch_seq_mask, is_training=is_training) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-5
def test_frames_and_literature_positions_to_atom14_pos(args): print('test_frames_and_literature_positions_to_atom14_pos') (activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_frames_and_literature_positions_to_atom14_pos') rigs = QuatAffine.from_tensor(activations).to_rigids() all_frames = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos) this_res = frames_and_literature_positions_to_atom14_pos(aatype=aatype, all_frames_to_global=all_frames) this_res = affine.vecs_to_tensor(this_res) print(check_recursive(this_res, res))
def MSARowAttentionWithPairBiasTest(args, config, global_config, is_training=False): feat, params, res = load_data(args, 'MSARowAttentionWithPairBias') for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias conf.dropout_rate = 0.0 attn_single = MSARowAttentionWithPairBiasFF( conf, global_config, pair_dim=feat['pair_act'].shape[-1], msa_dim=feat['msa_act'].shape[-1]) attn_single.load_weights_from_af2( params, rel_path='msa_row_attention_with_pair_bias') attn_batch = MSARowAttentionWithPairBiasFFB( conf, global_config, pair_dim=feat['pair_act'].shape[-1], msa_dim=feat['msa_act'].shape[-1]) attn_batch.load_weights_from_af2( params, rel_path='msa_row_attention_with_pair_bias') attn_single.cuda() attn_batch.cuda() batch_size = 8 feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32) #[:63,:,:] feat['pair_act'] = feat['pair_act'].to(device='cuda', dtype=torch.float32) #[:63,:,:] feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32) #[:63,:] batch_msa_act = feat['msa_act'][None, ...].repeat(batch_size, 1, 1, 1) batch_pair_act = feat['pair_act'][None, ...].repeat(batch_size, 1, 1, 1) batch_msa_mask = feat['msa_mask'][None, ...].repeat(batch_size, 1, 1) print(feat['pair_act'].size(), batch_pair_act.size()) print(feat['msa_act'].size(), batch_msa_act.size()) print(feat['msa_mask'].size(), batch_msa_mask.size()) res_single = attn_single(feat['msa_act'], feat['msa_mask'], feat['pair_act'], is_training=is_training) res_batch = attn_batch(batch_msa_act, batch_msa_mask, batch_pair_act, is_training=is_training) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-5
def MSAColumnGlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'MSAColumnGlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention attn = MSAColumnGlobalAttention(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='msa_column_global_attention') this_res = attn(feat['msa_act'], feat['msa_mask']) check_recursive(this_res, res)
def TriangleMultiplicationIncomingTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleMultiplicationIncoming') conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming attn = TriangleMultiplication(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='triangle_multiplication') this_res = attn(feat['pair_act'], feat['pair_mask']) check_recursive(this_res, res)
def frame_aligned_point_error_test(args): (activations1, activations2, frames_mask, pos_mask), res = load_data(args, 'frame_aligned_point_error_test') qa1 = affine.QuatAffine.from_tensor(activations1) rigs1 = qa1.to_rigids() qa2 = affine.QuatAffine.from_tensor(activations2) rigs2 = qa2.to_rigids() this_res = frame_aligned_point_error(rigs1, rigs2, frames_mask, rigs1.trans, rigs2.trans, pos_mask, 1.0) check_recursive(res, this_res)
def GlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'GlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias attn_opt = GlobalAttentionOpt(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn_opt.load_weights_from_af2(params['attention'], None) attn_vanilla = GlobalAttention(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn_vanilla.load_weights_from_af2(params['attention'], None) attn_vanilla.cuda() feat['q_data'] = feat['q_data'].to(device='cuda', dtype=torch.float32) feat['m_data'] = feat['m_data'].to(device='cuda', dtype=torch.float32) feat['q_mask'] = feat['q_mask'].to(device='cuda', dtype=torch.float32) handler_vanilla = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('GlobalAttention')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla( q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'].to(dtype=torch.float32), bias=feat['bias']) profiler.step() attn_opt.cuda() reporter = MemReporter() handler_opt = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('GlobalAttentionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'].to(dtype=torch.float32), bias=feat['bias']) profiler.step() reporter.report() check_recursive(res_opt, res_vanilla)
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn = InvariantPointAttention( conf, global_config, num_res=feat['inputs_1d'].shape[-2], num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn.load_weights_from_af2(params, rel_path='invariant_point_attention') qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32)) this_res = attn(inputs_1d = feat['inputs_1d'], inputs_2d = feat['inputs_2d'], mask=feat['mask'], affine=qa) print(check_recursive(this_res, res))
def TransitionTest(args, config, global_config): feat, params, res = load_data(args, 'Transition') conf = config.model.embeddings_and_evoformer.evoformer.pair_transition attn = Transition(conf, global_config, num_channel=feat['seq_act'].shape[-1]) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='transition_block') this_res = attn(feat['seq_act'], feat['seq_mask']) check_recursive(this_res, res)
def TriangleAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'TriangleAttention') conf = config.model.embeddings_and_evoformer.evoformer.triangle_attention_starting_node attn = TriangleAttention(conf, global_config, pair_dim=feat['pair_act'].shape[-1]) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='triangle_attention') this_res = attn(feat['pair_act'], feat['pair_mask']) check_recursive(this_res, res)
def InvariantPointAttentionTest(args, config, global_config): print('InvariantPointAttentionTest') feat, params, res = load_data(args, 'InvariantPointAttention') conf = config.model.heads.structure_module attn_single = InvariantPointAttention( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_single.load_weights_from_af2(params, rel_path='invariant_point_attention') attn_batch = InvariantPointAttentionB( conf, global_config, num_feat_1d=feat['inputs_1d'].shape[-1], num_feat_2d=feat['inputs_2d'].shape[-1]) attn_batch.load_weights_from_af2(params, rel_path='invariant_point_attention') print('inputs1d:', feat['inputs_1d'].size()) print('inputs2d:', feat['inputs_2d'].size()) print('activations:', feat['activations'].size()) print('mask:', feat['mask'].size()) batch_size = 8 inputs_1d_batch = feat['inputs_1d'][None, ...].repeat(batch_size, 1, 1) inputs_2d_batch = feat['inputs_2d'][None, ...].repeat(batch_size, 1, 1, 1) activations_batch = feat['activations'][None, ...].repeat(batch_size, 1, 1) mask_batch = feat['mask'][None, ...].repeat(batch_size, 1, 1) qa_single = QuatAffine.from_tensor( feat['activations'].to(dtype=torch.float32)) qa_batch = QuatAffine.from_tensor( activations_batch.to(dtype=torch.float32)) res_single = attn_single(inputs_1d=feat['inputs_1d'], inputs_2d=feat['inputs_2d'], mask=feat['mask'], affine=qa_single) res_batch = attn_batch(inputs_1d=inputs_1d_batch, inputs_2d=inputs_2d_batch, mask=mask_batch, affine=qa_batch) print(check_recursive(res_single, res)) print(check_recursive(res_batch[0, ...], res)) for i in range(batch_size): err = torch.sum(torch.abs(res_batch[i, ...] - res_single)) print(i, err.item()) assert err < 1e-2
def MSAColumnAttentionTest(args, config, global_config, is_training=False): feat, params, res = load_data(args, 'MSAColumnAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention attn_opt = MSAColumnAttentionFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_opt.load_weights_from_af2(params, rel_path='msa_column_attention') attn_vanilla = MSAColumnAttentionOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1]) attn_vanilla.load_weights_from_af2(params, rel_path='msa_column_attention') attn_vanilla.cuda() feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32) feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32) reporter = MemReporter() handler_vanilla = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('MSAColumnAttention')) with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_vanilla = attn_vanilla(feat['msa_act'], feat['msa_mask'], is_training=is_training) profiler.step() # reporter.report() attn_opt.cuda() reporter = MemReporter() handler_opt = torch.profiler.tensorboard_trace_handler( Path('Log') / Path('MSAColumnAttentionOpt')) with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler: res_opt = attn_opt(feat['msa_act'], feat['msa_mask'], is_training=is_training) profiler.step() # reporter.report() if isinstance(attn_opt, MSAColumnAttentionFF): check_recursive(res_opt, res_vanilla + feat['msa_act']) else: check_recursive(res_opt, res_vanilla)
def DistogramHeadTest(args, config, global_config): print('DistogramHeadTest') feat, params, res = load_data(args, 'DistogramHead') conf = config.model.heads.distogram representations = feat['representations'] batch = feat['batch'] attn = DistogramHead(conf, global_config, num_feat_2d=representations['pair'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='distogram_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))
def OuterProductMeanTest(args, config, global_config): feat, params, res = load_data(args, 'OuterProductMean') conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean attn = OuterProductMean(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256) for key in params.keys(): print(key) for param in params[key].keys(): print('\t' + param) attn.load_weights_from_af2(params, rel_path='outer_product_mean') this_res = attn(feat['msa_act'], feat['msa_mask']) check_recursive(this_res, res)
def create_extra_msa_features_test(args, config, global_config): feat, params, res = load_data(args, 'create_extra_msa_feature') conf = config.model.embeddings_and_evoformer # for key in params.keys(): # print(key) # for param in params[key].keys(): # print('\t' + param + ' ' + str(params[key][param].shape)) for key in feat.keys(): print(key, feat[key].shape) emb = ExtraMSAEmbedding(conf, global_config, msa_dim=feat['msa_feat'].shape[-1]) this_res = emb.create_extra_msa_features(feat) check_recursive(this_res, res)
def ExperimentallyResolvedHeadTest(args, config, global_config): print('ExperimentallyResolvedHeadTest') feat, params, res = load_data(args, 'ExperimentallyResolvedHead') conf = config.model.heads.experimentally_resolved representations = feat['representations'] batch = feat['batch'] attn = ExperimentallyResolvedHead(conf, global_config, num_feat_1d=representations['single'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='experimentally_resolved_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))
def MaskedMSAHeadTest(args, config, global_config): print('MaskedMSAHeadTest') feat, params, res = load_data(args, 'MaskedMSAHead') conf = config.model.heads.masked_msa representations = feat['representations'] batch = feat['batch'] attn = MaskedMSAHead(conf, global_config, num_feat_2d=representations['msa'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='masked_msa_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))
def PredictedAlignedErrorHeadTest(args, config, global_config): print('PredictedAlignedErrorHeadTest') feat, params, res = load_data(args, 'PredictedAlignedErrorHead') conf = config.model.heads.predicted_aligned_error representations = feat['representations'] batch = feat['batch'] attn = PredictedAlignedErrorHead(conf, global_config, num_feat_2d=representations['pair'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='predicted_aligned_error_head') this_res = attn(representations, batch) print(check_recursive(this_res, res))
def GlobalAttentionTest(args, config, global_config): feat, params, res = load_data(args, 'GlobalAttention') conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias attn = GlobalAttention(conf, global_config, output_dim=256, key_dim=feat['q_data'].shape[-1], value_dim=feat['m_data'].shape[-1]) attn.load_weights_from_af2(params['attention'], None) this_res = attn(q_data=feat['q_data'], m_data=feat['m_data'], q_mask=feat['q_mask'], bias=feat['bias']) check_recursive(this_res, res)
def MSARowAttentionWithPairBiasTest(args, config, global_config): feat, params, res = load_data(args, 'MSARowAttentionWithPairBias') # for key in params.keys(): # print(key) # for param in params[key].keys(): # print('\t' + param) conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias # conf.gating = False attn = MSARowAttentionWithPairBias(conf, global_config, pair_dim=feat['pair_act'].shape[-1], msa_dim=feat['msa_act'].shape[-1]) attn.load_weights_from_af2(params, rel_path='msa_row_attention_with_pair_bias') this_res = attn(feat['msa_act'], feat['msa_mask'], feat['pair_act']) check_recursive(this_res, res)
def StructureModuleTest(args, config, global_config): print('StructureModuleTest') feat, params, res = load_data(args, 'StructureModule') conf = config.model.heads.structure_module representations = feat['representations'] batch = feat['batch'] attn = StructureModule(conf, global_config, num_res=representations['single'].shape[-2], num_feat_1d=representations['single'].shape[-1], num_feat_2d=representations['pair'].shape[-1] ) attn.load_weights_from_af2(params, rel_path='structure_module') this_res = attn(representations, batch) print(check_recursive(this_res, res))