Ejemplo n.º 1
0
def define_train(hparams, environment_spec, event_dir):
    """Define the training setup."""
    policy_lambda = hparams.network

    if environment_spec == "stacked_pong":
        environment_spec = lambda: gym.make("PongNoFrameskip-v4")
        wrappers = hparams.in_graph_wrappers if hasattr(
            hparams, "in_graph_wrappers") else []
        wrappers.append((tf_atari_wrappers.MaxAndSkipWrapper, {"skip": 4}))
        hparams.in_graph_wrappers = wrappers
    if isinstance(environment_spec, str):
        env_lambda = lambda: gym.make(environment_spec)
    else:
        env_lambda = environment_spec

    batch_env = utils.batch_env_factory(env_lambda,
                                        hparams,
                                        num_agents=hparams.num_agents)

    policy_factory = functools.partial(policy_lambda, batch_env.action_space,
                                       hparams)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        memory, collect_summary = collect.define_collect(
            policy_factory,
            batch_env,
            hparams,
            eval_phase=False,
            on_simulated=hparams.simulated_environment)
        ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
        summary = tf.summary.merge([collect_summary, ppo_summary])

    with tf.variable_scope("eval", reuse=tf.AUTO_REUSE):
        eval_env_lambda = env_lambda
        if event_dir and hparams.video_during_eval:
            # Some environments reset environments automatically, when reached done
            # state. For them we shall record only every second episode.
            d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
            eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
                env_lambda(),
                event_dir,
                video_callable=lambda i: i % d == 0)
            eval_env_lambda = (
                lambda: utils.EvalVideoWrapper(eval_env_lambda()))
        eval_batch_env = utils.batch_env_factory(
            eval_env_lambda,
            hparams,
            num_agents=hparams.num_eval_agents,
            xvfb=hparams.video_during_eval)

        # TODO(blazej0): correct to the version below.
        corrected = True
        eval_summary = tf.no_op()
        if corrected:
            _, eval_summary = collect.define_collect(policy_factory,
                                                     eval_batch_env,
                                                     hparams,
                                                     eval_phase=True)
    return summary, eval_summary
Ejemplo n.º 2
0
  def _setup(self):
    in_graph_wrappers = [(atari.MemoryWrapper, {})] + self.in_graph_wrappers
    env_hparams = tf.contrib.training.HParams(
        in_graph_wrappers=in_graph_wrappers,
        simulated_environment=self.simulated_environment)

    generator_batch_env = batch_env_factory(
        self.environment_spec, env_hparams, num_agents=1, xvfb=False)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      if FLAGS.agent_policy_path:
        policy_lambda = self.collect_hparams.network
      else:
        # When no agent_policy_path is set, just generate random samples.
        policy_lambda = rl.random_policy_fun
      policy_factory = tf.make_template(
          "network",
          functools.partial(policy_lambda, self.environment_spec().action_space,
                            self.collect_hparams),
          create_scope_now_=True,
          unique_name_="network")

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      self.collect_hparams.epoch_length = 10
      _, self.collect_trigger_op = collect.define_collect(
          policy_factory, generator_batch_env, self.collect_hparams,
          eval_phase=False, scope="define_collect")

    self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
    self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
Ejemplo n.º 3
0
  def _setup(self):
    in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}),
                         (atari.MemoryWrapper, {})] + self.in_graph_wrappers
    env_hparams = tf.contrib.training.HParams(
        in_graph_wrappers=in_graph_wrappers,
        simulated_environment=self.simulated_environment)

    generator_batch_env = batch_env_factory(
        self.environment_spec, env_hparams, num_agents=1, xvfb=False)

    with tf.variable_scope("", reuse=tf.AUTO_REUSE):
      policy_lambda = self.collect_hparams.network
      policy_factory = tf.make_template(
          "network",
          functools.partial(policy_lambda, self.environment_spec().action_space,
                            self.collect_hparams),
          create_scope_now_=True,
          unique_name_="network")

    with tf.variable_scope("", reuse=tf.AUTO_REUSE):
      sample_policy = lambda policy: 0 * policy.sample()

      self.collect_hparams.epoch_length = 10
      _, self.collect_trigger_op = collect.define_collect(
          policy_factory, generator_batch_env, self.collect_hparams,
          eval_phase=False, policy_to_actions_lambda=sample_policy,
          scope="define_collect")

    self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
    self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
    self.history_buffer = deque(maxlen=self.history_size+1)
