def inference(reader, checkpoint_file, train_dir, data_pattern, out_file_location, batch_size, top_k): with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file: video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size) if checkpoint_file: if not gfile.Exists(checkpoint_file + ".meta"): logging.fatal("Unable to find checkpoint file at provided location '%s'" % checkpoint_file) latest_checkpoint = checkpoint_file else: latest_checkpoint = tf.train.latest_checkpoint(train_dir) if latest_checkpoint is None: raise Exception("unable to find a checkpoint at location: %s" % train_dir) else: meta_graph_location = latest_checkpoint + ".meta" logging.info("loading meta-graph: " + meta_graph_location) saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) logging.info("restoring variables from " + latest_checkpoint) saver.restore(sess, latest_checkpoint) input_tensor = tf.get_collection("input_batch_raw")[0] num_frames_tensor = tf.get_collection("num_frames")[0] predictions_tensor = tf.get_collection("predictions")[0] # Workaround for num_epochs issue. def set_up_init_ops(variables): init_op_list = [] for variable in list(variables): if "train_input" in variable.name: init_op_list.append(tf.assign(variable, 1)) variables.remove(variable) init_op_list.append(tf.variables_initializer(variables)) return init_op_list sess.run(set_up_init_ops(tf.get_collection_ref( tf.GraphKeys.LOCAL_VARIABLES))) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) num_examples_processed = 0 start_time = time.time() out_file.write("VideoId,LabelConfidencePairs\n") try: while not coord.should_stop(): video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch]) predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val}) now = time.time() num_examples_processed += len(video_batch_val) num_classes = predictions_val.shape[1] logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time)) for line in format_lines(video_id_batch_val, predictions_val, top_k): out_file.write(line) out_file.flush() except tf.errors.OutOfRangeError: logging.info('Done with inference. The output file was written to ' + out_file_location) finally: coord.request_stop() coord.join(threads) sess.close()
def check_status(self): status = io_util.load_text_proto(self.status_filename(), loop_pb2.LoopStatus, 'Status file') if (status.name != self.status.name or status.last_finished_round != self.status.last_finished_round or status.current_round != self.status.current_round or status.running_controller != self.running_controller): logging.fatal('Inconsistent status between stored status and disk')
def make_dir(dir_name: str) -> str: if gfile.Exists(dir_name): if gfile.IsDirectory(dir_name): return dir_name else: logging.fatal( 'Trying to create directory "%s", but there ' 'is a file with the same name', dir_name) gfile.MakeDirs(dir_name) return dir_name
def historical_examples_pipeline(self, root, write_to_fresh=False): """Pipeline for generating historical examples. Args: root: The beam source to anchor the pipeline on. write_to_fresh: Boolean signifying whether the historical data should also be written to the fresh examples. """ file_pattern = self.loop_meta.all_proof_logs_input_pattern() historical_dir = self.loop_meta.historical_examples_path() fresh_dir = self.loop_meta.fresh_examples_path() logging.info('Input proof logs file pattern:\n%s', file_pattern) collections = [ root | ('ReadAllProofLogs%d' % i) >> recordio.ReadFromRecordIO( pattern, coder=beam.coders.ProtoCoder(deephol_pb2.ProofLog)) for i, pattern in enumerate(file_pattern.split(',')) if gfile.Glob(pattern) ] if collections: logging.info('Historical prooflog collections: %d.', len(collections)) examples = ( collections | 'FlattenInputProofLogs' >> beam.Flatten() | ('ConvertHistoricalToTFExamples' >> beam.ParDo( ProofLogToTFExamplesDoFn( str(self.config.prover_options.path_tactics), self.theorem_db, self.config.convertor_options)))) _ = examples | 'WriteHistoricalTFExamples' >> sstableio.WriteToSSTable( file_path_prefix=os.path.join(historical_dir, 'train_examples'), key_coder=beam.coders.BytesCoder(), value_coder=beam.coders.BytesCoder(), num_shards=self.config.historical_examples_shards) if write_to_fresh: _ = examples | 'WriteFreshToHistoricalTFExamples' >> ( sstableio.WriteToSSTable( file_path_prefix=os.path.join(fresh_dir, 'train_examples'), key_coder=beam.coders.BytesCoder(), value_coder=beam.coders.BytesCoder(), num_shards=self.config.fresh_examples_shards)) else: logging.fatal('There are no historical files to process.')
def main(argv): runner.program_started() if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') assert FLAGS.root is not None, 'Required flag --root is missing.' config = io_util.load_text_proto(FLAGS.config, loop_pb2.LoopConfig) if (not FLAGS.rounds and not FLAGS.initial_examples and not config.inherited_proof_logs): logging.fatal('Loop setup requires either initial examples ' 'or inherited proof logs') controller_fingerprint = loop_meta.create_fingerprint() meta = loop_meta.LoopMeta(FLAGS.root, config, controller_fingerprint, False) assert meta.status, 'Could not read status' loop = loop_pipeline.LoopPipeline(meta, config) if not FLAGS.rounds: logging.info('Setting up loop...') loop.setup_examples(FLAGS.initial_examples) else: for _ in xrange(FLAGS.rounds): loop.perform_round(FLAGS.initial_examples)
def main(unused_argv): logging.set_verbosity(tf.logging.INFO) paths = gfile.Glob(FLAGS.input_data_pattern) logging.info("Found %s files.", len(paths)) for path in paths: with gfile.Open(path, "r") as f: first_read = True while True: length_raw = f.read(8) if not length_raw and first_read: logging.fatal("File %s has no data.", path) break elif not length_raw: logging.info("File %s looks good.", path) break else: first_read = False if len(length_raw) != 8: logging.fatal("File ends when reading record length: " + path) break length, = struct.unpack("L", length_raw) # +8 to include the crc values. record = f.read(length + 8) if len(record) != length + 8: logging.fatal("File ends in the middle of a record: " + path) break