Example #1
0
def inference(reader, checkpoint_file, train_dir, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
    video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
    if checkpoint_file:
      if not gfile.Exists(checkpoint_file + ".meta"):
        logging.fatal("Unable to find checkpoint file at provided location '%s'" % checkpoint_file)
      latest_checkpoint = checkpoint_file
    else:
      latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if latest_checkpoint is None:
      raise Exception("unable to find a checkpoint at location: %s" % train_dir)
    else:
      meta_graph_location = latest_checkpoint + ".meta"
      logging.info("loading meta-graph: " + meta_graph_location)
    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + latest_checkpoint)
    saver.restore(sess, latest_checkpoint)
    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()
    out_file.write("VideoId,LabelConfidencePairs\n")

    try:
      while not coord.should_stop():
          video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
          predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
          now = time.time()
          num_examples_processed += len(video_batch_val)
          num_classes = predictions_val.shape[1]
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          for line in format_lines(video_id_batch_val, predictions_val, top_k):
            out_file.write(line)
          out_file.flush()


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
 def check_status(self):
   status = io_util.load_text_proto(self.status_filename(),
                                    loop_pb2.LoopStatus, 'Status file')
   if (status.name != self.status.name or
       status.last_finished_round != self.status.last_finished_round or
       status.current_round != self.status.current_round or
       status.running_controller != self.running_controller):
     logging.fatal('Inconsistent status between stored status and disk')
def make_dir(dir_name: str) -> str:
  if gfile.Exists(dir_name):
    if gfile.IsDirectory(dir_name):
      return dir_name
    else:
      logging.fatal(
          'Trying to create directory "%s", but there '
          'is a file with the same name', dir_name)
  gfile.MakeDirs(dir_name)
  return dir_name
Example #4
0
    def historical_examples_pipeline(self, root, write_to_fresh=False):
        """Pipeline for generating historical examples.

    Args:
      root: The beam source to anchor the pipeline on.
      write_to_fresh: Boolean signifying whether the historical data should also
        be written to the fresh examples.
    """
        file_pattern = self.loop_meta.all_proof_logs_input_pattern()
        historical_dir = self.loop_meta.historical_examples_path()
        fresh_dir = self.loop_meta.fresh_examples_path()
        logging.info('Input proof logs file pattern:\n%s', file_pattern)
        collections = [
            root | ('ReadAllProofLogs%d' % i) >> recordio.ReadFromRecordIO(
                pattern, coder=beam.coders.ProtoCoder(deephol_pb2.ProofLog))
            for i, pattern in enumerate(file_pattern.split(','))
            if gfile.Glob(pattern)
        ]
        if collections:
            logging.info('Historical prooflog collections: %d.',
                         len(collections))
            examples = (
                collections | 'FlattenInputProofLogs' >> beam.Flatten()
                | ('ConvertHistoricalToTFExamples' >> beam.ParDo(
                    ProofLogToTFExamplesDoFn(
                        str(self.config.prover_options.path_tactics),
                        self.theorem_db, self.config.convertor_options))))
            _ = examples | 'WriteHistoricalTFExamples' >> sstableio.WriteToSSTable(
                file_path_prefix=os.path.join(historical_dir,
                                              'train_examples'),
                key_coder=beam.coders.BytesCoder(),
                value_coder=beam.coders.BytesCoder(),
                num_shards=self.config.historical_examples_shards)
            if write_to_fresh:
                _ = examples | 'WriteFreshToHistoricalTFExamples' >> (
                    sstableio.WriteToSSTable(
                        file_path_prefix=os.path.join(fresh_dir,
                                                      'train_examples'),
                        key_coder=beam.coders.BytesCoder(),
                        value_coder=beam.coders.BytesCoder(),
                        num_shards=self.config.fresh_examples_shards))
        else:
            logging.fatal('There are no historical files to process.')
def main(argv):
    runner.program_started()
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    assert FLAGS.root is not None, 'Required flag --root is missing.'
    config = io_util.load_text_proto(FLAGS.config, loop_pb2.LoopConfig)
    if (not FLAGS.rounds and not FLAGS.initial_examples
            and not config.inherited_proof_logs):
        logging.fatal('Loop setup requires either initial examples '
                      'or inherited proof logs')
    controller_fingerprint = loop_meta.create_fingerprint()
    meta = loop_meta.LoopMeta(FLAGS.root, config, controller_fingerprint,
                              False)
    assert meta.status, 'Could not read status'
    loop = loop_pipeline.LoopPipeline(meta, config)
    if not FLAGS.rounds:
        logging.info('Setting up loop...')
        loop.setup_examples(FLAGS.initial_examples)
    else:
        for _ in xrange(FLAGS.rounds):
            loop.perform_round(FLAGS.initial_examples)
Example #6
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)
    paths = gfile.Glob(FLAGS.input_data_pattern)
    logging.info("Found %s files.", len(paths))
    for path in paths:
        with gfile.Open(path, "r") as f:
            first_read = True
            while True:
                length_raw = f.read(8)
                if not length_raw and first_read:
                    logging.fatal("File %s has no data.", path)
                    break
                elif not length_raw:
                    logging.info("File %s looks good.", path)
                    break
                else:
                    first_read = False
                if len(length_raw) != 8:
                    logging.fatal("File ends when reading record length: " +
                                  path)
                    break
                length, = struct.unpack("L", length_raw)
                # +8 to include the crc values.
                record = f.read(length + 8)
                if len(record) != length + 8:
                    logging.fatal("File ends in the middle of a record: " +
                                  path)
                    break