예제 #1
0
def make_pfam_pairs_loader(root_dir,
                           extra_keys=('fam_key', ),
                           task='pfam34_pairs/iid_ood_clans',
                           branch_key='ci_100',
                           suffixes=('1', '2')):
    """Creates a Loader for pre-paired Pfam-A seed data."""
    has_context = task.endswith('with_ctx')

    input_sequence_key = []
    for key in ('seq', 'full_seq') if has_context else ('seq', ):
        for suffix in suffixes:
            input_sequence_key.append(f'{key}_{suffix}')
    input_sequence_key = tuple(input_sequence_key)

    output_sequence_key = []
    for key in ('sequence',
                'full_sequence') if has_context else ('sequence', ):
        for suffix in suffixes:
            output_sequence_key.append(f'{key}_{suffix}')
    output_sequence_key = tuple(output_sequence_key)

    specs = {}
    for key in input_sequence_key:
        specs[key] = tf.io.VarLenFeature(tf.int64)
    for suffix in suffixes:
        for key in extra_keys:
            specs[f'{key}_{suffix}'] = tf.io.FixedLenFeature([], tf.int64)

    return loaders.TFRecordsLoader(folder=os.path.join(root_dir, task,
                                                       branch_key),
                                   coder=serialization.FlatCoder(specs=specs),
                                   input_sequence_key=input_sequence_key,
                                   output_sequence_key=output_sequence_key)
예제 #2
0
def make_sequence_coded_loader(specs, root_dir, folder):
    """Returns a Loader for Proteinnet data."""
    folder = os.path.join(root_dir, folder)
    coder = get_sequence_coder(specs)
    return loaders.TFRecordsLoader(folder,
                                   coder=coder,
                                   input_sequence_key='primary',
                                   split_folder=False)
예제 #3
0
def make_pfam34_loader(
    root_dir,
    sub_dir = 'pfam34',
    extra_keys = ('fam_key',),
    task = 'iid'):
  """Creates a loader for Pfam-A seed 34.0 data."""
  has_context = task.endswith('with_ctx')

  folder = os.path.join(root_dir, sub_dir)
  with tf.io.gfile.GFile(os.path.join(folder, 'metadata.json'), 'r') as f:
    metadata = json.load(f)

  all_specs = {
      'seq': tf.io.VarLenFeature(tf.int64),
      'seq_key': tf.io.FixedLenFeature([], tf.int64),
      'fam_key': tf.io.FixedLenFeature([], tf.int64),
      'cla_key': tf.io.FixedLenFeature([], tf.int64),
      'seq_len': tf.io.FixedLenFeature([], tf.int64),
      'id': tf.io.FixedLenFeature([], tf.string),
      'ac': tf.io.FixedLenFeature([], tf.string),
      'start': tf.io.FixedLenFeature([], tf.int64),
      'end': tf.io.FixedLenFeature([], tf.int64),
      'ss': tf.io.FixedLenFeature([], tf.string),
  }
  for pid_th in metadata['pid_ths']:
    all_specs[f'ci_{pid_th}'] = tf.io.FixedLenFeature([], tf.int64)
  if has_context:
    all_specs['full_seq'] = tf.io.VarLenFeature(tf.int64)

  # Silently ignores misspecified keys.
  input_sequence_key = (('seq', 'full_seq') if has_context
                        else ('seq',))
  output_sequence_key = (('sequence', 'full_sequence') if has_context
                         else ('sequence',))
  extra_keys = (extra_keys,) if isinstance(extra_keys, str) else extra_keys
  specs = {
      k: all_specs[k]
      for k in input_sequence_key + tuple(extra_keys)
      if k in all_specs
  }
  return loaders.TFRecordsLoader(
      folder=os.path.join(folder, task),
      coder=serialization.FlatCoder(specs=specs),
      input_sequence_key=input_sequence_key,
      output_sequence_key=output_sequence_key)