def main(unused_argv):

  print('Constructing models and inputs.')
  with tf.variable_scope('model', reuse=None) as training_scope:
    images, actions, states = build_tfrecord_input(training=True)
    model = Model(images, actions, states, FLAGS.sequence_length,
                  prefix='train')

  with tf.variable_scope('val_model', reuse=None):
    val_images, val_actions, val_states = build_tfrecord_input(training=False)
    val_model = Model(val_images, val_actions, val_states,
                      FLAGS.sequence_length, training_scope, prefix='val')

  print('Constructing saver.')
  # Make saver.
  saver = tf.train.Saver(
      tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)

  # Make training session.
  sess = tf.InteractiveSession()
  summary_writer = tf.summary.FileWriter(
      FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)

  if FLAGS.pretrained_model:
    saver.restore(sess, FLAGS.pretrained_model)

  tf.train.start_queue_runners(sess)
  sess.run(tf.global_variables_initializer())

  tf.logging.info('iteration number, cost')

  # Run training.
  for itr in range(FLAGS.num_iterations):
    # Generate new batch of data.
    feed_dict = {model.iter_num: np.float32(itr),
                 model.lr: FLAGS.learning_rate}
    cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
                                    feed_dict)

    # Print info: iteration #, cost.
    tf.logging.info(str(itr) + ' ' + str(cost))

    if (itr) % VAL_INTERVAL == 2:
      # Run through validation set.
      feed_dict = {val_model.lr: 0.0,
                   val_model.iter_num: np.float32(itr)}
      _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
                                     feed_dict)
      summary_writer.add_summary(val_summary_str, itr)

    if (itr) % SAVE_INTERVAL == 2:
      tf.logging.info('Saving model.')
      saver.save(sess, FLAGS.output_dir + '/model' + str(itr))

    if (itr) % SUMMARY_INTERVAL:
      summary_writer.add_summary(summary_str, itr)

  tf.logging.info('Saving model.')
  saver.save(sess, FLAGS.output_dir + '/model')
  tf.logging.info('Training complete')
def prediction():

    print('Constructing models and inputs.')
    #训练集
    with tf.variable_scope('model', reuse=None) as training_scope:
        images, actions, states = build_tfrecord_input(training=False,
                                                       vil=False)
        model = Model(images,
                      actions,
                      states,
                      FLAGS.sequence_length,
                      prefix='train')

    print('Constructing saver.')
    # Make saver.保存
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),
                           max_to_keep=0)

    # Make training session. 培训
    sess = tf.InteractiveSession()  #可加入计算图
    sess.run(tf.global_variables_initializer())
    if FLAGS.pretrained_model:
        saver.restore(sess, FLAGS.pretrained_model)

    tf.train.start_queue_runners(sess)  # 启动填充队列线程

    tf.logging.info('iteration number, cost')

    # Generate new batch of data.
    feed_dict = {model.iter_num: np.float32(0), model.lr: FLAGS.learning_rate}
    psnr, cost, _, summary_str, images, gen_images = sess.run([
        model.psnr_all, model.loss, model.train_op, model.summ_op,
        model.images, model.gen_images
    ], feed_dict)  #计算cost 和 summary_str

    # Print info: iteration #, cost.
    print(cost)
    print('-----')
    print(psnr)
    output_gif(images, 'images')
    output_gif(gen_images, 'gen_images')
