def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary( policy_state, self._initial_policy_state, time_step.is_first()) self._actor_q.enqueue([time_step, policy_state, self._id]) policy_step, act_dist_param = self._action_return_q.dequeue() action = policy_step.action next_time_step = make_action_time_step(self._env.step(action), action) # temporarily store the transition into a local queue self._unroll_queue.enqueue( [time_step, policy_step, act_dist_param, next_time_step]) return next_time_step, policy_step.state
def get_initial_time_step(env): """ Return the initial time step Args: env (TFPyEnvironment): Returns: time_step (ActionTimeStep): the init time step with actions as zero tensors """ time_step = env.current_time_step() action = zero_tensor_from_nested_spec(env.action_spec(), env.batch_size) return make_action_time_step(time_step, action)
def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary( policy_state, self._initial_policy_state, time_step.is_first()) self._actor_q.enqueue([time_step, policy_state, self._id]) policy_step, act_dist_param = self._action_return_q.dequeue() action = policy_step.action next_time_step = make_action_time_step(self._env.step(action), action) # temporarily store the transition into a local queue self._unroll_queue.enqueue( LearningBatch(time_step=time_step, state=policy_state if self._tfq._store_state else (), policy_step=policy_step, act_dist_param=act_dist_param, next_time_step=next_time_step, env_id=())) return [next_time_step, policy_step.state]
def _env_step(self, action): time_step = self._env.step(action) return make_action_time_step(time_step, action)