Ejemplo n.º 1
0
def read_tfrecords(path, proto=None, max_records=None, compression_type=None):
    """Yields the parsed records in a TFRecord file path.

  Note that path can be sharded filespec (path@N) in which case this function
  will read each shard in order; i.e. shard 0 will read each entry in order,
  then shard 1, ...

  Args:
    path: String. A path to a TFRecord file containing protos.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    compression_type: 'GZIP', 'ZLIB', '' (uncompressed), or None to autodetect
      based on file extension.

  Yields:
    proto.FromString() values on each record in path in order.
  """
    if sharded_file_utils.is_sharded_file_spec(path):
        paths = sharded_file_utils.generate_sharded_filenames(path)
    else:
        paths = [path]

    i = 0
    for path in paths:
        for record in Reader(path, proto, compression_type=compression_type):
            i += 1
            if max_records is not None and i > max_records:
                return
            yield record
Ejemplo n.º 2
0
def read_sharded_runtime_tsvs(path_string: str) -> pd.DataFrame:
    """Imports data from a single or sharded path into a pandas dataframe.

  Args:
    path_string: The path to the input file, which may be sharded.

  Returns:
    A dataframe matching the TSV file(s) but with added Task column.
  """
    if sharded_file_utils.is_sharded_file_spec(path_string):
        paths = sharded_file_utils.generate_sharded_filenames(path_string)
    else:
        paths = [path_string]
    list_of_dataframes = []
    for i, path in enumerate(paths):
        if path.startswith('gs://'):
            # Once pandas is updated to 0.24+, pd.read_csv will work for gs://
            # without this workaround.
            with tf.io.gfile.GFile(path) as f:
                d = pd.read_csv(f, sep='\t')
        else:
            d = pd.read_csv(path, sep='\t')
        d['Task'] = i
        list_of_dataframes.append(d)

    return pd.concat(list_of_dataframes, axis=0, ignore_index=True)
Ejemplo n.º 3
0
def read_shard_sorted_tfrecords(path,
                                key,
                                proto=None,
                                max_records=None,
                                compression_type=None):
    """Yields the parsed records in a TFRecord file path in sorted order.

  The input TFRecord file must have each shard already in sorted order when
  using the key function for comparison (but elements can be interleaved across
  shards). Under those constraints, the elements will be yielded in a global
  sorted order.

  Args:
    path: String. A path to a TFRecord-formatted file containing protos.
    key: Callable. A function that takes as input a single instance of the proto
      class and returns a value on which the comparison for sorted ordering is
      performed.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    compression_type: 'GZIP', 'ZLIB', '' (uncompressed), or None to autodetect
      based on file extension.

  Yields:
    proto.FromString() values on each record in path in sorted order.
  """
    if sharded_file_utils.is_sharded_file_spec(path):
        paths = sharded_file_utils.generate_sharded_filenames(path)
    else:
        paths = [path]

    keyed_iterables = []
    for path in paths:
        protos = Reader(path, proto,
                        compression_type=compression_type).iterate()
        keyed_iterables.append(((key(elem), elem) for elem in protos))

    for i, (_, value) in enumerate(heapq.merge(*keyed_iterables)):
        if max_records is not None and i >= max_records:
            return
        yield value