示例#3
0
def main(unused_argv):

  print('Constructing models and inputs.')
  with tf.variable_scope('model', reuse=None) as training_scope:
    images, actions, states = build_tfrecord_input(training=True)
    model = Model(images, actions, states, FLAGS.sequence_length,
                  prefix='train')

  with tf.variable_scope('val_model', reuse=None):
    val_images, val_actions, val_states = build_tfrecord_input(training=False)
    val_model = Model(val_images, val_actions, val_states,
                      FLAGS.sequence_length, training_scope, prefix='val')

  print('Constructing saver.')
  # Make saver.
  saver = tf.train.Saver(
      tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)

  # Make training session.
  sess = tf.InteractiveSession()
  summary_writer = tf.summary.FileWriter(
      FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)

  if FLAGS.pretrained_model:
    saver.restore(sess, FLAGS.pretrained_model)

  tf.train.start_queue_runners(sess)
  sess.run(tf.global_variables_initializer())

  tf.logging.info('iteration number, cost')

  # Run training.
  for itr in range(FLAGS.num_iterations):
    # Generate new batch of data.
    feed_dict = {model.iter_num: np.float32(itr),
                 model.lr: FLAGS.learning_rate}
    cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
                                    feed_dict)

    # Print info: iteration #, cost.
    tf.logging.info(str(itr) + ' ' + str(cost))

    if (itr) % VAL_INTERVAL == 2:
      # Run through validation set.
      feed_dict = {val_model.lr: 0.0,
                   val_model.iter_num: np.float32(itr)}
      _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
                                     feed_dict)
      summary_writer.add_summary(val_summary_str, itr)

    if (itr) % SAVE_INTERVAL == 2:
      tf.logging.info('Saving model.')
      saver.save(sess, FLAGS.output_dir + '/model' + str(itr))

    if (itr) % SUMMARY_INTERVAL:
      summary_writer.add_summary(summary_str, itr)

  tf.logging.info('Saving model.')
  saver.save(sess, FLAGS.output_dir + '/model')
  tf.logging.info('Training complete')
  tf.logging.flush()
示例#4
0
def main(unused_argv):

    print('Constructing models and inputs.')
    with tf.variable_scope('model', reuse=None) as training_scope:
        images, actions, states = build_tfrecord_input(training=True)
        model = Model(images,
                      actions,
                      states,
                      FLAGS.sequence_length,
                      prefix='train')

    with tf.variable_scope('val_model', reuse=None):
        val_images, val_actions, val_states = build_tfrecord_input(
            training=False)
        val_model = Model(val_images,
                          val_actions,
                          val_states,
                          FLAGS.sequence_length,
                          training_scope,
                          prefix='val')

    print('Constructing saver.')
    # Make saver.
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),
                           max_to_keep=0)

    # Make training session.
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter(FLAGS.event_log_dir,
                                           graph=sess.graph,
                                           flush_secs=10)

    if FLAGS.pretrained_model:
        saver.restore(sess, FLAGS.pretrained_model)

    tf.train.start_queue_runners(sess)

    tf.logging.info('iteration number, cost')

    # Run training.
    for itr in range(FLAGS.num_iterations):
        # Generate new batch of data.
        feed_dict = {
            model.iter_num: np.float32(itr),
            model.lr: FLAGS.learning_rate
        }
        cost, _, summary_str = sess.run(
            [model.loss, model.train_op, model.summ_op], feed_dict)

        # Print info: iteration #, cost.
        tf.logging.info(str(itr) + ' ' + str(cost))

        if (itr) % VAL_INTERVAL == 2:
            # Run through validation set.
            feed_dict = {
                val_model.lr: 0.0,
                val_model.iter_num: np.float32(itr)
            }

            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            model_name = 'video_prediction_{}'.format(FLAGS.model.lower())

            eval_loss = sess.run(val_model.loss,
                                 feed_dict=feed_dict,
                                 options=options,
                                 run_metadata=run_metadata)
            cg = CompGraph(model_name, run_metadata, tf.get_default_graph())

            cg_tensor_dict = cg.get_tensors()
            cg_sorted_keys = sorted(cg_tensor_dict.keys())
            cg_sorted_items = []
            for cg_key in cg_sorted_keys:
                cg_sorted_items.append(tf.shape(cg_tensor_dict[cg_key]))

            cg_sorted_shape = sess.run(cg_sorted_items, feed_dict=feed_dict)
            cg.op_analysis(dict(zip(cg_sorted_keys, cg_sorted_shape)),
                           '{}.pickle'.format(model_name))

            #_, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
            #                               feed_dict)

            print('Evaluation finished')
            exit(0)

            summary_writer.add_summary(val_summary_str, itr)

        if (itr) % SAVE_INTERVAL == 2:
            tf.logging.info('Saving model.')
            saver.save(sess, FLAGS.output_dir + '/model' + str(itr))

        if (itr) % SUMMARY_INTERVAL:
            summary_writer.add_summary(summary_str, itr)

    tf.logging.info('Saving model.')
    saver.save(sess, FLAGS.output_dir + '/model')
    tf.logging.info('Training complete')
    tf.logging.flush()