"""Label structures with deviation.""" import click import pandas as pd import parallel as par import atom3d.shard.shard as sh import atom3d.util.log as log logger = log.get_logger('rsr_label') @click.command(help='Label QM9 structures.') @click.argument('sharded_path', type=click.Path()) @click.argument('csv_file', type=click.Path(exists=True)) @click.option('-n', '--num_threads', default=8, help='Number of threads to use for parallel processing.') @click.option('--overwrite/--no-overwrite', default=False, help='Overwrite existing labels.') def gen_labels_sharded(sharded_path, csv_file, num_threads, overwrite): sharded = sh.Sharded.load(sharded_path) num_shards = sharded.get_num_shards() requested_shards = list(range(num_shards)) if not overwrite: produced_shards = [ x for x in requested_shards if sharded.has(x, 'labels') ] else:
import os import subprocess import Bio.PDB.Polypeptide as Poly import dotenv as de import numpy as np import tqdm from Bio import SeqIO from Bio.Blast.Applications import NcbiblastpCommandline import atom3d.util.file as fi import atom3d.util.log as log project_root = os.path.abspath(os.path.join(__file__, '../../..')) de.load_dotenv(os.path.join(project_root, '.env')) logger = log.get_logger('sequence') def find_similar(chain_sequences, blast_db, cutoff, num_alignments): """Find all other pdbs that have sequence identity greater than cutoff.""" if 'BLAST_BIN' not in os.environ: raise RuntimeError('Need to set BLAST_BIN in .env to use blastp') if not (0 <= cutoff <= 100): raise Exception('cutoff need to be between 0 and 100') sim = set() for chain, s in chain_sequences: blastp_cline = NcbiblastpCommandline(db=blast_db, outfmt="10 nident sacc", num_alignments=num_alignments,
"""Code for preparing a pairs dataset (filtering and splitting).""" import numpy as np import pandas as pd import click import atom3d.filters.pdb import atom3d.filters.sequence import atom3d.splits.splits as splits import atom3d.filters.filters as filters import atom3d.shard.shard as sh import atom3d.shard.shard_ops as sho import atom3d.util.file as fi import atom3d.util.log as log logger = log.get_logger('prepare') def split(input_sharded, output_root, scaffold_data, shuffle_buffer): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') logger.info('Splitting by scaffold') scaffold_list = scaffold_data['Scaffold'].tolist() train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list) train = scaffold_data['pdb'][train_idx].tolist() val = scaffold_data['pdb'][val_idx].tolist() test = scaffold_data['pdb'][test_idx].tolist() keys = input_sharded.get_keys() if keys != ['ensemble']:
"""Functions for splitting data into test, validation, and training sets.""" from functools import partial import numpy as np import torch import atom3d.util.log as log logger = log.get_logger('splits') def split(dataset, indices_train, indices_val, indices_test): train_dataset = torch.utils.data.Subset(dataset, indices_train) val_dataset = torch.utils.data.Subset(dataset, indices_val) test_dataset = torch.utils.data.Subset(dataset, indices_test) logger.info(f'Size of the training set: {len(indices_train):}') logger.info(f'Size of the validation set: {len(indices_val):}') logger.info(f'Size of the test set: {len(indices_test):}') return train_dataset, val_dataset, test_dataset def read_split_file(split_file): """ Read text file with pre-defined split, returning list of examples. One example per row in text file. """ with open(split_file) as f: # file may contain integer indices or string identifiers (e.g. PDB
"""Methods to extract protein interface labels pair.""" import click import numpy as np import pandas as pd import scipy.spatial as spa import atom3d.shard.shard as sh import atom3d.util.log as log logger = log.get_logger('neighbors') index_columns = \ ['ensemble', 'subunit', 'structure', 'model', 'chain', 'residue'] @click.command(help='Find neighbors for entry in sharded.') @click.argument('sharded_path', type=click.Path()) @click.argument('ensemble') @click.argument('output_labels', type=click.Path()) @click.option('-c', '--cutoff', type=int, default=8, help='Maximum distance (in angstroms), for two residues to be ' 'considered neighbors.') @click.option('--cutoff-type', default='CA', type=click.Choice(['heavy', 'CA'], case_sensitive=False), help='How to compute distance between residues: CA is based on ' 'alpha-carbons, heavy is based on any heavy atom.') def get_neighbors_main(sharded_path, ensemble, output_labels, cutoff, cutoff_type): sharded = sh.Sharded.load(sharded_path) ensemble = sharded.read_keyed(ensemble)
"""Label mutation pairs as beneficial or detrimental.""" import click import pandas as pd import parallel as par import atom3d.shard.shard as sh import atom3d.util.log as log logger = log.get_logger('msp_label') @click.command(help='Label SKEMPI pairs with good/bad label.') @click.argument('sharded_path', type=click.Path()) @click.argument('data_csv', type=click.Path(exists=True)) @click.option('-n', '--num_threads', default=8, help='Number of threads to use for parallel processing.') @click.option('--overwrite/--no-overwrite', default=False, help='Overwrite existing labels.') def gen_labels_sharded(sharded_path, data_csv, num_threads, overwrite): sharded = sh.Sharded.load(sharded_path) num_shards = sharded.get_num_shards() requested_shards = list(range(num_shards)) if not overwrite: produced_shards = [ x for x in requested_shards if sharded.has(x, 'labels') ]
"""TODO: This code has been significantly re-written and should be tested.""" import math import random import numpy as np import atom3d.protein.sequence as seq import atom3d.util.file as fi import atom3d.util.log as log import atom3d.splits.splits as splits logger = log.get_logger('sequence_splits') #################################### # split by pre-clustered sequence # identity clusters from PDB #################################### def cluster_split(dataset, cutoff, val_split=0.1, test_split=0.1, min_fam_in_split=5, random_seed=None): """ Splits pdb dataset using pre-computed sequence identity clusters from PDB. Generates train, val, test sets. We assume there is one PDB code per entry in dataset.
"""Code to generate pair ensembles.""" import click import pandas as pd import parallel as par import atom3d.datasets.ppi.neighbors as nb import atom3d.util.log as log logger = log.get_logger('shard_pairs') def _gen_pairs_per_ensemble(x, cutoff, cutoff_type): pairs = [] if len(x['subunit'].unique()) > 1: raise RuntimeError('Cannot find pairs on existing ensemble') # Only keep first model. x = x[x['model'] == sorted(x['model'].unique())[0]] names, subunits = _gen_subunits(x) for i in range(len(subunits)): for j in range(i + 1, len(subunits)): curr = nb.get_neighbors(subunits[i], subunits[j], cutoff, cutoff_type) if len(curr) > 0: tmp0 = subunits[i].copy() tmp0['subunit'] = names[i] tmp1 = subunits[j].copy() tmp1['subunit'] = names[j] pair = pd.concat([tmp0, tmp1]) pair['ensemble'] = names[i] + '_' + names[j] pairs.append(pair)
"""Code for preparing a lep dataset (filtering and splitting).""" import click import pandas as pd import atom3d.filters.filters as filters import atom3d.shard.shard as sh import atom3d.shard.shard_ops as sho import atom3d.util.log as log import atom3d.splits.splits as splits logger = log.get_logger('lep_prepare') def split(input_sharded, output_root, info_csv, shuffle_buffer): """Split by protein.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') info = pd.read_csv(info_csv) info['ensemble'] = info.apply( lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[ 2] + '__' + x['inactive_struc'].split('_')[2], axis=1) info = info.set_index('ensemble') # Remove duplicate ensembles. info = info[~info.index.duplicated()] ensembles = input_sharded.get_names()['ensemble'] in_use = info.loc[ensembles] active = in_use[in_use['label'] == 'A'] inactive = in_use[in_use['label'] == 'I']
"""Code for preparing rsr dataset (splitting).""" import click import atom3d.datasets.rsr.score as sc import atom3d.filters.filters as filters import atom3d.shard.shard as sh import atom3d.shard.shard_ops as sho import atom3d.util.log as log logger = log.get_logger('rsr_prepare') # Canonical splits. TRAIN = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13'] VAL = ['14b', '14f', '15', '17'] TEST = ['18', '19', '20', '21'] def split(input_sharded, output_root): """Split temporally and shuffle examples.""" prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) # Re-key to ensemble, subunit. To allow for shuffling across targets. tmp_sharded = sh.Sharded( f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit']) logger.info("Rekeying") sho.rekey(input_sharded, tmp_sharded) keys = tmp_sharded.get_keys() train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
"""Label pairs as active or inactive.""" import click import pandas as pd import parallel as par import atom3d.shard.shard as sh import atom3d.util.log as log logger = log.get_logger('lep_label') @click.command(help='Label LEP pairs with inactive/active label.') @click.argument('sharded_path', type=click.Path()) @click.argument('info_csv', type=click.Path(exists=True)) @click.option('-n', '--num_threads', default=8, help='Number of threads to use for parallel processing.') @click.option('--overwrite/--no-overwrite', default=False, help='Overwrite existing labels.') def gen_labels_sharded(sharded_path, info_csv, num_threads, overwrite): sharded = sh.Sharded.load(sharded_path) num_shards = sharded.get_num_shards() requested_shards = list(range(num_shards)) if not overwrite: produced_shards = [ x for x in requested_shards if sharded.has(x, 'labels') ] else:
"""Generate protein interfaces labels for sharded dataset.""" import warnings import click import pandas as pd import parallel as par import atom3d.datasets.ppi.neighbors as nb import atom3d.shard.shard as sh import atom3d.util.log as log warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning) logger = log.get_logger('genLabels') @click.command(help='Find neighbors for sharded dataset.') @click.argument('sharded_path', type=click.Path()) @click.option('-c', '--cutoff', type=int, default=8, help='Maximum distance (in angstroms), for two residues to be ' 'considered neighbors.') @click.option('--cutoff-type', default='CA', type=click.Choice(['heavy', 'CA'], case_sensitive=False), help='How to compute distance between residues: CA is based on ' 'alpha-carbons, heavy is based on any heavy atom.') @click.option('-n', '--num_threads', default=8, help='Number of threads to use for parallel processing.') @click.option('--overwrite/--no-overwrite', default=False, help='Overwrite existing neighbors.') def get_neighbors_sharded(sharded_path, cutoff, cutoff_type, num_threads, overwrite):
"""Generate BSA database for sharded dataset.""" import multiprocessing as mp import os import timeit import click import pandas as pd import parallel as par import atom3d.datasets.ppi.bsa as bsa import atom3d.datasets.ppi.neighbors as nb import atom3d.shard.shard as sh import atom3d.util.log as log logger = log.get_logger('bsa') db_sem = mp.Semaphore() @click.command(help='Generate Buried Surface Area database for sharded.') @click.argument('sharded_path', type=click.Path()) @click.argument('output_bsa', type=click.Path()) @click.option('-n', '--num_threads', default=8, help='Number of threads to use for parallel processing.') def bsa_db(sharded_path, output_bsa, num_threads): sharded = sh.Sharded.load(sharded_path) num_shards = sharded.get_num_shards() dirname = os.path.dirname(output_bsa)