示例#1
0
 def __init__(
     self, batch_size, observation_space, action_space, policy_hparams,
     policy_dir, sampling_temp
 ):
   super(PolicyAgent, self).__init__(
       batch_size, observation_space, action_space
   )
   self._sampling_temp = sampling_temp
   with tf.Graph().as_default():
     self._observations_t = tf.placeholder(
         shape=((batch_size,) + self.observation_space.shape),
         dtype=self.observation_space.dtype
     )
     (logits, self._values_t) = rl.get_policy(
         self._observations_t, policy_hparams, self.action_space
     )
     actions = common_layers.sample_with_temperature(logits, sampling_temp)
     self._probs_t = tf.nn.softmax(logits / sampling_temp)
     self._actions_t = tf.cast(actions, tf.int32)
     model_saver = tf.train.Saver(
         tf.global_variables(policy_hparams.policy_network + "/.*")  # pylint: disable=unexpected-keyword-arg
     )
     self._sess = tf.Session()
     self._sess.run(tf.global_variables_initializer())
     trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess)
示例#2
0
  def evaluate(self, env_fn, hparams, stochastic):
    if stochastic:
      policy_to_actions_lambda = lambda policy: policy.sample()
    else:
      policy_to_actions_lambda = lambda policy: policy.mode()

    with tf.Graph().as_default():
      with tf.name_scope("rl_eval"):
        eval_env = env_fn(in_graph=True)
        (collect_memory, _, collect_init) = _define_collect(
            eval_env,
            hparams,
            "ppo_eval",
            eval_phase=True,
            frame_stack_size=self.frame_stack_size,
            force_beginning_resets=False,
            policy_to_actions_lambda=policy_to_actions_lambda)
        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))

        with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          collect_init(sess)
          trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver,
                                         sess)
          sess.run(collect_memory)
示例#3
0
def _run_train(ppo_hparams,
               event_dir,
               model_dir,
               restarter,
               train_summary_op,
               eval_summary_op,
               initializers,
               report_fn=None):
    """Train."""
    summary_writer = tf.summary.FileWriter(event_dir,
                                           graph=tf.get_default_graph(),
                                           flush_secs=60)

    model_saver = tf.train.Saver(
        tf.global_variables(ppo_hparams.policy_network + "/.*") +
        tf.global_variables("training/" + ppo_hparams.policy_network + "/.*") +
        # tf.global_variables("clean_scope.*") +  # Needed for sharing params.
        tf.global_variables("global_step") +
        tf.global_variables("losses_avg.*") +
        tf.global_variables("train_stats.*"))

    global_step = tf.train.get_or_create_global_step()
    with tf.control_dependencies([tf.assign_add(global_step, 1)]):
        train_summary_op = tf.identity(train_summary_op)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for initializer in initializers:
            initializer(sess)
        trainer_lib.restore_checkpoint(model_dir, model_saver, sess)

        num_target_iterations = restarter.target_local_step
        num_completed_iterations = num_target_iterations - restarter.steps_to_go
        with restarter.training_loop():
            for epoch_index in range(num_completed_iterations,
                                     num_target_iterations):
                summary = sess.run(train_summary_op)
                if summary_writer:
                    summary_writer.add_summary(summary, epoch_index)

                if (ppo_hparams.eval_every_epochs
                        and epoch_index % ppo_hparams.eval_every_epochs == 0):
                    eval_summary = sess.run(eval_summary_op)
                    if summary_writer:
                        summary_writer.add_summary(eval_summary, epoch_index)
                    if report_fn:
                        summary_proto = tf.Summary()
                        summary_proto.ParseFromString(eval_summary)
                        for elem in summary_proto.value:
                            if "mean_score" in elem.tag:
                                report_fn(elem.simple_value, epoch_index)
                                break

                if (model_saver and ppo_hparams.save_models_every_epochs and
                    (epoch_index % ppo_hparams.save_models_every_epochs == 0 or
                     (epoch_index + 1) == num_target_iterations)):
                    ckpt_path = os.path.join(
                        model_dir, "model.ckpt-{}".format(
                            tf.train.global_step(sess, global_step)))
                    model_saver.save(sess, ckpt_path)
