Exemplo n.º 1
0
def OuterProductMeanTest(args, config, global_config):
	feat, params, res = load_data(args, 'OuterProductMean')
		
	conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean
	attn_opt = OuterProductMeanFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256)
	attn_opt.load_weights_from_af2(params, 'outer_product_mean')
	attn_vanilla = OuterProductMeanOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], num_output_channel=256)
	attn_vanilla.load_weights_from_af2(params, 'outer_product_mean')
	
	attn_vanilla.cuda()
	feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32)[:,:127,:]
	feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32)[:,:127]
		
			
	alloc_start_vanilla = get_total_alloc()
	handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMean'))
	with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_vanilla = attn_vanilla(feat['msa_act'], feat['msa_mask'])	
		profiler.step()
	alloc_end_vanilla = get_total_alloc()

	attn_opt.cuda()
	alloc_start_opt = get_total_alloc()
	handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('OuterProductMeanOpt'))
	with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_opt = attn_opt(feat['msa_act'], feat['msa_mask'])
		profiler.step()
	alloc_end_opt = get_total_alloc()
	
	check_recursive(res_opt, res_vanilla)
	print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
Exemplo n.º 2
0
def EmbeddingsAndEvoformerTest(args,
                               config,
                               global_config,
                               cuda: bool = False):
    feat, params, res = load_data(args, 'EmbeddingsAndEvoformer')
    conf = config.model.embeddings_and_evoformer
    for key in params.keys():
        print(key)
        for param in params[key].keys():
            print('\t' + param + '  ' + str(params[key][param].shape))
    for key in feat.keys():
        print(key, feat[key].shape)

    conf.template.enabled = False
    conf.recycle_pos = False
    conf.recycle_features = False
    conf.evoformer_num_block = 1
    conf.extra_msa_stack_num_block = 1
    global_config.deterministic = True
    attn = EmbeddingsAndEvoformer(conf,
                                  global_config,
                                  target_dim=feat['target_feat'].shape[-1],
                                  msa_dim=feat['msa_feat'].shape[-1],
                                  extra_msa_dim=25,
                                  clear_cache=False)
    attn.load_weights_from_af2(params, rel_path='evoformer')

    if cuda:
        attn = attn.cuda()
        for key in feat.keys():
            if isinstance(feat[key], torch.Tensor):
                feat[key] = feat[key].to(device='cuda')

    this_res = attn(feat, is_training=False)
    check_recursive(this_res, res)
Exemplo n.º 3
0
def TransitionTest(args, config, global_config):
	feat, params, res = load_data(args, 'Transition')

	conf = config.model.embeddings_and_evoformer.evoformer.pair_transition
	global_config.subbatch_size = 2
	attn_opt = TransitionFF(conf, global_config, num_channel=feat['seq_act'].shape[-1])
	attn_opt.load_weights_from_af2(params, 'transition_block')
	attn_vanilla = TransitionOpt(conf, global_config, num_channel=feat['seq_act'].shape[-1])
	attn_vanilla.load_weights_from_af2(params, 'transition_block')
	

	attn_vanilla.cuda()
	feat['seq_act'] = feat['seq_act'].to(device='cuda', dtype=torch.float32)
	feat['seq_mask'] = feat['seq_mask'].to(device='cuda', dtype=torch.float32)
		
	alloc_start_vanilla = get_total_alloc()
	handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('Transition'))
	with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_vanilla = attn_vanilla(feat['seq_act'], feat['seq_mask'])	
		profiler.step()
	alloc_end_vanilla = get_total_alloc()

	attn_opt.cuda()
	alloc_start_opt = get_total_alloc()
	handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TransitionOpt'))
	with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_opt = attn_opt(feat['seq_act'], feat['seq_mask'])
		profiler.step()
	alloc_end_opt = get_total_alloc()
		
	check_recursive(res_opt, res_vanilla)
	print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
