Beispiel #1
0
                             inputs_2d=feat['inputs_2d'],
                             mask=feat['mask'],
                             affine=qa_single)
    res_batch = attn_batch(inputs_1d=inputs_1d_batch,
                           inputs_2d=inputs_2d_batch,
                           mask=mask_batch,
                           affine=qa_batch)
    print(check_recursive(res_single, res))
    print(check_recursive(res_batch[0, ...], res))

    for i in range(batch_size):
        err = torch.sum(torch.abs(res_batch[i, ...] - res_single))
        print(i, err.item())
        assert err < 1e-2


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train deep protein docking')
    parser.add_argument('-debug_dir',
                        default='/home/lupoglaz/Projects/alphafold/Debug',
                        type=str)

    args = parser.parse_args()
    config = model_config('model_1')
    global_config = config.model.global_config

    # TriangleAttentionTest(args, config, global_config, is_training=True)
    # TriangleMultiplicationTest(args, config, global_config, is_training=True)
    # OuterProductMeanTest(args, config, global_config, is_training=True)
    # TransitionTest(args, config, global_config, is_training=True)
    InvariantPointAttentionTest(args, config, global_config)
Beispiel #2
0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train deep protein docking')
    parser.add_argument('-fasta_path', default='T1024.fas', type=str)
    parser.add_argument('-output_dir',
                        default='/media/lupoglaz/AlphaFold2Output',
                        type=str)
    parser.add_argument('-model_name', default='model_1', type=str)
    parser.add_argument('-data_dir',
                        default='/media/lupoglaz/AlphaFold2Data',
                        type=str)

    args = parser.parse_args()

    model_config = model_config(args.model_name)
    model_config.data.eval.num_ensemble = 1
    model_config.data.common.use_templates = False
    af2features = AlphaFoldFeatures(config=model_config)

    features_path = Path(
        args.output_dir) / Path('T1024') / Path('features.pkl')
    proc_features_path = Path(
        args.output_dir) / Path('T1024') / Path('proc_features.pkl')
    with open(features_path, 'rb') as f:
        raw_feature_dict = pickle.load(f)
    with open(proc_features_path, 'rb') as f:
        af2_proc_features = pickle.load(f)

    this_proc_features = af2features(raw_feature_dict, random_seed=42)