示例#4
0
 def initialize(self, sess):
   model_loader = tf.train.Saver(
       var_list=tf.global_variables(scope="next_frame*")  # pylint:disable=unexpected-keyword-arg
   )
   trainer_lib.restore_checkpoint(
       self._model_dir, saver=model_loader, sess=sess, must_restore=True
   )
示例#5
0
    def evaluate(self, env_fn, hparams, sampling_temp):
        with tf.Graph().as_default():
            with tf.name_scope("rl_eval"):
                eval_env = env_fn(in_graph=True)
                (collect_memory, _, collect_init) = _define_collect(
                    eval_env,
                    hparams,
                    "ppo_eval",
                    eval_phase=True,
                    frame_stack_size=self.frame_stack_size,
                    force_beginning_resets=False,
                    sampling_temp=sampling_temp,
                    distributional_size=self._distributional_size,
                )
                model_saver = tf.train.Saver(
                    tf.global_variables(hparams.policy_network + "/.*")
                    # tf.global_variables("clean_scope.*")  # Needed for sharing params.
                )

                with tf.Session() as sess:
                    sess.run(tf.global_variables_initializer())
                    collect_init(sess)
                    trainer_lib.restore_checkpoint(self.agent_model_dir,
                                                   model_saver, sess)
                    sess.run(collect_memory)
 def initialize(self, sess):
   model_loader = tf.train.Saver(
       var_list=tf.global_variables(scope="next_frame*")  # pylint:disable=unexpected-keyword-arg
   )
   trainer_lib.restore_checkpoint(
       self._model_dir, saver=model_loader, sess=sess, must_restore=True
   )
def encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
                   out_files):
    """Encode all frames in dataset with model and write them out to out_files."""
    batch_size = 8
    dataset = dataset.batch(batch_size)
    examples = dataset.make_one_shot_iterator().get_next()
    images = examples.pop("frame")
    images = tf.cast(images, tf.int32)

    encoded = model.encode(images)
    encoded_frame_height = int(
        math.ceil(problem.frame_height / 2**ae_hparams.num_hidden_layers))
    encoded_frame_width = int(
        math.ceil(problem.frame_width / 2**ae_hparams.num_hidden_layers))
    num_bits = 8
    encoded = tf.reshape(
        encoded, [-1, encoded_frame_height, encoded_frame_width, 3, num_bits])
    encoded = tf.cast(discretization.bit_to_int(encoded, num_bits), tf.uint8)

    pngs = tf.map_fn(tf.image.encode_png,
                     encoded,
                     dtype=tf.string,
                     back_prop=False)

    with tf.Session() as sess:
        autoencoder_saver = tf.train.Saver(
            tf.global_variables("autoencoder.*"))
        trainer_lib.restore_checkpoint(autoencoder_path,
                                       autoencoder_saver,
                                       sess,
                                       must_restore=True)

        def generator():
            """Generate examples."""
            while True:
                try:
                    pngs_np, examples_np = sess.run([pngs, examples])
                    rewards = examples_np["reward"].tolist()
                    actions = examples_np["action"].tolist()
                    frame_numbers = examples_np["frame_number"].tolist()
                    for action, reward, frame_number, png in \
                            zip(actions, rewards, frame_numbers, pngs_np):
                        yield {
                            "action": action,
                            "reward": reward,
                            "frame_number": frame_number,
                            "image/encoded": [png],
                            "image/format": ["png"],
                            "image/height": [encoded_frame_height],
                            "image/width": [encoded_frame_width],
                        }
                except tf.errors.OutOfRangeError:
                    break

        generator_utils.generate_files(
            generator(),
            out_files,
            cycle_every_n=problem.total_number_of_frames // 10)
示例#8
0
def train(hparams, event_dir=None, model_dir=None,
          restore_agent=True, epoch=0):
  """Train."""
  with tf.name_scope("rl_train"):
    train_summary_op, _, initialization = define_train(hparams, event_dir)
    if event_dir:
      summary_writer = tf.summary.FileWriter(
          event_dir, graph=tf.get_default_graph(), flush_secs=60)
    if model_dir:
      model_saver = tf.train.Saver(
          tf.global_variables(".*network_parameters.*"))
    else:
      summary_writer = None
      model_saver = None

    # TODO(piotrmilos): This should be refactored, possibly with
    # handlers for each type of env
    if hparams.environment_spec.simulated_env:
      env_model_loader = tf.train.Saver(
          tf.global_variables("next_frame*"))
    else:
      env_model_loader = None

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      initialization(sess)
      if env_model_loader:
        trainer_lib.restore_checkpoint(
            hparams.world_model_dir, env_model_loader, sess, must_restore=True)
      start_step = 0
      if model_saver and restore_agent:
        start_step = trainer_lib.restore_checkpoint(
            model_dir, model_saver, sess)

      # Fail-friendly, don't train if already trained for this epoch
      if start_step >= ((hparams.epochs_num * (epoch + 1))):
        tf.logging.info("Skipping PPO training for epoch %d as train steps "
                        "(%d) already reached", epoch, start_step)
        return

      for epoch_index in range(hparams.epochs_num):
        summary = sess.run(train_summary_op)
        if summary_writer:
          summary_writer.add_summary(summary, epoch_index)
        if (hparams.eval_every_epochs and
            epoch_index % hparams.eval_every_epochs == 0):
          if summary_writer and summary:
            summary_writer.add_summary(summary, epoch_index)
          else:
            tf.logging.info("Eval summary not saved")
        if (model_saver and hparams.save_models_every_epochs and
            (epoch_index % hparams.save_models_every_epochs == 0 or
             (epoch_index + 1) == hparams.epochs_num)):
          ckpt_path = os.path.join(
              model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step))
          model_saver.save(sess, ckpt_path)
