Ejemplo n.º 1
0
def main(_):
    problem_name = FLAGS.problem
    if "video" not in problem_name and "gym" not in problem_name:
        print("This tool only works for video problems.")
        return

    mode = tf.estimator.ModeKeys.TRAIN
    hparams = trainer_lib.create_hparams(FLAGS.hparams_set,
                                         FLAGS.hparams,
                                         data_dir=os.path.expanduser(
                                             FLAGS.data_dir),
                                         problem_name=problem_name)

    dataset = hparams.problem.input_fn(mode, hparams)
    features = dataset.make_one_shot_iterator().get_next()

    tf.gfile.MakeDirs(FLAGS.output_dir)
    base_template = os.path.join(FLAGS.output_dir, FLAGS.problem)
    count = 0
    with tf.train.MonitoredTrainingSession() as sess:
        while not sess.should_stop():
            # TODO(mbz): figure out what the second output is.
            data, _ = sess.run(features)
            video_batch = np.concatenate((data["inputs"], data["targets"]),
                                         axis=1)

            for video in video_batch:
                print("Saving {}/{}".format(count, FLAGS.num_samples))
                name = "%s_%05d" % (base_template, count)
                decoding.save_video(video, name + "_{:05d}.png")
                create_gif(name)
                count += 1

                if count == FLAGS.num_samples:
                    sys.exit(0)
Ejemplo n.º 2
0
def main(_):
  problem_name = FLAGS.problem
  if "video" not in problem_name and "gym" not in problem_name:
    print("This tool only works for video problems.")
    return

  mode = tf.estimator.ModeKeys.TRAIN
  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=problem_name)

  dataset = hparams.problem.input_fn(mode, hparams)
  features = dataset.make_one_shot_iterator().get_next()

  tf.gfile.MakeDirs(FLAGS.output_dir)
  base_template = os.path.join(FLAGS.output_dir, FLAGS.problem)
  count = 0
  with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
      # TODO(mbz): figure out what the second output is.
      data, _ = sess.run(features)
      video_batch = np.concatenate((data["inputs"], data["targets"]), axis=1)

      for video in video_batch:
        print("Saving {}/{}".format(count, FLAGS.num_samples))
        name = "%s_%05d" % (base_template, count)
        decoding.save_video(video, name + "_{:05d}.png")
        create_gif(name)
        count += 1

        if count == FLAGS.num_samples:
          sys.exit(0)