Ejemplo n.º 4
0
def run():
  """Create pileup images from examples, filtered in various ways."""
  with errors.clean_commandline_error_exit():
    if FLAGS.column_labels:
      column_labels = FLAGS.column_labels.split(',')
    else:
      column_labels = None

    filter_to_vcf = FLAGS.vcf is not None
    if filter_to_vcf:
      ids_from_vcf = parse_vcf(FLAGS.vcf)
      logging.info(
          'Found %d loci in VCF. '
          'Only examples matching these loci will be output.',
          len(ids_from_vcf))

    filter_to_region = FLAGS.regions is not None
    if filter_to_region:
      passes_region_filter = create_region_filter(
          region_flag_string=FLAGS.regions, verbose=FLAGS.verbose)

    # Use nucleus.io.tfrecord to read all shards.
    dataset = tfrecord.read_tfrecords(FLAGS.examples)

    # Check flag here to avoid expensive string matching on every iteration.
    make_rgb = FLAGS.image_type in ['both', 'RGB']
    make_channels = FLAGS.image_type in ['both', 'channels']

    num_scanned = 0
    num_output = 0
    for example in dataset:
      num_scanned += 1
      # Only when scanning many examples, print a dot for each one to
      # indicate that the script is making progress and not stalled.
      if num_scanned % UPDATE_EVERY_N_EXAMPLES == 0:
        if num_scanned == UPDATE_EVERY_N_EXAMPLES:
          print('Reporting progress below. Writing one dot every time {} '
                'examples have been scanned:'.format(UPDATE_EVERY_N_EXAMPLES))
        # Print another dot on the same line, using print since logging does
        # not support printing without a newline.
        print('.', end='', flush=True)

      # Extract variant from example.
      variant = vis.variant_from_example(example)
      locus_id = vis.locus_id_from_variant(variant)
      indices = vis.alt_allele_indices_from_example(example)

      # Optionally filter to variants in the VCF.
      if filter_to_vcf:
        # Check if the locus is in the VCF.
        if locus_id not in ids_from_vcf:
          # Skip this example since it doesn't match the VCF.
          continue

      if filter_to_region and not passes_region_filter(variant):
        continue

      # Use locus ID in the filename, replacing long alleles with INS/DEL sizes.
      locus_with_alt_id = get_short_id(variant, indices)

      # Examples of long alleles replaced with their sizes:
      # 20:62456134_INS103bp.png
      # 20:62481177_DEL51bp.png

      # Examples of short alleles where the full string is included:
      # 1:55424995_TC->T.png
      # 1:55424996_CT->CTT.png
      # 1:55424996_CT->C.png
      # 1:55424996_CT->TTT.png
      # 1:55424996_CT->C|CTT.png

      if FLAGS.verbose:
        logging.info('\nOutputting image for: %s', locus_with_alt_id)
        full_id = get_full_id(variant, indices)
        if locus_with_alt_id != full_id:
          logging.info(
              'ID above was shortened due to long ref/alt strings. '
              'Original: %s', full_id)

      # If the example has a truth label, optionally include it.
      optional_truth_label = ''
      if FLAGS.truth_labels:
        truth_label = get_label(example)
        if truth_label is not None:
          optional_truth_label = '_label{}'.format(truth_label)

      # Extract and format example into channels.
      channels = vis.channels_from_example(example)
      if column_labels is not None and len(column_labels) != len(channels):
        raise ValueError(
            '--column_labels must have {} names separated by commas, since '
            'there are {} channels in the examples. '
            'However, {} column labels were found: {}'.format(
                len(channels), len(channels), len(column_labels),
                ','.join(['"{}"'.format(x) for x in column_labels])))

      output_prefix = '{}_'.format(
          FLAGS.output) if FLAGS.output is not None else ''

      # Create image with a grey-scale row of channels and save to file.
      if make_channels:
        channels_output = '{}channels_{}{}.png'.format(output_prefix,
                                                       locus_with_alt_id,
                                                       optional_truth_label)

        vis.draw_deepvariant_pileup(
            channels=channels,
            path=channels_output,
            scale=1,
            show=False,
            annotated=FLAGS.annotation,
            labels=column_labels)

      # Create RGB image and save to file.
      if make_rgb:
        rgb_output = '{}rgb_{}{}.png'.format(output_prefix, locus_with_alt_id,
                                             optional_truth_label)
        vis.draw_deepvariant_pileup(
            channels=channels,
            composite_type='RGB',
            path=rgb_output,
            scale=1,
            show=False,
            annotated=FLAGS.annotation,
            labels=column_labels)

      # Check if --num_records quota has been hit yet.
      num_output += 1
      if FLAGS.num_records != -1 and num_output >= FLAGS.num_records:
        break

    logging.info('Scanned %d examples and output %d images.', num_scanned,
                 num_output)

    if num_scanned == 0 and FLAGS.examples.startswith('gs://'):
      if sharded_file_utils.is_sharded_file_spec(FLAGS.examples):
        paths = sharded_file_utils.generate_sharded_filenames(FLAGS.examples)
        special_gcs_message = ('WARNING: --examples sharded files are either '
                               'all empty or do not exist. Please check that '
                               'the paths are correct:\n')
        for p in paths[0:3]:
          special_gcs_message += 'gsutil ls {}\n'.format(p)
        logging.warning(special_gcs_message)
      else:
        logging.warning(
            'WARNING: --examples file is either empty or does not exist. '
            'Please check that the path is correct: \n'
            'gsutil ls %s', FLAGS.examples)
Ejemplo n.º 5
0
 def testGenerateShardedFilenamesManyShards(self):
     names = io.generate_sharded_filenames('/dir/foo/bar@100000')
     self.assertEqual(len(names), 100000)
     self.assertEqual(names[99999], '/dir/foo/bar-099999-of-100000')
Ejemplo n.º 6
0
 def testGenerateShardedFilenames(self, spec, expected):
     names = io.generate_sharded_filenames(spec)
     self.assertEqual(names, expected)