示例#9
0
def train(hparams,
          environment_spec,
          event_dir=None,
          model_dir=None,
          restore_agent=True):
    """Train."""
    with tf.name_scope("rl_train"):
        train_summary_op, eval_summary_op = define_train(
            hparams, environment_spec, event_dir)
        if event_dir:
            summary_writer = tf.summary.FileWriter(
                event_dir, graph=tf.get_default_graph(), flush_secs=60)
        if model_dir:
            model_saver = tf.train.Saver(
                tf.global_variables(".*network_parameters.*"))
        else:
            summary_writer = None
            model_saver = None

        if hparams.simulated_environment:
            env_model_loader = tf.train.Saver(
                tf.global_variables("basic_conv_gen.*"))
        else:
            env_model_loader = None

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if env_model_loader:
                trainer_lib.restore_checkpoint(hparams.world_model_dir,
                                               env_model_loader,
                                               sess,
                                               must_restore=True)
            start_step = 0
            if model_saver and restore_agent:
                start_step = trainer_lib.restore_checkpoint(
                    model_dir, model_saver, sess)
            for epoch_index in range(hparams.epochs_num):
                summary = sess.run(train_summary_op)
                if summary_writer:
                    summary_writer.add_summary(summary, epoch_index)
                if (hparams.eval_every_epochs
                        and epoch_index % hparams.eval_every_epochs == 0):
                    summary = sess.run(eval_summary_op)
                    if summary_writer and summary:
                        summary_writer.add_summary(summary, epoch_index)
                    else:
                        tf.logging.info("Eval summary not saved")
                if (model_saver and hparams.save_models_every_epochs and
                    (epoch_index % hparams.save_models_every_epochs == 0 or
                     (epoch_index + 1) == hparams.epochs_num)):
                    ckpt_path = os.path.join(
                        model_dir,
                        "model.ckpt-{}".format(epoch_index + start_step))
                    model_saver.save(sess, ckpt_path)
