예제 #1
0
    def processed_experience_spec(self):
        """Spec for processed experience.

        Returns:
            Spec for the experience returned by preprocess_experience().
        """
        if self._processed_experience_spec is not None:
            return self._processed_experience_spec
        batch_size = 4
        exp = common.zeros_from_spec(self.experience_spec, batch_size)
        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self.experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            rollout_info=common.extract_spec(processed_exp.rollout_info))
        if not self._use_rollout_state:
            self._procesed_experience_spec = \
                self._processed_experience_spec._replace(state=())
        return self._processed_experience_spec
예제 #2
0
 def rollout_info_spec(self):
     """The spec for the PolicyInfo.info returned from rollout()."""
     if self._rollout_info_spec is not None:
         return self._rollout_info_spec
     batch_size = 4
     time_step = common.zeros_from_spec(self.time_step_spec, batch_size)
     state = common.zeros_from_spec(self.train_state_spec, batch_size)
     policy_step = self.rollout(self.transform_timestep(time_step), state,
                                RLAlgorithm.PREPARE_SPEC)
     self._rollout_info_spec = common.extract_spec(policy_step.info)
     return self._rollout_info_spec
예제 #3
0
 def train_step_info_spec(self):
     """The spec for the PolicyInfo.info returned from train_step()."""
     if self._train_step_info_spec is not None:
         return self._train_step_info_spec
     batch_size = 4
     processed_exp = common.zeros_from_spec(self.processed_experience_spec,
                                            batch_size)
     state = common.zeros_from_spec(self.train_state_spec, batch_size)
     policy_step = self.train_step(processed_exp, state)
     self._train_step_info_spec = common.extract_spec(policy_step.info)
     return self._train_step_info_spec
예제 #4
0
    def _prepare_specs(self, algorithm):
        """Prepare various tensor specs."""

        time_step = self.get_initial_time_step()
        self._time_step_spec = common.extract_spec(time_step)
        self._action_spec = self._env.action_spec()

        policy_step = algorithm.rollout(
            algorithm.transform_timestep(time_step), self._initial_state)
        info_spec = common.extract_spec(policy_step.info)
        self._policy_step_spec = PolicyStep(
            action=self._action_spec,
            state=algorithm.train_state_spec,
            info=info_spec)

        self._action_distribution_spec = tf.nest.map_structure(
            common.to_distribution_spec, algorithm.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        algorithm.prepare_off_policy_specs(time_step)
예제 #5
0
    def prepare_off_policy_specs(self, time_step: ActionTimeStep):
        """Prepare various tensor specs for off_policy training.

        prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec().

        """

        self._env_batch_size = time_step.step_type.shape[0]
        self._time_step_spec = common.extract_spec(time_step)
        initial_state = common.get_initial_policy_state(
            self._env_batch_size, self.train_state_spec)
        transformed_timestep = self.transform_timestep(time_step)
        policy_step = self.rollout(transformed_timestep, initial_state)
        info_spec = common.extract_spec(policy_step.info)

        self._action_distribution_spec = tf.nest.map_structure(
            common.to_distribution_spec, self.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec,
            state=self.train_state_spec if self._use_rollout_state else ())

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env_batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)

        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist,
                         state=initial_state if self._use_rollout_state else
                         ())

        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self._experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            info=common.extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm_step_func=self.train_step,
            time_step=processed_exp,
            state=initial_state)
        info_spec = common.extract_spec(policy_step.info)
        self._training_info_spec = TrainingInfo(
            action_distribution=self._action_dist_param_spec, info=info_spec)