Esempio n. 1
0
 def rollout_info_spec(self):
     """The spec for the PolicyInfo.info returned from rollout()."""
     if self._rollout_info_spec is not None:
         return self._rollout_info_spec
     batch_size = 4
     time_step = common.zeros_from_spec(self.time_step_spec, batch_size)
     state = common.zeros_from_spec(self.train_state_spec, batch_size)
     policy_step = self.rollout(self.transform_timestep(time_step), state,
                                RLAlgorithm.PREPARE_SPEC)
     self._rollout_info_spec = common.extract_spec(policy_step.info)
     return self._rollout_info_spec
Esempio n. 2
0
 def train_step_info_spec(self):
     """The spec for the PolicyInfo.info returned from train_step()."""
     if self._train_step_info_spec is not None:
         return self._train_step_info_spec
     batch_size = 4
     processed_exp = common.zeros_from_spec(self.processed_experience_spec,
                                            batch_size)
     state = common.zeros_from_spec(self.train_state_spec, batch_size)
     policy_step = self.train_step(processed_exp, state)
     self._train_step_info_spec = common.extract_spec(policy_step.info)
     return self._train_step_info_spec
Esempio n. 3
0
    def processed_experience_spec(self):
        """Spec for processed experience.

        Returns:
            Spec for the experience returned by preprocess_experience().
        """
        if self._processed_experience_spec is not None:
            return self._processed_experience_spec
        batch_size = 4
        exp = common.zeros_from_spec(self.experience_spec, batch_size)
        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self.experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            rollout_info=common.extract_spec(processed_exp.rollout_info))
        if not self._use_rollout_state:
            self._procesed_experience_spec = \
                self._processed_experience_spec._replace(state=())
        return self._processed_experience_spec