Exemplo n.º 4
0
def AlphaFoldTest(args, config):
	batch, params, res = load_data(args, 'AlphaFold')
	conf = config.model
	for key in params.keys():
		print(key)
		for param in params[key].keys():
			print('\t' + param + '  ' + str(params[key][param].shape))
	for key in batch.keys():
		print(key, batch[key].shape)

	conf.embeddings_and_evoformer.recycle_pos = False
	conf.embeddings_and_evoformer.recycle_features = False
	conf.embeddings_and_evoformer.template.enabled = False
	conf.embeddings_and_evoformer.evoformer_num_block = 1
	conf.embeddings_and_evoformer.extra_msa_stack_num_block = 1
	conf.num_recycle = 0
	conf.resample_msa_in_recycling = False
	conf.global_config.deterministic = True

	attn = AlphaFold(conf,
					num_res=batch['target_feat'].shape[-2],
					target_dim=batch['target_feat'].shape[-1], 
					msa_dim=batch['msa_feat'].shape[-1],
					extra_msa_dim=25)
	attn.load_weights_from_af2(params, rel_path='alphafold')
	with torch.no_grad():
		this_res = attn(batch, is_training=False)
	print(this_res)
	check_recursive(res, this_res)
Exemplo n.º 5
0
def EvoformerIterationTest(args, config, global_config):
	feat, params, res = load_data(args, 'EvoformerIteration1')
	conf = config.model.embeddings_and_evoformer.evoformer
	
	attn_vanilla = EvoformerIterationOpt(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False)
	attn_vanilla.load_weights_from_af2(params, rel_path='evoformer_iteration')
	
	attn_opt = EvoformerIterationFF(conf, global_config, msa_dim=feat['msa_act'].shape[-1], pair_dim=feat['pair_act'].shape[-1], is_extra_msa=False)
	attn_opt.load_weights_from_af2(params, rel_path='evoformer_iteration')
		
	feat['msa_act'] = feat['msa_act'].to(device='cuda',dtype=torch.float32)
	feat['pair_act'] = feat['pair_act'].to(device='cuda',dtype=torch.float32)
	feat['msa_mask'] = feat['msa_mask'].to(device='cuda',dtype=torch.float32)
	feat['pair_mask'] = feat['pair_mask'].to(device='cuda',dtype=torch.float32)

	attn_vanilla.cuda()
	alloc_start_vanilla = get_total_alloc()
	handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIteration'))
	with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_vanilla = attn_vanilla(msa_act=feat['msa_act'], pair_act=feat['pair_act'], 
								msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False)
	profiler.step()
	alloc_end_vanilla = get_total_alloc()
	
	attn_opt.cuda()
	alloc_start_opt = get_total_alloc()
	handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('EvoformerIterationOpt'))
	with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_opt = attn_opt(msa_act=feat['msa_act'], pair_act=feat['pair_act'], 
						msa_mask=feat['msa_mask'], pair_mask=feat['pair_mask'], is_training=False)
		profiler.step()
	alloc_end_opt = get_total_alloc()

	check_recursive(res_opt, res_vanilla)
	print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
