def main(FLAGS):
    if FLAGS.input_label_database == FLAGS.output_label_database:
        raise ValueError(
            "Input and output label databases can't be the same file.")
    if FLAGS.minimum_label_count < 1:
        raise ValueError(
            'The value of --minimum-label-count must be greater than 2.')

    # Create new output label database if it doesn't exist yet.
    if not pathlib.Path(FLAGS.output_label_database).exists():
        with open(FLAGS.output_label_database, 'w') as f:
            f.write('"Filename","Label","Count"\n')

    # Open label databases.
    print('Opening input label database...')
    with label_database.Database(FLAGS.input_label_database,
                                 readonly=True) as db_in:
        print('Opening output label database...')
        with label_database.Database(FLAGS.output_label_database,
                                     save_backups=False) as db_out:

            # Load labeled images and per-digit labels.
            print('Loading labeled images; arranging test/train data...')
            all_data = load_data(db_in, FLAGS.minimum_label_count,
                                 FLAGS.max_0000)

            # Divide into training and test data.
            train_data, test_data = divide_data(all_data,
                                                FLAGS.train_data_fraction)
            print('   ...loaded', len(train_data), 'data points for training,',
                  len(test_data), 'for testing.')

            # Train classifiers.
            classifiers = []
            for d in range(1, train_data.num_digits() + 1):
                images_train = train_data.images
                images_test = test_data.images
                if FLAGS.mask_digits:
                    print('Preprocessing data for digit {}...'.format(d))
                    images_train = mask_nth_digit_in_images(
                        images_train, d - 1)
                    images_test = mask_nth_digit_in_images(images_test, d - 1)

                print('Training classifier for digit {}...'.format(d))
                cfier = train_classifier(images_train, train_data[d],
                                         images_test, test_data[d])
                print('        Training set accuracy:',
                      test_classifier(cfier, images_train, train_data[d]))
                print('            Test set accuracy:',
                      test_classifier(cfier, images_test, test_data[d]))
                classifiers.append(cfier)

            # Now classify all of the data.
            print('Classifying all word images...')
            classify_everything(db_in, db_out, classifiers, FLAGS.mask_digits)

            # All done!
            print('Saving output label database...')
def main(FLAGS):
  # Load databases that we will compare.
  if FLAGS.label_database in FLAGS.databases_to_compare: raise ValueError(
      'The label database to modify should not also be listed as one of the '
      'additional databases under comparison.')
  if len(FLAGS.databases_to_compare) > 9: raise ValueError(
      'Too many databases are listed. This program supports ten databases max.')
  print('Loading label databases to compare...')
  compare_dbs = [label_database.Database(dbfile, readonly=True)
                 for dbfile in FLAGS.databases_to_compare]

  print('Opening label database to modify...')
  with label_database.Database(FLAGS.label_database) as db:
      
    # Identify word images that different databases have labeled differently.
    print('Looking for ambiguous images (takes a bit)...')
    ambiguous_word_images = find_ambiguous_word_images(db, compare_dbs)

    # OK, go to town!
    ui_for_resolving_ambiguous_word_images(
            db, compare_dbs, FLAGS.screen_image_subdir, ambiguous_word_images)
示例#3
0
def main(FLAGS):
  # Load the collection of filenames whose labels we ignore.
  ignorables = set()
  if FLAGS.skip_XXXX_from:
    sys.stderr.write('Opening {}...\n'.format(FLAGS.skip_XXXX_from))
    with label_database.Database(FLAGS.skip_XXXX_from, readonly=True) as db:
      db_all_labels = db.all_labels_with_counts_of_at_least(0)
      ignorables.update(fn for fn, label in db_all_labels if label =='XXXX')

  # Open label databases.
  sys.stderr.write('Opening {}...\n'.format(FLAGS.label_database_1))
  with label_database.Database(FLAGS.label_database_1, readonly=True) as db1:
    sys.stderr.write('Opening {}...\n'.format(FLAGS.label_database_2))
    with label_database.Database(FLAGS.label_database_2, readonly=True) as db2:
      # Print differences.
      db1_all_labels = db1.all_labels_with_counts_of_at_least(
           FLAGS.minimum_label_count)
      for fn, label1 in db1_all_labels:
        label2, count2 = db2[fn]
        if count2 < FLAGS.minimum_label_count: continue
        if FLAGS.skip_XXXX and 'XXXX' in [label1, label2]: continue
        if label1 == label2: continue
        if fn in ignorables: continue
        print('{}   {} <> {}'.format(fn, label1, label2))