示例#10
0
 def initialize(self, sess):
     model_loader = tf.train.Saver(
         var_list=tf.global_variables(scope="next_frame*")  # pylint:disable=unexpected-keyword-arg
     )
     # TODO(afrozm): use TF methods to be on the safe side here.
     if os.path.isdir(self._model_dir):
         trainer_lib.restore_checkpoint(self._model_dir,
                                        saver=model_loader,
                                        sess=sess,
                                        must_restore=True)
     else:
         model_loader.restore(sess=sess, save_path=self._model_dir)
示例#11
0
def encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
                   out_files):
  """Encode all frames in dataset with model and write them out to out_files."""
  batch_size = 8
  dataset = dataset.batch(batch_size)
  examples = dataset.make_one_shot_iterator().get_next()
  images = examples.pop("frame")
  images = tf.expand_dims(images, 1)

  encoded = model.encode(images)
  encoded_frame_height = int(
      math.ceil(problem.frame_height / 2**ae_hparams.num_hidden_layers))
  encoded_frame_width = int(
      math.ceil(problem.frame_width / 2**ae_hparams.num_hidden_layers))
  num_bits = 8
  encoded = tf.reshape(
      encoded, [-1, encoded_frame_height, encoded_frame_width, 3, num_bits])
  encoded = tf.cast(discretization.bit_to_int(encoded, num_bits), tf.uint8)

  pngs = tf.map_fn(tf.image.encode_png, encoded, dtype=tf.string,
                   back_prop=False)

  with tf.Session() as sess:
    autoencoder_saver = tf.train.Saver(tf.global_variables("autoencoder.*"))
    trainer_lib.restore_checkpoint(autoencoder_path, autoencoder_saver, sess,
                                   must_restore=True)

    def generator():
      """Generate examples."""
      while True:
        try:
          pngs_np, examples_np = sess.run([pngs, examples])
          rewards_np = [list(el) for el in examples_np["reward"]]
          actions_np = [list(el) for el in examples_np["action"]]
          pngs_np = [el for el in pngs_np]
          for action, reward, png in zip(actions_np, rewards_np, pngs_np):
            yield {
                "action": action,
                "reward": reward,
                "image/encoded": [png],
                "image/format": ["png"],
                "image/height": [encoded_frame_height],
                "image/width": [encoded_frame_width],
            }
        except tf.errors.OutOfRangeError:
          break

    generator_utils.generate_files(
        generator(), out_files,
        cycle_every_n=problem.total_number_of_frames // 10)
示例#12
0
def evaluate(hparams, model_dir, name_scope="rl_eval"):
    """Evaluate."""
    hparams = copy.copy(hparams)
    hparams.add_hparam("eval_phase", True)
    with tf.Graph().as_default():
        with tf.name_scope(name_scope):
            (collect_memory, _,
             collect_init) = collect.define_collect(hparams, "ppo_eval")
            model_saver = tf.train.Saver(
                tf.global_variables(".*network_parameters.*"))

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                collect_init(sess)
                trainer_lib.restore_checkpoint(model_dir, model_saver, sess)
                sess.run(collect_memory)
def _run_train(ppo_hparams,
               event_dir,
               model_dir,
               num_target_iterations,
               train_summary_op,
               eval_summary_op,
               initializers,
               report_fn=None):
    """Train."""
    summary_writer = tf.summary.FileWriter(event_dir,
                                           graph=tf.get_default_graph(),
                                           flush_secs=60)

    model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for initializer in initializers:
            initializer(sess)
        num_completed_iterations = trainer_lib.restore_checkpoint(
            model_dir, model_saver, sess)

        # Fail-friendly, complete only unfinished epoch
        num_iterations_to_go = num_target_iterations - num_completed_iterations

        if num_iterations_to_go <= 0:
            tf.logging.info(
                "Skipping PPO training. Requested %d iterations while %d train "
                "iterations already reached", num_target_iterations,
                num_completed_iterations)
            return

        for epoch_index in range(num_iterations_to_go):
            summary = sess.run(train_summary_op)
            if summary_writer:
                summary_writer.add_summary(summary, epoch_index)

            if (ppo_hparams.eval_every_epochs
                    and epoch_index % ppo_hparams.eval_every_epochs == 0):
                eval_summary = sess.run(eval_summary_op)
                if summary_writer:
                    summary_writer.add_summary(eval_summary, epoch_index)
                if report_fn:
                    summary_proto = tf.Summary()
                    summary_proto.ParseFromString(eval_summary)
                    for elem in summary_proto.value:
                        if "mean_score" in elem.tag:
                            report_fn(elem.simple_value, epoch_index)
                            break

            epoch_index_and_start = epoch_index + num_completed_iterations
            if (model_saver and ppo_hparams.save_models_every_epochs and
                (epoch_index_and_start % ppo_hparams.save_models_every_epochs
                 == 0 or (epoch_index + 1) == num_iterations_to_go)):
                ckpt_path = os.path.join(
                    model_dir,
                    "model.ckpt-{}".format(epoch_index + 1 +
                                           num_completed_iterations))
                model_saver.save(sess, ckpt_path)
示例#14
0
    def __init__(self,
                 environment_spec,
                 batch_size,
                 model_dir=None,
                 sess=None):
        self.batch_size = batch_size

        with tf.Graph().as_default():
            self._batch_env = SimulatedBatchEnv(environment_spec,
                                                self.batch_size)

            self.action_space = self._batch_env.action_space
            # TODO(kc): check for the stack wrapper and correct number of channels in
            # observation_space
            self.observation_space = self._batch_env.observ_space
            self._sess = sess if sess is not None else tf.Session()
            self._to_initialize = [self._batch_env]

            environment_wrappers = environment_spec.wrappers
            wrappers = copy.copy(
                environment_wrappers) if environment_wrappers else []

            for w in wrappers:
                self._batch_env = w[0](self._batch_env, **w[1])
                self._to_initialize.append(self._batch_env)

            self._sess.run(tf.global_variables_initializer())
            for wrapped_env in self._to_initialize:
                wrapped_env.initialize(self._sess)

            self._actions_t = tf.placeholder(shape=(batch_size, ),
                                             dtype=tf.int32)
            self._rewards_t, self._dones_t = self._batch_env.simulate(
                self._actions_t)
            self._obs_t = self._batch_env.observ
            self._reset_op = self._batch_env.reset(
                tf.range(batch_size, dtype=tf.int32))

            env_model_loader = tf.train.Saver(
                var_list=tf.global_variables(scope="next_frame*"))  # pylint:disable=unexpected-keyword-arg
            trainer_lib.restore_checkpoint(model_dir,
                                           saver=env_model_loader,
                                           sess=self._sess,
                                           must_restore=True)
    def __init__(self, hparams, action_space, observation_space, policy_dir):
        assert hparams.base_algo == "ppo"
        ppo_hparams = trainer_lib.create_hparams(hparams.base_algo_params)

        frame_stack_shape = (
            1, hparams.frame_stack_size) + observation_space.shape
        self._frame_stack = np.zeros(frame_stack_shape, dtype=np.uint8)

        with tf.Graph().as_default():
            self.obs_t = tf.placeholder(shape=self.frame_stack_shape,
                                        dtype=np.uint8)
            self.logits_t, self.value_function_t = get_policy(
                self.obs_t, ppo_hparams, action_space)
            model_saver = tf.train.Saver(
                tf.global_variables(scope=ppo_hparams.policy_network + "/.*")  # pylint: disable=unexpected-keyword-arg
            )
            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())
            trainer_lib.restore_checkpoint(policy_dir, model_saver, self.sess)
