예제 #1
0
파일: train.py 프로젝트: Hillyess/im2txt
def main(argv):
    config = Config()
    config.input_file_pattern = FLAGS.input_file_pattern
    config.optimizer = FLAGS.optimizer
    config.attention_mechanism = FLAGS.attention
    config.save_dir = FLAGS.train_dir
    
    # Create training directory.
    train_dir = config.save_dir
    if not tf.gfile.IsDirectory(train_dir):
        tf.logging.info("Creating training directory: %s", train_dir)
        tf.gfile.MakeDirs(train_dir)

    # Build the TensorFlow graph.
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = CaptionGenerator(config, mode="train")
        model.build()
    
        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(max_to_keep=config.max_checkpoints_to_keep)

    sess_config = tf.ConfigProto()

    sess_config.gpu_options.allow_growth = True

    # Run training.
    tf.contrib.slim.learning.train(
        model.opt_op,
        train_dir,
        log_every_n_steps=config.log_every_n_steps,
        graph=g,
        global_step=model.global_step,
        number_of_steps=FLAGS.number_of_steps,

        summary_op=model.summary,
        save_summaries_secs=60,
        save_interval_secs=600,
        init_fn=None,
        saver=saver,
        session_config=sess_config)
예제 #2
0
def run():
  """Runs evaluation in a loop, and logs summaries to TensorBoard."""
  # Create the evaluation directory if it doesn't exist.
  eval_dir = FLAGS.eval_dir
  if not tf.gfile.IsDirectory(eval_dir):
    tf.logging.info("Creating eval directory: %s", eval_dir)
    tf.gfile.MakeDirs(eval_dir)

  # build vocabulary file
  vocab = vocabulary.Vocabulary(FLAGS.vocab_file)

  g = tf.Graph()
  with g.as_default():

    config = Config()
    config.input_file_pattern = FLAGS.input_file_pattern
    config.beam_size = FLAGS.beam_size

    # Build the model for evaluation.
    model = CaptionGenerator(config, mode="eval") 
    model.build()

    # Create the Saver to restore model Variables.
    saver = tf.train.Saver()

    # Create the summary writer.
    summary_writer = tf.summary.FileWriter(eval_dir)

    g.finalize()

    # Run a new evaluation run every eval_interval_secs.
    while True:
      start = time.time()
      tf.logging.info("Starting evaluation at " + time.strftime(
          "%Y-%m-%d-%H:%M:%S", time.localtime()))
      run_once(model,vocab, saver, summary_writer)
      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
      if time_to_next_eval > 0:
        time.sleep(time_to_next_eval)
예제 #3
0
def export_graph(model_folder,model_name,config):
  graph = tf.Graph()
  with graph.as_default():
    model = CaptionGenerator(config, mode="inference") 
    model.build()

    # input tensor can't use tf.identity() to rename
    # inputs = {}
    outputs = {}
    # # input
    # inputs['contexts'] = tf.identity(model.contexts, name='contexts')
    # inputs['last_word'] = tf.identity(model.last_word, name='last_word')
    # inputs['last_memory'] = tf.identity(model.last_memory, name='last_memory')
    # inputs['last_output'] = tf.identity(model.last_output, name='last_output')
    # outputs
    outputs['initial_memory'] = tf.identity(model.initial_memory, name='initial_memory')
    outputs['initial_output'] = tf.identity(model.initial_output, name='initial_output')
    
    # results
    outputs['alpha'] = tf.identity(model.alpha, name='alpha')
    outputs['memory'] = tf.identity(model.memory, name='memory')
    outputs['output'] = tf.identity(model.output, name='output')
    outputs['probs'] = tf.identity(model.probs, name='probs')
    # logits = model.inference(input_image)
    # y_conv = tf.nn.softmax(logits,name='outputdata')
    restore_saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    latest_ckpt = tf.train.latest_checkpoint(model_folder)
    restore_saver.restore(sess, latest_ckpt)
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), list(outputs.keys()))

#    tf.train.write_graph(output_graph_def, 'log', model_name, as_text=False)
    with tf.gfile.GFile(model_name, "wb") as f:  
        f.write(output_graph_def.SerializeToString())