Exemplo n.º 1
0
def read_forwards():
    train_forward = LMDBDataset(
        data_file=
        '/export/home/tape/data/alignment/pfam/index_map/pfam_train.lmdb')
    valid_forward = LMDBDataset(
        data_file=
        '/export/home/tape/data/alignment/pfam/index_map/pfam_valid.lmdb')
    holdout_forward = LMDBDataset(
        data_file=
        '/export/home/tape/data/alignment/pfam/index_map/pfam_holdout.lmdb')
    return train_forward, valid_forward, holdout_forward
Exemplo n.º 2
0
def main(args=None):
    if args is None:
        parser = create_parser()
        args = parser.parse_args()

    print(f'Input from {args.input_file}')
    print(f'Output to {args.output_file}')
    num_records = 0
    dataset = LMDBDataset(data_file=args.input_file)
    env = lmdb.open(args.output_file, map_size=args.map_size)
    with env.begin(write=True) as txn:
        for record in tqdm(dataset):
            if isinstance(record[args.length_key], str):
                length = len(record[args.length_key])
            elif isinstance(record[args.length_key], int):
                length = record[args.length_key]
            else:
                raise ValueError(
                    f'Unrecognized length value {record[args.length_key]}')

            if args.maximum_length is not None and \
                length > args.maximum_length:
                continue
            if args.minimum_length is not None and \
                length < args.minimum_length:
                continue

            write_record_to_transaction(num_records, record, txn)
            num_records += 1
        txn.put('num_examples'.encode(), pickle.dumps(num_records))
    print(f'Wrote {num_records}/{len(dataset)} records to {args.output_file}')
Exemplo n.º 3
0
def write_reverse(input_directory,
                  output_directory,
                  map_size=1e+13,
                  split='train',
                  args=None,
                  batch_size=10000):
    data_file = os.path.join(input_directory, f'pfam_{split}.lmdb')
    dataset = LMDBDataset(data_file=data_file)

    print("Creating index map...")
    output_file = os.path.join(output_directory, f'reverse_{split}.lmdb')
    env = lmdb.open(output_file, map_size=map_size)
    with env.begin(write=True) as txn:
        batch_count = 0
        batch_dict = {}
        for item in tqdm(dataset, total=len(dataset)):
            dataset_index = int(item['dataset_index'])
            pfam_id = item['pfam_id']
            species = item['species']
            uniprot_id = item['uniprot_id']

            key = uniprot_id
            batch_count += 1
            if key in batch_dict:
                batch_dict[key].append(dataset_index)
            else:
                batch_dict[key] = [dataset_index]

            if batch_count >= batch_size:
                _add_dict(txn, batch_dict)
                batch_count = 0
                batch_dict = {}
Exemplo n.º 4
0
def get_lengths_from_indices(indices, data_file, sequence_key):
    dataset = LMDBDataset(data_file)
    lengths = []
    for count, i in enumerate(indices):
        if count % int(len(indices) / 5) == 0:
            print(f'{i}/{indices[-1]}')
        length = len(dataset[i][sequence_key])
        lengths.append(length)
    return lengths
Exemplo n.º 5
0
def main(args=None):
    if args is None:
        parser = create_parser()
        args = parser.parse_args()

    dataset = setup_dataset(task=args.task,
                            data_dir=args.data_dir,
                            split=args.split,
                            tokenizer=args.tokenizer)

    family_dataset = None
    if args.restrict_id:
        family_data_file = os.path.join(args.id_map_dir,
                                        f'pfam_{args.split}.lmdb')
        family_dataset = LMDBDataset(data_file=family_data_file)

    write_dataset_as_fasta(args.task, args.split, dataset, args.output_file,
                           args.pfam_id, family_dataset)