Ejemplo n.º 4
0
    def _setup(self):
        if self.make_extra_debug_info:
            self.report_reward_statistics_every = 10
            self.dones = 0
            self.real_reward = 0
            self.real_env.reset()
            # Slight weirdness to make sim env and real env aligned
            for _ in range(simulated_batch_env.SimulatedBatchEnv.
                           NUMBER_OF_HISTORY_FRAMES):
                self.real_ob, _, _, _ = self.real_env.step(0)
            self.total_sim_reward, self.total_real_reward = 0.0, 0.0
            self.successful_dones = 0

        in_graph_wrappers = self.in_graph_wrappers + [(atari.MemoryWrapper, {})
                                                      ]
        env_hparams = tf.contrib.training.HParams(
            in_graph_wrappers=in_graph_wrappers,
            simulated_environment=self.simulated_environment)

        generator_batch_env = batch_env_factory(self.environment_spec,
                                                env_hparams,
                                                num_agents=1,
                                                xvfb=False)

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            if FLAGS.agent_policy_path:
                policy_lambda = self.collect_hparams.network
            else:
                # When no agent_policy_path is set, just generate random samples.
                policy_lambda = rl.random_policy_fun
            policy_factory = tf.make_template(
                "network",
                functools.partial(policy_lambda,
                                  self.environment_spec().action_space,
                                  self.collect_hparams),
                create_scope_now_=True,
                unique_name_="network")

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            self.collect_hparams.epoch_length = 10
            _, self.collect_trigger_op = collect.define_collect(
                policy_factory,
                generator_batch_env,
                self.collect_hparams,
                eval_phase=False,
                scope="define_collect")

        self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size(
        )
        self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
Ejemplo n.º 5
0
    def _setup(self):
        if self.make_extra_debug_info:
            self.report_reward_statistics_every = 10
            self.dones = 0
            self.real_reward = 0
            # Slight weirdness to make sim env and real env aligned
            if self.simulated_environment:
                self.real_env.reset()
                for _ in range(self.num_input_frames):
                    self.real_ob, _, _, _ = self.real_env.step(0)
            self.total_sim_reward, self.total_real_reward = 0.0, 0.0
            self.sum_of_rewards = 0.0
            self.successful_episode_reward_predictions = 0

        in_graph_wrappers = self.in_graph_wrappers + [
            (atari.MemoryWrapper, {}), (StackAndSkipWrapper, {
                "skip": 4
            })
        ]
        env_hparams = tf.contrib.training.HParams(
            in_graph_wrappers=in_graph_wrappers,
            problem=self.real_env_problem if self.real_env_problem else self,
            simulated_environment=self.simulated_environment)
        if self.simulated_environment:
            env_hparams.add_hparam("simulation_random_starts",
                                   self.simulation_random_starts)
            env_hparams.add_hparam("intrinsic_reward_scale",
                                   self.intrinsic_reward_scale)

        generator_batch_env = batch_env_factory(self.environment_spec,
                                                env_hparams,
                                                num_agents=1,
                                                xvfb=False)

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            if FLAGS.agent_policy_path:
                policy_lambda = self.collect_hparams.network
            else:
                # When no agent_policy_path is set, just generate random samples.
                policy_lambda = rl.random_policy_fun

        if FLAGS.autoencoder_path:
            # TODO(lukaszkaiser): remove hard-coded autoencoder params.
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                self.setup_autoencoder()
                autoencoder_model = self.autoencoder_model
                # Feeds for autoencoding.
                shape = [
                    self.raw_frame_height, self.raw_frame_width,
                    self.num_channels
                ]
                self.autoencoder_feed = tf.placeholder(tf.int32, shape=shape)
                self.autoencoder_result = self.autoencode_tensor(
                    self.autoencoder_feed)
                # Now for autodecoding.
                shape = self.frame_shape
                self.autodecoder_feed = tf.placeholder(tf.int32, shape=shape)
                bottleneck = tf.reshape(
                    discretization.int_to_bit(self.autodecoder_feed, 8), [
                        1, 1, self.frame_height, self.frame_width,
                        self.num_channels * 8
                    ])
                autoencoder_model.set_mode(tf.estimator.ModeKeys.PREDICT)
                self.autodecoder_result = autoencoder_model.decode(bottleneck)

        def preprocess_fn(x):
            shape = [
                self.raw_frame_height, self.raw_frame_width, self.num_channels
            ]
            # TODO(lukaszkaiser): we assume x comes from StackAndSkipWrapper skip=4.
            xs = [tf.reshape(t, [1] + shape) for t in tf.split(x, 4, axis=-1)]
            autoencoded = self.autoencode_tensor(tf.concat(xs, axis=0),
                                                 batch_size=4)
            encs = [
                tf.squeeze(t, axis=[0])
                for t in tf.split(autoencoded, 4, axis=0)
            ]
            res = tf.to_float(tf.concat(encs, axis=-1))
            return tf.expand_dims(res, axis=0)

        # TODO(lukaszkaiser): x is from StackAndSkipWrapper thus 4*num_channels.
        shape = [1, self.frame_height, self.frame_width, 4 * self.num_channels]
        do_preprocess = (self.autoencoder_model is not None
                         and not self.simulated_environment)
        preprocess = (preprocess_fn, shape) if do_preprocess else None

        def policy(x):
            return policy_lambda(self.environment_spec().action_space,
                                 self.collect_hparams, x)

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            self.collect_hparams.epoch_length = 10
            _, self.collect_trigger_op = collect.define_collect(
                policy,
                generator_batch_env,
                self.collect_hparams,
                eval_phase=self.eval_phase,
                scope="define_collect",
                preprocess=preprocess)

        self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size(
        )
        self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