def train(hparams,
          event_dir=None,
          model_dir=None,
          restore_agent=True,
          name_scope="rl_train",
          report_fn=None):
    """Train."""
    with tf.Graph().as_default():
        with tf.name_scope(name_scope):
            train_summary_op, eval_summary_op, initializers = define_train(
                hparams)
            if event_dir:
                summary_writer = tf.summary.FileWriter(
                    event_dir, graph=tf.get_default_graph(), flush_secs=60)
            else:
                summary_writer = None

            if model_dir:
                model_saver = tf.train.Saver(
                    tf.global_variables(".*network_parameters.*"))
            else:
                model_saver = None

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                for initializer in initializers:
                    initializer(sess)
                start_step = 0
                if model_saver and restore_agent:
                    start_step = trainer_lib.restore_checkpoint(
                        model_dir, model_saver, sess)

                # Fail-friendly, complete only unfinished epoch
                steps_to_go = hparams.epochs_num - start_step

                if steps_to_go <= 0:
                    tf.logging.info(
                        "Skipping PPO training. Requested %d steps while "
                        "%d train steps already reached", hparams.epochs_num,
                        start_step)
                    return

                for epoch_index in range(steps_to_go):
                    summary = sess.run(train_summary_op)
                    if summary_writer:
                        summary_writer.add_summary(summary, epoch_index)

                    if (hparams.eval_every_epochs
                            and epoch_index % hparams.eval_every_epochs == 0):
                        eval_summary = sess.run(eval_summary_op)
                        if summary_writer:
                            summary_writer.add_summary(eval_summary,
                                                       epoch_index)
                        if report_fn:
                            summary_proto = tf.Summary()
                            summary_proto.ParseFromString(eval_summary)
                            for elem in summary_proto.value:
                                if "mean_score" in elem.tag:
                                    report_fn(elem.simple_value, epoch_index)
                                    break

                    epoch_index_and_start = epoch_index + start_step
                    if (model_saver and hparams.save_models_every_epochs
                            and (epoch_index_and_start %
                                 hparams.save_models_every_epochs == 0 or
                                 (epoch_index + 1) == steps_to_go)):
                        ckpt_path = os.path.join(
                            model_dir, "model.ckpt-{}".format(epoch_index + 1 +
                                                              start_step))
                        model_saver.save(sess, ckpt_path)
