Beispiel #1
0
 def test_draw_deepvariant_pileup_with_example_input(self, composite_type):
     _, example = _mock_example_with_image((100, 10, 7))
     # Testing that it runs without error
     vis.draw_deepvariant_pileup(example=example,
                                 composite_type=composite_type)
Beispiel #2
0
 def test_draw_deepvariant_pileup_with_channels_input(self, composite_type):
     channels = [_image_array((100, 221)) for _ in range(6)]
     # Testing that it runs without error
     vis.draw_deepvariant_pileup(channels=channels,
                                 composite_type=composite_type)
def run():
  """Create pileup images from examples, filtered in various ways."""
  with errors.clean_commandline_error_exit():
    if FLAGS.column_labels:
      column_labels = FLAGS.column_labels.split(',')
    else:
      column_labels = None

    filter_to_vcf = FLAGS.vcf is not None
    if filter_to_vcf:
      ids_from_vcf = parse_vcf(FLAGS.vcf)
      logging.info(
          'Found %d loci in VCF. '
          'Only examples matching these loci will be output.',
          len(ids_from_vcf))

    filter_to_region = FLAGS.regions is not None
    if filter_to_region:
      passes_region_filter = create_region_filter(
          region_flag_string=FLAGS.regions, verbose=FLAGS.verbose)

    # Use nucleus.io.tfrecord to read all shards.
    dataset = tfrecord.read_tfrecords(FLAGS.examples)

    # Check flag here to avoid expensive string matching on every iteration.
    make_rgb = FLAGS.image_type in ['both', 'RGB']
    make_channels = FLAGS.image_type in ['both', 'channels']

    num_scanned = 0
    num_output = 0
    for example in dataset:
      num_scanned += 1
      # Only when scanning many examples, print a dot for each one to
      # indicate that the script is making progress and not stalled.
      if num_scanned % UPDATE_EVERY_N_EXAMPLES == 0:
        if num_scanned == UPDATE_EVERY_N_EXAMPLES:
          print('Reporting progress below. Writing one dot every time {} '
                'examples have been scanned:'.format(UPDATE_EVERY_N_EXAMPLES))
        # Print another dot on the same line, using print since logging does
        # not support printing without a newline.
        print('.', end='', flush=True)

      # Extract variant from example.
      variant = vis.variant_from_example(example)
      locus_id = vis.locus_id_from_variant(variant)
      indices = vis.alt_allele_indices_from_example(example)

      # Optionally filter to variants in the VCF.
      if filter_to_vcf:
        # Check if the locus is in the VCF.
        if locus_id not in ids_from_vcf:
          # Skip this example since it doesn't match the VCF.
          continue

      if filter_to_region and not passes_region_filter(variant):
        continue

      # Use locus ID in the filename, replacing long alleles with INS/DEL sizes.
      locus_with_alt_id = get_short_id(variant, indices)

      # Examples of long alleles replaced with their sizes:
      # 20:62456134_INS103bp.png
      # 20:62481177_DEL51bp.png

      # Examples of short alleles where the full string is included:
      # 1:55424995_TC->T.png
      # 1:55424996_CT->CTT.png
      # 1:55424996_CT->C.png
      # 1:55424996_CT->TTT.png
      # 1:55424996_CT->C|CTT.png

      if FLAGS.verbose:
        logging.info('\nOutputting image for: %s', locus_with_alt_id)
        full_id = get_full_id(variant, indices)
        if locus_with_alt_id != full_id:
          logging.info(
              'ID above was shortened due to long ref/alt strings. '
              'Original: %s', full_id)

      # If the example has a truth label, optionally include it.
      optional_truth_label = ''
      if FLAGS.truth_labels:
        truth_label = get_label(example)
        if truth_label is not None:
          optional_truth_label = '_label{}'.format(truth_label)

      # Extract and format example into channels.
      channels = vis.channels_from_example(example)
      if column_labels is not None and len(column_labels) != len(channels):
        raise ValueError(
            '--column_labels must have {} names separated by commas, since '
            'there are {} channels in the examples. '
            'However, {} column labels were found: {}'.format(
                len(channels), len(channels), len(column_labels),
                ','.join(['"{}"'.format(x) for x in column_labels])))

      output_prefix = '{}_'.format(
          FLAGS.output) if FLAGS.output is not None else ''

      # Create image with a grey-scale row of channels and save to file.
      if make_channels:
        channels_output = '{}channels_{}{}.png'.format(output_prefix,
                                                       locus_with_alt_id,
                                                       optional_truth_label)

        vis.draw_deepvariant_pileup(
            channels=channels,
            path=channels_output,
            scale=1,
            show=False,
            annotated=FLAGS.annotation,
            labels=column_labels)

      # Create RGB image and save to file.
      if make_rgb:
        rgb_output = '{}rgb_{}{}.png'.format(output_prefix, locus_with_alt_id,
                                             optional_truth_label)
        vis.draw_deepvariant_pileup(
            channels=channels,
            composite_type='RGB',
            path=rgb_output,
            scale=1,
            show=False,
            annotated=FLAGS.annotation,
            labels=column_labels)

      # Check if --num_records quota has been hit yet.
      num_output += 1
      if FLAGS.num_records != -1 and num_output >= FLAGS.num_records:
        break

    logging.info('Scanned %d examples and output %d images.', num_scanned,
                 num_output)

    if num_scanned == 0 and FLAGS.examples.startswith('gs://'):
      if sharded_file_utils.is_sharded_file_spec(FLAGS.examples):
        paths = sharded_file_utils.generate_sharded_filenames(FLAGS.examples)
        special_gcs_message = ('WARNING: --examples sharded files are either '
                               'all empty or do not exist. Please check that '
                               'the paths are correct:\n')
        for p in paths[0:3]:
          special_gcs_message += 'gsutil ls {}\n'.format(p)
        logging.warning(special_gcs_message)
      else:
        logging.warning(
            'WARNING: --examples file is either empty or does not exist. '
            'Please check that the path is correct: \n'
            'gsutil ls %s', FLAGS.examples)