Ejemplo n.º 6
0
  def _setup(self):
    if self.make_extra_debug_info:
      self.report_reward_statistics_every = 10
      self.dones = 0
      self.real_reward = 0
      self.real_env.reset()
      # Slight weirdness to make sim env and real env aligned
      for _ in range(self.num_input_frames):
        self.real_ob, _, _, _ = self.real_env.step(0)
      self.total_sim_reward, self.total_real_reward = 0.0, 0.0
      self.sum_of_rewards = 0.0
      self.successful_episode_reward_predictions = 0

    in_graph_wrappers = self.in_graph_wrappers + [(atari.MemoryWrapper, {})]
    env_hparams = tf.contrib.training.HParams(
        in_graph_wrappers=in_graph_wrappers,
        problem=self,
        simulated_environment=self.simulated_environment)

    generator_batch_env = batch_env_factory(
        self.environment_spec, env_hparams, num_agents=1, xvfb=False)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      if FLAGS.agent_policy_path:
        policy_lambda = self.collect_hparams.network
      else:
        # When no agent_policy_path is set, just generate random samples.
        policy_lambda = rl.random_policy_fun
      policy_factory = tf.make_template(
          "network",
          functools.partial(policy_lambda, self.environment_spec().action_space,
                            self.collect_hparams),
          create_scope_now_=True,
          unique_name_="network")

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      self.collect_hparams.epoch_length = 10
      _, self.collect_trigger_op = collect.define_collect(
          policy_factory, generator_batch_env, self.collect_hparams,
          eval_phase=False, scope="define_collect")

    if FLAGS.autoencoder_path:
      # TODO(lukaszkaiser): remove hard-coded autoencoder params.
      with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        self.setup_autoencoder()
        autoencoder_model = self.autoencoder_model
        # Feeds for autoencoding.
        shape = [self.raw_frame_height, self.raw_frame_width, self.num_channels]
        self.autoencoder_feed = tf.placeholder(tf.int32, shape=shape)
        autoencoded = autoencoder_model.encode(
            tf.reshape(self.autoencoder_feed, [1, 1] + shape))
        autoencoded = tf.reshape(
            autoencoded, [self.frame_height, self.frame_width,
                          self.num_channels, 8])  # 8-bit groups.
        self.autoencoder_result = discretization.bit_to_int(autoencoded, 8)
        # Now for autodecoding.
        shape = [self.frame_height, self.frame_width, self.num_channels]
        self.autodecoder_feed = tf.placeholder(tf.int32, shape=shape)
        bottleneck = tf.reshape(
            discretization.int_to_bit(self.autodecoder_feed, 8),
            [1, 1, self.frame_height, self.frame_width, self.num_channels * 8])
        autoencoder_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        self.autodecoder_result = autoencoder_model.decode(bottleneck)

    self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
    self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()