Exemplo n.º 6
0
def TriangleMultiplicationTest(args, config, global_config):
	feat, params, res = load_data(args, 'TriangleMultiplication')
		
	# conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing
	conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming
	attn_opt = TriangleMultiplicationFF(conf, global_config, pair_dim=feat['pair_act'].shape[-1])
	attn_opt.load_weights_from_af2(params, 'triangle_multiplication')
	attn_vanilla = TriangleMultiplicationOpt(conf, global_config, pair_dim=feat['pair_act'].shape[-1])
	attn_vanilla.load_weights_from_af2(params, 'triangle_multiplication')
	

	attn_vanilla.cuda()
	feat['pair_act'] = feat['pair_act'].to(device='cuda', dtype=torch.float32)
	feat['pair_mask'] = feat['pair_mask'].to(device='cuda', dtype=torch.float32)
	
	alloc_start_vanilla = get_total_alloc()
	handler_vanilla = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplication'))
	with torch.profiler.profile(on_trace_ready=handler_vanilla, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_vanilla = attn_vanilla(feat['pair_act'], feat['pair_mask'], is_training=False)
		profiler.step()
	alloc_end_vanilla = get_total_alloc()

	attn_opt.cuda()
	alloc_start_opt = get_total_alloc()
	handler_opt = torch.profiler.tensorboard_trace_handler(Path('Log')/Path('TriangleMultiplicationOpt'))
	with torch.profiler.profile(on_trace_ready=handler_opt, with_stack=True, with_modules=True, profile_memory=True, record_shapes=True) as profiler:
		res_opt = attn_opt(feat['pair_act'], feat['pair_mask'], is_training=False)
		profiler.step()
	alloc_end_opt = get_total_alloc()
	
	if isinstance(attn_opt, TriangleMultiplicationFF):	
		check_recursive(res_opt, res_vanilla + feat['pair_act'])
	else:
		check_recursive(res_opt, res_vanilla)
	print(f'Mem vanilla: {mem_to_str(alloc_end_vanilla-alloc_start_vanilla)} \t opt: {mem_to_str(alloc_end_opt-alloc_start_opt)}')
Exemplo n.º 7
0
def MSAColumnGlobalAttentionTest(args, config, global_config):
    feat, params, res = load_data(args, 'MSAColumnGlobalAttention')
    conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention
    attn = MSAColumnGlobalAttention(conf,
                                    global_config,
                                    msa_dim=feat['msa_act'].shape[-1])

    attn.load_weights_from_af2(params, rel_path='msa_column_global_attention')
    this_res = attn(feat['msa_act'], feat['msa_mask'])

    check_recursive(this_res, res)
Exemplo n.º 8
0
def TriangleMultiplicationIncomingTest(args, config, global_config):
    feat, params, res = load_data(args, 'TriangleMultiplicationIncoming')
    conf = config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming
    attn = TriangleMultiplication(conf,
                                  global_config,
                                  pair_dim=feat['pair_act'].shape[-1])

    attn.load_weights_from_af2(params, rel_path='triangle_multiplication')
    this_res = attn(feat['pair_act'], feat['pair_mask'])

    check_recursive(this_res, res)
Exemplo n.º 9
0
def frame_aligned_point_error_test(args):
    (activations1, activations2, frames_mask,
     pos_mask), res = load_data(args, 'frame_aligned_point_error_test')
    qa1 = affine.QuatAffine.from_tensor(activations1)
    rigs1 = qa1.to_rigids()
    qa2 = affine.QuatAffine.from_tensor(activations2)
    rigs2 = qa2.to_rigids()

    this_res = frame_aligned_point_error(rigs1, rigs2, frames_mask,
                                         rigs1.trans, rigs2.trans, pos_mask,
                                         1.0)
    check_recursive(res, this_res)
Exemplo n.º 10
0
def GlobalAttentionTest(args, config, global_config):
    feat, params, res = load_data(args, 'GlobalAttention')

    conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias
    attn_opt = GlobalAttentionOpt(conf,
                                  global_config,
                                  output_dim=256,
                                  key_dim=feat['q_data'].shape[-1],
                                  value_dim=feat['m_data'].shape[-1])
    attn_opt.load_weights_from_af2(params['attention'], None)
    attn_vanilla = GlobalAttention(conf,
                                   global_config,
                                   output_dim=256,
                                   key_dim=feat['q_data'].shape[-1],
                                   value_dim=feat['m_data'].shape[-1])
    attn_vanilla.load_weights_from_af2(params['attention'], None)

    attn_vanilla.cuda()
    feat['q_data'] = feat['q_data'].to(device='cuda', dtype=torch.float32)
    feat['m_data'] = feat['m_data'].to(device='cuda', dtype=torch.float32)
    feat['q_mask'] = feat['q_mask'].to(device='cuda', dtype=torch.float32)
    handler_vanilla = torch.profiler.tensorboard_trace_handler(
        Path('Log') / Path('GlobalAttention'))
    with torch.profiler.profile(on_trace_ready=handler_vanilla,
                                with_stack=True,
                                with_modules=True,
                                profile_memory=True,
                                record_shapes=True) as profiler:
        res_vanilla = attn_vanilla(
            q_data=feat['q_data'],
            m_data=feat['m_data'],
            q_mask=feat['q_mask'].to(dtype=torch.float32),
            bias=feat['bias'])
        profiler.step()

    attn_opt.cuda()
    reporter = MemReporter()
    handler_opt = torch.profiler.tensorboard_trace_handler(
        Path('Log') / Path('GlobalAttentionOpt'))
    with torch.profiler.profile(on_trace_ready=handler_opt,
                                with_stack=True,
                                with_modules=True,
                                profile_memory=True,
                                record_shapes=True) as profiler:
        res_opt = attn_opt(q_data=feat['q_data'],
                           m_data=feat['m_data'],
                           q_mask=feat['q_mask'].to(dtype=torch.float32),
                           bias=feat['bias'])
        profiler.step()
    reporter.report()

    check_recursive(res_opt, res_vanilla)
Exemplo n.º 11
0
def MSAColumnAttentionTest(args, config, global_config, is_training=False):
    feat, params, res = load_data(args, 'MSAColumnAttention')

    conf = config.model.embeddings_and_evoformer.evoformer.msa_column_attention
    attn_opt = MSAColumnAttentionFF(conf,
                                    global_config,
                                    msa_dim=feat['msa_act'].shape[-1])
    attn_opt.load_weights_from_af2(params, rel_path='msa_column_attention')
    attn_vanilla = MSAColumnAttentionOpt(conf,
                                         global_config,
                                         msa_dim=feat['msa_act'].shape[-1])
    attn_vanilla.load_weights_from_af2(params, rel_path='msa_column_attention')

    attn_vanilla.cuda()
    feat['msa_act'] = feat['msa_act'].to(device='cuda', dtype=torch.float32)
    feat['msa_mask'] = feat['msa_mask'].to(device='cuda', dtype=torch.float32)

    reporter = MemReporter()
    handler_vanilla = torch.profiler.tensorboard_trace_handler(
        Path('Log') / Path('MSAColumnAttention'))
    with torch.profiler.profile(on_trace_ready=handler_vanilla,
                                with_stack=True,
                                with_modules=True,
                                profile_memory=True,
                                record_shapes=True) as profiler:
        res_vanilla = attn_vanilla(feat['msa_act'],
                                   feat['msa_mask'],
                                   is_training=is_training)
        profiler.step()
    # reporter.report()

    attn_opt.cuda()
    reporter = MemReporter()
    handler_opt = torch.profiler.tensorboard_trace_handler(
        Path('Log') / Path('MSAColumnAttentionOpt'))
    with torch.profiler.profile(on_trace_ready=handler_opt,
                                with_stack=True,
                                with_modules=True,
                                profile_memory=True,
                                record_shapes=True) as profiler:
        res_opt = attn_opt(feat['msa_act'],
                           feat['msa_mask'],
                           is_training=is_training)
        profiler.step()
    # reporter.report()

    if isinstance(attn_opt, MSAColumnAttentionFF):
        check_recursive(res_opt, res_vanilla + feat['msa_act'])
    else:
        check_recursive(res_opt, res_vanilla)
Exemplo n.º 12
0
def InvariantPointAttentionTest(args, config, global_config):
    print('InvariantPointAttentionTest')
    feat, params, res = load_data(args, 'InvariantPointAttention')
    conf = config.model.heads.structure_module

    attn_single = InvariantPointAttention(
        conf,
        global_config,
        num_feat_1d=feat['inputs_1d'].shape[-1],
        num_feat_2d=feat['inputs_2d'].shape[-1])
    attn_single.load_weights_from_af2(params,
                                      rel_path='invariant_point_attention')
    attn_batch = InvariantPointAttentionB(
        conf,
        global_config,
        num_feat_1d=feat['inputs_1d'].shape[-1],
        num_feat_2d=feat['inputs_2d'].shape[-1])
    attn_batch.load_weights_from_af2(params,
                                     rel_path='invariant_point_attention')

    print('inputs1d:', feat['inputs_1d'].size())
    print('inputs2d:', feat['inputs_2d'].size())
    print('activations:', feat['activations'].size())
    print('mask:', feat['mask'].size())
    batch_size = 8
    inputs_1d_batch = feat['inputs_1d'][None, ...].repeat(batch_size, 1, 1)
    inputs_2d_batch = feat['inputs_2d'][None, ...].repeat(batch_size, 1, 1, 1)
    activations_batch = feat['activations'][None, ...].repeat(batch_size, 1, 1)
    mask_batch = feat['mask'][None, ...].repeat(batch_size, 1, 1)

    qa_single = QuatAffine.from_tensor(
        feat['activations'].to(dtype=torch.float32))
    qa_batch = QuatAffine.from_tensor(
        activations_batch.to(dtype=torch.float32))

    res_single = attn_single(inputs_1d=feat['inputs_1d'],
                             inputs_2d=feat['inputs_2d'],
                             mask=feat['mask'],
                             affine=qa_single)
    res_batch = attn_batch(inputs_1d=inputs_1d_batch,
                           inputs_2d=inputs_2d_batch,
                           mask=mask_batch,
                           affine=qa_batch)
    print(check_recursive(res_single, res))
    print(check_recursive(res_batch[0, ...], res))

    for i in range(batch_size):
        err = torch.sum(torch.abs(res_batch[i, ...] - res_single))
        print(i, err.item())
        assert err < 1e-2
Exemplo n.º 13
0
def TransitionTest(args, config, global_config):
    feat, params, res = load_data(args, 'Transition')
    conf = config.model.embeddings_and_evoformer.evoformer.pair_transition
    attn = Transition(conf,
                      global_config,
                      num_channel=feat['seq_act'].shape[-1])
    for key in params.keys():
        print(key)
        for param in params[key].keys():
            print('\t' + param)
    attn.load_weights_from_af2(params, rel_path='transition_block')
    this_res = attn(feat['seq_act'], feat['seq_mask'])

    check_recursive(this_res, res)
Exemplo n.º 14
0
def TriangleAttentionTest(args, config, global_config):
    feat, params, res = load_data(args, 'TriangleAttention')
    conf = config.model.embeddings_and_evoformer.evoformer.triangle_attention_starting_node
    attn = TriangleAttention(conf,
                             global_config,
                             pair_dim=feat['pair_act'].shape[-1])
    for key in params.keys():
        print(key)
        for param in params[key].keys():
            print('\t' + param)
    attn.load_weights_from_af2(params, rel_path='triangle_attention')
    this_res = attn(feat['pair_act'], feat['pair_mask'])

    check_recursive(this_res, res)
Exemplo n.º 15
0
def create_extra_msa_features_test(args, config, global_config):
    feat, params, res = load_data(args, 'create_extra_msa_feature')
    conf = config.model.embeddings_and_evoformer
    # for key in params.keys():
    # 	print(key)
    # 	for param in params[key].keys():
    # 		print('\t' + param + '  ' + str(params[key][param].shape))
    for key in feat.keys():
        print(key, feat[key].shape)

    emb = ExtraMSAEmbedding(conf,
                            global_config,
                            msa_dim=feat['msa_feat'].shape[-1])
    this_res = emb.create_extra_msa_features(feat)
    check_recursive(this_res, res)
Exemplo n.º 16
0
def OuterProductMeanTest(args, config, global_config):
    feat, params, res = load_data(args, 'OuterProductMean')
    conf = config.model.embeddings_and_evoformer.evoformer.outer_product_mean
    attn = OuterProductMean(conf,
                            global_config,
                            msa_dim=feat['msa_act'].shape[-1],
                            num_output_channel=256)
    for key in params.keys():
        print(key)
        for param in params[key].keys():
            print('\t' + param)
    attn.load_weights_from_af2(params, rel_path='outer_product_mean')
    this_res = attn(feat['msa_act'], feat['msa_mask'])

    check_recursive(this_res, res)
Exemplo n.º 17
0
def test_torsion_angles_to_frames(args):
	print('test_torsion_angles_to_frames')
	(activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_torsion_angles_to_frames')
	rigs = QuatAffine.from_tensor(activations).to_rigids()
	this_res = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos)
	this_res = affine.rigids_to_tensor_flat12(this_res)
	print(check_recursive(this_res, res))
Exemplo n.º 18
0
def GlobalAttentionTest(args, config, global_config):
    feat, params, res = load_data(args, 'GlobalAttention')

    conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias
    attn = GlobalAttention(conf,
                           global_config,
                           output_dim=256,
                           key_dim=feat['q_data'].shape[-1],
                           value_dim=feat['m_data'].shape[-1])
    attn.load_weights_from_af2(params['attention'], None)
    this_res = attn(q_data=feat['q_data'],
                    m_data=feat['m_data'],
                    q_mask=feat['q_mask'],
                    bias=feat['bias'])

    check_recursive(this_res, res)
Exemplo n.º 19
0
def pre_compose_test(args, name):
	(tensor, update), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	this_res = qa.pre_compose(update).to_tensor()
	err, max_err, mean_err = check_recursive(res, this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Exemplo n.º 20
0
def invert_point_test(args, name):
	(tensor, point), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	this_res = qa.invert_point(point)
	err, max_err, mean_err = check_recursive(res, this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Exemplo n.º 21
0
def init_test(args, name):
	args, res = load_data(args, name)
	qa = QuatAffine.from_tensor(*args)
	this_res = qa.quaternion, qa.translation, qa.rotation
	for a,b in zip(res, this_res):
		err, max_err, mean_err = check_recursive(a,b)
		print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
		print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Exemplo n.º 22
0
def to_tensor_test(args, name):
	(tensor, ), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0)
	this_res = qa.to_tensor()
	err, max_err, mean_err = check_recursive(res,this_res)
	print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
	print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Exemplo n.º 23
0
def MSARowAttentionWithPairBiasTest(args, config, global_config):
    feat, params, res = load_data(args, 'MSARowAttentionWithPairBias')
    # for key in params.keys():
    # 	print(key)
    # 	for param in params[key].keys():
    # 		print('\t' + param)
    conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias
    # conf.gating = False
    attn = MSARowAttentionWithPairBias(conf,
                                       global_config,
                                       pair_dim=feat['pair_act'].shape[-1],
                                       msa_dim=feat['msa_act'].shape[-1])
    attn.load_weights_from_af2(params,
                               rel_path='msa_row_attention_with_pair_bias')
    this_res = attn(feat['msa_act'], feat['msa_mask'], feat['pair_act'])

    check_recursive(this_res, res)
Exemplo n.º 24
0
def test_frames_and_literature_positions_to_atom14_pos(args):
	print('test_frames_and_literature_positions_to_atom14_pos')
	(activations, aatype, torsion_angles_sin_cos), res = load_data(args, 'test_frames_and_literature_positions_to_atom14_pos')
	rigs = QuatAffine.from_tensor(activations).to_rigids()
	all_frames = torsion_angles_to_frames(aatype=aatype, backb_to_global=rigs, torsion_angles_sin_cos=torsion_angles_sin_cos)
	this_res = frames_and_literature_positions_to_atom14_pos(aatype=aatype, all_frames_to_global=all_frames)
	this_res = affine.vecs_to_tensor(this_res)
	print(check_recursive(this_res, res))
Exemplo n.º 25
0
def apply_rot_func_test(args, name):
	(tensor, ), res = load_data(args, name)
	qa = QuatAffine.from_tensor(tensor)
	qa = qa.apply_rotation_tensor_fn(lambda t: t+1.0)
	this_res = qa.quaternion, qa.translation, qa.rotation
	for a,b in zip(res, this_res):
		err, max_err, mean_err = check_recursive(a,b)
		print(f'Max error = {max_err}, mean error = {mean_err} total error = {err}')
		print(f'Success = {(max_err < 1e-4) and (mean_err < 1e-5)}')
Exemplo n.º 26
0
def EvoformerIterationTest1(args, config, global_config):
    feat, params, res = load_data(args, 'EvoformerIteration1')
    conf = config.model.embeddings_and_evoformer.evoformer

    attn = EvoformerIteration(conf,
                              global_config,
                              msa_dim=feat['msa_act'].shape[-1],
                              pair_dim=feat['pair_act'].shape[-1],
                              is_extra_msa=False)
    attn.load_weights_from_af2(params, rel_path='evoformer_iteration')

    this_res = attn(msa_act=feat['msa_act'],
                    pair_act=feat['pair_act'],
                    msa_mask=feat['msa_mask'].float(),
                    pair_mask=feat['pair_mask'].float(),
                    is_training=False)
    this_res = {'msa': this_res[0], 'pair': this_res[1]}
    check_recursive(this_res, res)
Exemplo n.º 27
0
def AttentionTest(args, config, global_config):
    feat, params, res = load_data(args, 'Attention')
    # for param in params['attention'].keys():
    # 	print(param)

    conf = config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias
    # conf.gating = False
    attn = Attention(conf,
                     global_config,
                     output_dim=256,
                     key_dim=feat['q_data'].shape[-1],
                     value_dim=feat['m_data'].shape[-1])
    attn.load_weights_from_af2(params['attention'], None)
    this_res = attn(q_data=feat['q_data'],
                    m_data=feat['m_data'],
                    bias=feat['bias'],
                    nonbatched_bias=feat['nonbatched_bias'])

    check_recursive(this_res, res)
Exemplo n.º 28
0
def InvariantPointAttentionTest(args, config, global_config):
	print('InvariantPointAttentionTest')
	feat, params, res = load_data(args, 'InvariantPointAttention')
	conf = config.model.heads.structure_module
	
	attn = InvariantPointAttention(	conf, global_config, 
									num_res=feat['inputs_1d'].shape[-2], 
									num_feat_1d=feat['inputs_1d'].shape[-1],
									num_feat_2d=feat['inputs_2d'].shape[-1])
	attn.load_weights_from_af2(params, rel_path='invariant_point_attention')
	
	qa = QuatAffine.from_tensor(feat['activations'].to(dtype=torch.float32))
	this_res = attn(inputs_1d = feat['inputs_1d'], inputs_2d = feat['inputs_2d'], mask=feat['mask'], affine=qa)
	print(check_recursive(this_res, res))
Exemplo n.º 29
0
def PredictedAlignedErrorHeadTest(args, config, global_config):
	print('PredictedAlignedErrorHeadTest')
	feat, params, res = load_data(args, 'PredictedAlignedErrorHead')
	conf = config.model.heads.predicted_aligned_error
	representations = feat['representations']
	batch = feat['batch']
	
	attn = PredictedAlignedErrorHead(conf, global_config, 
						num_feat_2d=representations['pair'].shape[-1]
						)
						
	attn.load_weights_from_af2(params, rel_path='predicted_aligned_error_head')
	
	this_res = attn(representations, batch)
	print(check_recursive(this_res, res))
Exemplo n.º 30
0
def DistogramHeadTest(args, config, global_config):
	print('DistogramHeadTest')
	feat, params, res = load_data(args, 'DistogramHead')
	conf = config.model.heads.distogram
	representations = feat['representations']
	batch = feat['batch']
	
	attn = DistogramHead(conf, global_config, 
						num_feat_2d=representations['pair'].shape[-1]
						)
						
	attn.load_weights_from_af2(params, rel_path='distogram_head')
	
	this_res = attn(representations, batch)
	print(check_recursive(this_res, res))