def forward(self, inputs, length, initial_inputs=None, static_inputs=None): """ :param inputs: These are sliced by time. Time is the second dimension :param length: Rollout length :param initial_inputs: These are not sliced and are overridden by cell output :param static_inputs: These are not sliced and can't be overridden by cell output :return: """ # NOTE! Unrolling the cell directly will result in crash as the hidden state is not being reset # Use this function or CustomLSTMCell.unroll if needed initial_inputs, static_inputs = self.assert_begin(inputs, initial_inputs, static_inputs) step_inputs = initial_inputs.copy() step_inputs.update(static_inputs) lstm_outputs = [] for t in range(length): step_inputs.update(map_dict(lambda x: x[:, t], inputs)) # Slicing output = self.cell(**step_inputs) self.assert_post(output, inputs, initial_inputs, static_inputs) # TODO Test what signature does with *args autoregressive_output = subdict(output, output.keys() & signature(self.cell.forward).parameters) step_inputs.update(autoregressive_output) lstm_outputs.append(output) # TODO recursively stack outputs lstm_outputs = listdict2dictlist(lstm_outputs) lstm_outputs = map_dict(lambda x: stack(x, dim=1), lstm_outputs) self.cell.reset() return lstm_outputs
def sample_batch(self, batch_size, is_train=True, global_step=None): """Samples an experience batch of the required size.""" experience_batch = [] step = 0 with self._env.val_mode() if not is_train else contextlib.suppress(): with self._agent.val_mode() if not is_train else contextlib.suppress(): with self._agent.rollout_mode(): while step < batch_size: # perform one rollout step agent_output = self.sample_action(self._obs) if agent_output.action is None: self._episode_reset(global_step) continue agent_output = self._postprocess_agent_output(agent_output) obs, reward, done, info = self._env.step(agent_output.action) obs = self._postprocess_obs(obs) experience_batch.append(AttrDict( observation=self._obs, reward=reward, done=done, action=agent_output.action, observation_next=obs, )) # update stored observation self._obs = obs step += 1; self._episode_step += 1; self._episode_reward += reward # reset if episode ends if done or self._episode_step >= self._max_episode_len: if not done: # force done to be True for timeout experience_batch[-1].done = True self._episode_reset(global_step) return listdict2dictlist(experience_batch), step
def sample_episode(self, is_train, render=False): """Samples one episode from the environment.""" self.init(is_train) episode, done = [], False with self._env.val_mode() if not is_train else contextlib.suppress(): with self._agent.val_mode() if not is_train else contextlib.suppress(): with self._agent.rollout_mode(): while not done and self._episode_step < self._max_episode_len: # perform one rollout step agent_output = self.sample_action(self._obs) if agent_output.action is None: break agent_output = self._postprocess_agent_output(agent_output) if render: render_obs = self._env.render() obs, reward, done, info = self._env.step(agent_output.action) obs = self._postprocess_obs(obs) episode.append(AttrDict( observation=self._obs, reward=reward, done=done, action=agent_output.action, observation_next=obs, info=obj2np(info), )) if render: episode[-1].update(AttrDict(image=render_obs)) # update stored observation self._obs = obs self._episode_step += 1 episode[-1].done = True # make sure episode is marked as done at final time step return listdict2dictlist(episode)
def sample_batch(self, batch_size, is_train=True, global_step=None, store_ll=True): """Samples the required number of high-level transitions. Number of LL transitions can be higher.""" hl_experience_batch, ll_experience_batch = [], [] env_steps, hl_step = 0, 0 with self._env.val_mode() if not is_train else contextlib.suppress(): with self._agent.val_mode() if not is_train else contextlib.suppress(): with self._agent.rollout_mode(): while hl_step < batch_size or len(ll_experience_batch) <= 1: # perform one rollout step agent_output = self.sample_action(self._obs) agent_output = self._postprocess_agent_output(agent_output) obs, reward, done, info = self._env.step(agent_output.action) obs = self._postprocess_obs(obs) # update last step's 'observation_next' with HL action if store_ll: if ll_experience_batch: ll_experience_batch[-1].observation_next = \ self._agent.make_ll_obs(ll_experience_batch[-1].observation_next, agent_output.hl_action) # store current step in ll_experience_batch ll_experience_batch.append(AttrDict( observation=self._agent.make_ll_obs(self._obs, agent_output.hl_action), reward=reward, done=done, action=agent_output.action, observation_next=obs, # this will get updated in the next step )) # store HL experience batch if this was HL action or episode is done if agent_output.is_hl_step or (done or self._episode_step >= self._max_episode_len-1): if self.last_hl_obs is not None and self.last_hl_action is not None: hl_experience_batch.append(AttrDict( observation=self.last_hl_obs, reward=self.reward_since_last_hl, done=done, action=self.last_hl_action, observation_next=obs, )) hl_step += 1 if hl_step % 1000 == 0: print("Sample step {}".format(hl_step)) self.last_hl_obs = self._obs self.last_hl_action = agent_output.hl_action self.reward_since_last_hl = 0 # update stored observation self._obs = obs env_steps += 1; self._episode_step += 1; self._episode_reward += reward self.reward_since_last_hl += reward # reset if episode ends if done or self._episode_step >= self._max_episode_len: if not done: # force done to be True for timeout ll_experience_batch[-1].done = True if hl_experience_batch: # can potentially be empty hl_experience_batch[-1].done = True self._episode_reset(global_step) return AttrDict( hl_batch=listdict2dictlist(hl_experience_batch), ll_batch=listdict2dictlist(ll_experience_batch[:-1]), # last element does not have updated obs_next! ), env_steps