Exemplo n.º 6
0
    def __init__(self,
                 data_path: Union[str, Path],
                 split: str,
                 tokenizer: Union[str, TAPETokenizer] = 'iupac',
                 in_memory: bool = False):

        if split not in ('train', 'test', 'valid'):
            raise ValueError(f"Unrecognized split: {split}. Must be one of "
                             f"['train', 'test', 'valid']")

        if isinstance(tokenizer, str):
            # If you get tokenizer in as a string, create an actual tokenizer
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer

        # Define the path to the data file. There are three helper datasets
        # that you can import from tape.datasets - a FastaDataset,
        # a JSONDataset, and an LMDBDataset. You can use these to load raw
        # data from your files (or of course, you can do this manually).
        data_path = Path(data_path)
        data_file = f'deeploc/deeploc_{split}.lmdb'
        self.data = LMDBDataset(data_path / data_file, in_memory=in_memory)
Exemplo n.º 7
0
def write_split(split, data_dir, out_dir, sequence_key, num_jobs):
    data_file = os.path.join(data_dir, f'pfam_{split}.lmdb')
    dataset = LMDBDataset(data_file)

    apply_func = partial(get_lengths_from_indices,
                         data_file=data_file,
                         sequence_key=sequence_key)
    total_indices = list(range(len(dataset)))
    chunk_size = int(len(dataset) / num_jobs) + 1
    chunks = [
        list(range(i, min(i + chunk_size, len(dataset))))
        for i in range(0, len(dataset), chunk_size)
    ]

    out_file = os.path.join(out_dir, f'{split}_lengths.pkl')
    print(f'Writing lengths from {data_file} to {out_file}')
    with Pool(num_jobs) as pool:
        length_batches = pool.map(apply_func, chunks)
    sequence_lengths = [
        length for sublist in length_batches for length in sublist
    ]
    with open(out_file, 'wb') as handle:
        pickle.dump(sequence_lengths, handle)
Exemplo n.º 8
0
def load_real_values(lmdb_filename):
    ds = LMDBDataset(lmdb_filename)
    return {v['id']: float(v['target']) for v in ds}
Exemplo n.º 9
0
import argparse
import math
from tqdm import tqdm
from Bio.SeqIO.FastaIO import Seq, SeqRecord
from tape.datasets import LMDBDataset
from pathlib import Path

parser = argparse.ArgumentParser(
    description='Convert an lmdb file into a fasta file')
parser.add_argument('lmdbdir',
                    type=str,
                    help='The dir with lmdb files to convert')
args = parser.parse_args()

for path in Path(args.lmdbdir).rglob('*.lmdb'):
    dataset = LMDBDataset(path)
    id_fill = math.ceil(math.log10(len(dataset)))
    print(id_fill)
    fastafile = str(path).replace('lmdb', 'fasta')
    with open(fastafile, 'w') as outfile:
        for i, element in enumerate(tqdm(dataset)):
            id_ = element.get('id', str(i).zfill(id_fill))
            if isinstance(id_, bytes):
                id_ = id_.decode()

            primary = element['primary']
            seq = Seq(primary)
            record = SeqRecord(seq,
                               id_,
                               description=path.name.replace('.lmdb', ''))
            outfile.write(record.format('fasta'))
Exemplo n.º 10
0
def get_dataset(data_dir, split):
    data_file = os.path.join(data_dir, f'pfam_{split}.lmdb')
    dataset = LMDBDataset(data_file=data_file)
    return dataset
Exemplo n.º 11
0
import argparse
import math
from tqdm import tqdm
from Bio.SeqIO.FastaIO import Seq, SeqRecord
from tape.datasets import LMDBDataset

parser = argparse.ArgumentParser(
    description='Convert an lmdb file into a fasta file')
parser.add_argument('lmdbfile', type=str, help='The lmdb file to convert')
parser.add_argument('fastafile', type=str, help='The fasta file to output')
args = parser.parse_args()

dataset = LMDBDataset(args.lmdbfile)

id_fill = math.ceil(math.log10(len(dataset)))

fastafile = args.fastafile
if not fastafile.endswith('.fasta'):
    fastafile += '.fasta'

with open(fastafile, 'w') as outfile:
    for i, element in enumerate(tqdm(dataset)):
        id_ = element.get('id', str(i).zfill(id_fill))
        if isinstance(id_, bytes):
            id_ = id_.decode()

        primary = element['primary']
        seq = Seq(primary)
        record = SeqRecord(seq, id_)
        outfile.write(record.format('fasta'))