示例#4
0
def main(FLAGS):
  print('Loading...')
  with label_database.Database(FLAGS.label_database) as db:
    if FLAGS.mark_apl_ros_c000_zeros:
      print('Marking APL ROS known-zeros at C000...')
      mark_apl_ros_c000_zeros(db)

    for act_count in itertools.count():
      filename, image = next_image_and_housekeeping(
          db, FLAGS.num_labels, FLAGS.label_bias, FLAGS.scale, act_count)

      if filename is None:
        print('You are finished! Thank you for your hard work!')
        return

      label = quiz_user_for_label(image)
      if not label:
        print('Skipping this image.')
      elif label == 'Q':
        print('Quitting...')
        return
      else:
        db.label(filename, label)
示例#5
0
def main(FLAGS):
    print('Loading trace data...')
    traces = read_traces_csv(FLAGS.traces)
    traces = deltaify_traces(traces)
    traces = filter_silly_bytes(traces, FLAGS.min_byte_duration)
    traces = expand_long_bytes(traces, FLAGS.split_bytes_longer_than,
                               FLAGS.divide_long_bytes_by,
                               FLAGS.max_long_byte_splits)

    print('Listing screen image files...')
    images = sorted(
        os.path.join(FLAGS.screen_image_path, f)
        for f in os.listdir(FLAGS.screen_image_path))

    print('Opening label database...')
    with label_database.Database(FLAGS.label_database) as db:
        # Determine how the paths of word images and screen images differ. Verify
        # that they differ in only the right way.
        wordfile = db.random()
        wf_parts = wordfile.split(os.sep)
        wf_innermost_dir = wf_parts[-2]

        screenfile = images[0]
        sf_parts = screenfile.split(os.sep)
        sf_root, sf_ext = os.path.splitext(sf_parts[-1])

        made_up_wordfile = os.sep.join(
            sf_parts[:-2] + [wf_innermost_dir, sf_root + '_1_1' + sf_ext])
        assert made_up_wordfile in db, (
            'Label database file paths and the screen image path appear to differ '
            'in more ways than the innermost directory name. Parts that should '
            'have been more common---labels: {}, images: {}. Giving up...'
            ''.format(wf_parts, sf_parts))

        # Start up the user interface.
        ui(db, traces, images, wf_innermost_dir)