示例#17
0
def main(_):
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    output_dir = FLAGS.output_dir

    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game

    if using_autoencoder:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
    else:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        if simulated_problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                simulated_problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game, game_mode="Deterministic-v4")

    epoch = hparams.epochs - 1
    epoch_data_dir = os.path.join(directories["data"], str(epoch))
    ppo_model_dir = directories["ppo"]

    world_model_dir = directories["world_model"]

    gym_problem = registry.problem(simulated_problem_name)

    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    environment_spec = copy.copy(gym_problem.environment_spec)
    environment_spec.simulation_random_starts = hparams.simulation_random_starts

    batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    batch_env_hparams.add_hparam("model_hparams", model_hparams)
    batch_env_hparams.add_hparam("environment_spec", environment_spec)
    batch_env_hparams.num_agents = 1

    with temporary_flags({
            "problem": simulated_problem_name,
            "model": hparams.generative_model,
            "hparams_set": hparams.generative_model_params,
            "output_dir": world_model_dir,
            "data_dir": epoch_data_dir,
    }):
        sess = tf.Session()
        env = DebugBatchEnv(batch_env_hparams, sess)
        sess.run(tf.global_variables_initializer())
        env.initialize()

        env_model_loader = tf.train.Saver(tf.global_variables("next_frame*"))
        trainer_lib.restore_checkpoint(world_model_dir,
                                       env_model_loader,
                                       sess,
                                       must_restore=True)

        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))
        trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

        key_mapping = gym_problem.env.env.get_keys_to_action()
        # map special codes
        key_mapping[()] = 100
        key_mapping[(ord("r"), )] = 101
        key_mapping[(ord("p"), )] = 102

        play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
