Пример #1
0
def make_pfam_pairs_loader(root_dir,
                           extra_keys=('fam_key', ),
                           task='pfam34_pairs/iid_ood_clans',
                           branch_key='ci_100',
                           suffixes=('1', '2')):
    """Creates a Loader for pre-paired Pfam-A seed data."""
    has_context = task.endswith('with_ctx')

    input_sequence_key = []
    for key in ('seq', 'full_seq') if has_context else ('seq', ):
        for suffix in suffixes:
            input_sequence_key.append(f'{key}_{suffix}')
    input_sequence_key = tuple(input_sequence_key)

    output_sequence_key = []
    for key in ('sequence',
                'full_sequence') if has_context else ('sequence', ):
        for suffix in suffixes:
            output_sequence_key.append(f'{key}_{suffix}')
    output_sequence_key = tuple(output_sequence_key)

    specs = {}
    for key in input_sequence_key:
        specs[key] = tf.io.VarLenFeature(tf.int64)
    for suffix in suffixes:
        for key in extra_keys:
            specs[f'{key}_{suffix}'] = tf.io.FixedLenFeature([], tf.int64)

    return loaders.TFRecordsLoader(folder=os.path.join(root_dir, task,
                                                       branch_key),
                                   coder=serialization.FlatCoder(specs=specs),
                                   input_sequence_key=input_sequence_key,
                                   output_sequence_key=output_sequence_key)
Пример #2
0
 def test_flat_coder_not_flat(self):
     coder = serialization.FlatCoder(specs={
         'seq': tf.float32,
         'name': tf.string,
         'multi': tf.float32
     })
     with self.assertRaises(TypeError):
         coder.encode(self.example)
def serialize(example, pid_ths):
  int_keys = ['seq_key', 'start', 'end', 'cla_key', 'fam_key', 'seq_len']
  int_keys.extend([f'ci_{pid_th}' for pid_th in pid_ths])
  str_keys = ['id', 'ac', 'ss']
  keys = int_keys + str_keys + ['sequence']
  all_specs = {k: v.dtype for k, v in example.items() if k in keys}
  coder = serialization.FlatCoder(specs=all_specs)
  for k in int_keys:
    example[k] = [example[k]]
  return coder.encode(example)
Пример #4
0
 def test_flat_coder(self):
     coder = serialization.FlatCoder(specs={
         'seq': tf.float32,
         'name': tf.string,
     })
     encoded = coder.encode(self.example)
     decoded = coder.decode(encoded)
     reencoded = coder.encode(decoded)
     redecoded = coder.decode(reencoded)
     self.assertAllClose(redecoded['seq'], self.example['seq'], atol=1e-7)
     self.assertEqual(redecoded['name'], self.example['name'])
Пример #5
0
def make_pfam34_loader(
    root_dir,
    sub_dir = 'pfam34',
    extra_keys = ('fam_key',),
    task = 'iid'):
  """Creates a loader for Pfam-A seed 34.0 data."""
  has_context = task.endswith('with_ctx')

  folder = os.path.join(root_dir, sub_dir)
  with tf.io.gfile.GFile(os.path.join(folder, 'metadata.json'), 'r') as f:
    metadata = json.load(f)

  all_specs = {
      'seq': tf.io.VarLenFeature(tf.int64),
      'seq_key': tf.io.FixedLenFeature([], tf.int64),
      'fam_key': tf.io.FixedLenFeature([], tf.int64),
      'cla_key': tf.io.FixedLenFeature([], tf.int64),
      'seq_len': tf.io.FixedLenFeature([], tf.int64),
      'id': tf.io.FixedLenFeature([], tf.string),
      'ac': tf.io.FixedLenFeature([], tf.string),
      'start': tf.io.FixedLenFeature([], tf.int64),
      'end': tf.io.FixedLenFeature([], tf.int64),
      'ss': tf.io.FixedLenFeature([], tf.string),
  }
  for pid_th in metadata['pid_ths']:
    all_specs[f'ci_{pid_th}'] = tf.io.FixedLenFeature([], tf.int64)
  if has_context:
    all_specs['full_seq'] = tf.io.VarLenFeature(tf.int64)

  # Silently ignores misspecified keys.
  input_sequence_key = (('seq', 'full_seq') if has_context
                        else ('seq',))
  output_sequence_key = (('sequence', 'full_sequence') if has_context
                         else ('sequence',))
  extra_keys = (extra_keys,) if isinstance(extra_keys, str) else extra_keys
  specs = {
      k: all_specs[k]
      for k in input_sequence_key + tuple(extra_keys)
      if k in all_specs
  }
  return loaders.TFRecordsLoader(
      folder=os.path.join(folder, task),
      coder=serialization.FlatCoder(specs=specs),
      input_sequence_key=input_sequence_key,
      output_sequence_key=output_sequence_key)