示例#6
0
def main(FLAGS):
  sys.stderr.write('Loading...\n')
  # Load ground-truth database.
  db_truth = label_database.Database(FLAGS.ground_truth_database, readonly=True)
  # Load other databases.
  # Nevermind, this uses up too much RAM. Guess we'll trade space for time.
  # dbs_other = tuple(label_database.Database(dbfile, readonly=True)
  #                   for dbfile in FLAGS.label_databases)
  # dbs_all = (db_truth,) + dbs_other


  sys.stderr.write('Preliminaries..')
  # Print out the databases we've loaded.
  print('======== Label databases:')
  print('0:', FLAGS.ground_truth_database)
  for i, dbfile in enumerate(FLAGS.label_databases, start=1):
    print('{}:'.format(i), dbfile)

  # Gather image filenames.
  image_filenames = set()
  filtering_substrings = FLAGS.image_substrings.split(',')
  #for db in dbs_all:  # too memory intensive!
  for dbfile in [FLAGS.ground_truth_database] + FLAGS.label_databases:
    # Load the database.
    if dbfile == FLAGS.ground_truth_database:
      db = db_truth
    else:
      db = label_database.Database(dbfile, readonly=True)
    sys.stderr.write('.')
    sys.stderr.flush()
    # Gather image filenames.
    for fn, _ in db.all_labels_with_counts_of_at_least(2):
      for substring in filtering_substrings:
        if substring in fn:
          image_filenames.add(fn)
          break
    # Release memory early.
    if db is not db_truth: del db
  image_filenames = collections.OrderedDict(
      (imfile, i) for i, imfile in enumerate(sorted(image_filenames)))
  print('======== Images:')
  for imfile, i in image_filenames.items():
    print('{}:'.format(i), imfile)

  # Isolate image filename stems; that is, everything prior to the 1_3.png part
  # at the end of the filename. We could do it the right way and parse the
  # filenames, but easier just to lop off the last seven characters.
  image_filename_stems = set()
  for imfile in image_filenames:
    image_filename_stems.add(imfile[:-7])

  # Filter out all stems where either of the leftmost columns has the label
  # 'XXXX' in the ground truth database.
  filtered_image_filename_stems = set()
  for stem in image_filename_stems:
    l1, c1 = db_truth['{}0_0.png'.format(stem)]
    l2, c2 = db_truth['{}1_0.png'.format(stem)]
    if c1 >= 2 and l1 == 'XXXX': continue
    if c2 >= 2 and l2 == 'XXXX': continue
    filtered_image_filename_stems.add(stem)
  image_filename_stems = filtered_image_filename_stems
  del filtered_image_filename_stems  # No longer used.

  sys.stderr.write('\n')


  sys.stderr.write('Assembly..')
  # We assemble the data structure described in the class docstring in this:
  # a mapping from memory addresses to (a mapping from labels to the <locator>
  # pairs described in the top docstring).
  result = collections.defaultdict(lambda: collections.defaultdict(list))

  #for stem in image_filename_stems:
    #for db_index, db in enumerate(dbs_all):  # too memory intensive!
  for db_index, dbfile in enumerate(
      [FLAGS.ground_truth_database] + FLAGS.label_databases):
    # 0. Load the database.
    if dbfile == FLAGS.ground_truth_database:
      db = db_truth
    else:
      db = label_database.Database(dbfile, readonly=True)
    sys.stderr.write('.');
    sys.stderr.flush()

    for stem in image_filename_stems:
      # 1. Collect memory addresses associated with this video frame by the
      #    current database. There should be a gap of 0x10 between both.
      l1, c1 = db['{}0_0.png'.format(stem)]
      l2, c2 = db['{}1_0.png'.format(stem)]
      if c1 < 2 or c2 < 2: continue  # Address words haven't been parsed.
      l1, l2 = int(l1, 16), int(l2, 16)
      if l2 - l1 != 0x10: continue   # Gap is not 0x10.

      # 2. For each memory address shown in this video frame, record labels
      #    associated with that address by the current database.
      def lookup_and_record_label(imfile, address):
        label, count = db[imfile]
        if count >= 2:  # Only if this image is parsed in this database.
          if_index = image_filenames[imfile]
          result['{:04X}'.format(address)][label].append((db_index, if_index))

      for i, pos in enumerate(
          ('0_1', '0_2', '0_3', '0_4', '0_5', '0_6', '0_7', '0_8')):
        imfile = '{}{}.png'.format(stem, pos)
        lookup_and_record_label(imfile, l1 + 2 * i)

      for i, pos in enumerate(
          ('1_1', '1_2', '1_3', '1_4', '1_5', '1_6', '1_7', '1_8')):
        imfile = '{}{}.png'.format(stem, pos)
        lookup_and_record_label(imfile, l2 + 2 * i)

    # 3. Help the garbage collector.
    if db is not db_truth: del db

  sys.stderr.write('\n')


  # Print out the result.
  print('======== Labels:')
  for address in sorted(result):
    line = '{}:'.format(address)
    for label, locators in sorted(result[address].items()):
      line += ' {}@{}'.format(label, '/'.join(
          '{},{}'.format(dbi, imi) for dbi, imi in sorted(locators)))
    print(line)


  # Just for the user's information, print out whether we have a contiguous
  # range of addresses.
  prev_address = int(min(result), 16)
  sys.stderr.write('INFO: first address {:04x}\n'.format(prev_address))
  for address in sorted(result):
    address = int(address, 0x10)
    if address - prev_address > 0x10:
      for gone_address in range(prev_address + 0x10, address, 0x10):
        sys.stderr.write('WARNING: no data for {:04X}\n'.format(gone_address))
    prev_address = address
  sys.stderr.write('INFO: last address {:04X}\n'.format(prev_address))

  # Just for the user's information, print out a histogram of the number of
  # addresses that have K labels associated with them.
  histogram = collections.defaultdict(lambda: 0)
  for labelings in result.values():
    histogram[len(labelings)] += 1
  sys.stderr.write(
      'INFO: {} addresses with unanimous labeling\n'.format(histogram[1]))
  for i in range(2, max(histogram) + 1):
    sys.stderr.write(
        'INFO: {} addresses with {} labels\n'.format(histogram[i], i))


  return result