示例#18
0
def train(hparams,
          event_dir=None,
          model_dir=None,
          restore_agent=True,
          name_scope="rl_train"):
    """Train."""
    with tf.Graph().as_default():
        with tf.name_scope(name_scope):
            train_summary_op, _, initialization = define_train(hparams)
            if event_dir:
                summary_writer = tf.summary.FileWriter(
                    event_dir, graph=tf.get_default_graph(), flush_secs=60)
            else:
                summary_writer = None

            if model_dir:
                model_saver = tf.train.Saver(
                    tf.global_variables(".*network_parameters.*"))
            else:
                model_saver = None

            # TODO(piotrmilos): This should be refactored, possibly with
            # handlers for each type of env
            if hparams.environment_spec.simulated_env:
                env_model_loader = tf.train.Saver(
                    tf.global_variables("next_frame*"))
            else:
                env_model_loader = None

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                initialization(sess)
                if env_model_loader:
                    trainer_lib.restore_checkpoint(hparams.world_model_dir,
                                                   env_model_loader,
                                                   sess,
                                                   must_restore=True)
                start_step = 0
                if model_saver and restore_agent:
                    start_step = trainer_lib.restore_checkpoint(
                        model_dir, model_saver, sess)

                # Fail-friendly, complete only unfinished epoch
                steps_to_go = hparams.epochs_num - start_step

                if steps_to_go <= 0:
                    tf.logging.info(
                        "Skipping PPO training. Requested %d steps while "
                        "%d train steps already reached", hparams.epochs_num,
                        start_step)
                    return

                for epoch_index in range(steps_to_go):
                    summary = sess.run(train_summary_op)
                    if summary_writer:
                        summary_writer.add_summary(summary, epoch_index)
                    if (hparams.eval_every_epochs
                            and epoch_index % hparams.eval_every_epochs == 0):
                        if summary_writer and summary:
                            summary_writer.add_summary(summary, epoch_index)
                        else:
                            tf.logging.info("Eval summary not saved")
                    epoch_index_and_start = epoch_index + start_step
                    if (model_saver and hparams.save_models_every_epochs
                            and (epoch_index_and_start %
                                 hparams.save_models_every_epochs == 0 or
                                 (epoch_index + 1) == steps_to_go)):
                        ckpt_path = os.path.join(
                            model_dir, "model.ckpt-{}".format(epoch_index + 1 +
                                                              start_step))
                        model_saver.save(sess, ckpt_path)
def main(_):
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  output_dir = FLAGS.output_dir

  subdirectories = ["data", "tmp", "world_model", "ppo"]
  using_autoencoder = hparams.autoencoder_train_steps > 0
  if using_autoencoder:
    subdirectories.append("autoencoder")
  directories = setup_directories(output_dir, subdirectories)

  if hparams.game in gym_env.ATARI_GAMES:
    game_with_mode = hparams.game + "_deterministic-v4"
  else:
    game_with_mode = hparams.game

  if using_autoencoder:
    simulated_problem_name = (
        "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded"
        % game_with_mode)
  else:
    simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s"
                              % game_with_mode)
    if simulated_problem_name not in registry.list_problems():
      tf.logging.info("Game Problem %s not found; dynamically registering",
                      simulated_problem_name)
      gym_env.register_game(hparams.game, game_mode="Deterministic-v4")

  epoch = hparams.epochs-1
  epoch_data_dir = os.path.join(directories["data"], str(epoch))
  ppo_model_dir = directories["ppo"]

  world_model_dir = directories["world_model"]

  gym_problem = registry.problem(simulated_problem_name)

  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  environment_spec = copy.copy(gym_problem.environment_spec)
  environment_spec.simulation_random_starts = hparams.simulation_random_starts

  batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  batch_env_hparams.add_hparam("model_hparams", model_hparams)
  batch_env_hparams.add_hparam("environment_spec", environment_spec)
  batch_env_hparams.num_agents = 1

  with temporary_flags({
      "problem": simulated_problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "output_dir": world_model_dir,
      "data_dir": epoch_data_dir,
  }):
    sess = tf.Session()
    env = DebugBatchEnv(batch_env_hparams, sess)
    sess.run(tf.global_variables_initializer())
    env.initialize()

    env_model_loader = tf.train.Saver(
        tf.global_variables("next_frame*"))
    trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess,
                                   must_restore=True)

    model_saver = tf.train.Saver(
        tf.global_variables(".*network_parameters.*"))
    trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

    key_mapping = gym_problem.env.env.get_keys_to_action()
    # map special codes
    key_mapping[()] = 100
    key_mapping[(ord("r"),)] = 101
    key_mapping[(ord("p"),)] = 102

    play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)