def make_pfam_pairs_loader(root_dir, extra_keys=('fam_key', ), task='pfam34_pairs/iid_ood_clans', branch_key='ci_100', suffixes=('1', '2')): """Creates a Loader for pre-paired Pfam-A seed data.""" has_context = task.endswith('with_ctx') input_sequence_key = [] for key in ('seq', 'full_seq') if has_context else ('seq', ): for suffix in suffixes: input_sequence_key.append(f'{key}_{suffix}') input_sequence_key = tuple(input_sequence_key) output_sequence_key = [] for key in ('sequence', 'full_sequence') if has_context else ('sequence', ): for suffix in suffixes: output_sequence_key.append(f'{key}_{suffix}') output_sequence_key = tuple(output_sequence_key) specs = {} for key in input_sequence_key: specs[key] = tf.io.VarLenFeature(tf.int64) for suffix in suffixes: for key in extra_keys: specs[f'{key}_{suffix}'] = tf.io.FixedLenFeature([], tf.int64) return loaders.TFRecordsLoader(folder=os.path.join(root_dir, task, branch_key), coder=serialization.FlatCoder(specs=specs), input_sequence_key=input_sequence_key, output_sequence_key=output_sequence_key)
def test_flat_coder_not_flat(self): coder = serialization.FlatCoder(specs={ 'seq': tf.float32, 'name': tf.string, 'multi': tf.float32 }) with self.assertRaises(TypeError): coder.encode(self.example)
def serialize(example, pid_ths): int_keys = ['seq_key', 'start', 'end', 'cla_key', 'fam_key', 'seq_len'] int_keys.extend([f'ci_{pid_th}' for pid_th in pid_ths]) str_keys = ['id', 'ac', 'ss'] keys = int_keys + str_keys + ['sequence'] all_specs = {k: v.dtype for k, v in example.items() if k in keys} coder = serialization.FlatCoder(specs=all_specs) for k in int_keys: example[k] = [example[k]] return coder.encode(example)
def test_flat_coder(self): coder = serialization.FlatCoder(specs={ 'seq': tf.float32, 'name': tf.string, }) encoded = coder.encode(self.example) decoded = coder.decode(encoded) reencoded = coder.encode(decoded) redecoded = coder.decode(reencoded) self.assertAllClose(redecoded['seq'], self.example['seq'], atol=1e-7) self.assertEqual(redecoded['name'], self.example['name'])
def make_pfam34_loader( root_dir, sub_dir = 'pfam34', extra_keys = ('fam_key',), task = 'iid'): """Creates a loader for Pfam-A seed 34.0 data.""" has_context = task.endswith('with_ctx') folder = os.path.join(root_dir, sub_dir) with tf.io.gfile.GFile(os.path.join(folder, 'metadata.json'), 'r') as f: metadata = json.load(f) all_specs = { 'seq': tf.io.VarLenFeature(tf.int64), 'seq_key': tf.io.FixedLenFeature([], tf.int64), 'fam_key': tf.io.FixedLenFeature([], tf.int64), 'cla_key': tf.io.FixedLenFeature([], tf.int64), 'seq_len': tf.io.FixedLenFeature([], tf.int64), 'id': tf.io.FixedLenFeature([], tf.string), 'ac': tf.io.FixedLenFeature([], tf.string), 'start': tf.io.FixedLenFeature([], tf.int64), 'end': tf.io.FixedLenFeature([], tf.int64), 'ss': tf.io.FixedLenFeature([], tf.string), } for pid_th in metadata['pid_ths']: all_specs[f'ci_{pid_th}'] = tf.io.FixedLenFeature([], tf.int64) if has_context: all_specs['full_seq'] = tf.io.VarLenFeature(tf.int64) # Silently ignores misspecified keys. input_sequence_key = (('seq', 'full_seq') if has_context else ('seq',)) output_sequence_key = (('sequence', 'full_sequence') if has_context else ('sequence',)) extra_keys = (extra_keys,) if isinstance(extra_keys, str) else extra_keys specs = { k: all_specs[k] for k in input_sequence_key + tuple(extra_keys) if k in all_specs } return loaders.TFRecordsLoader( folder=os.path.join(folder, task), coder=serialization.FlatCoder(specs=specs), input_sequence_key=input_sequence_key, output_sequence_key=output_sequence_key)