예제 #1
0
    def forward(self, inputs, length, initial_inputs=None, static_inputs=None):
        """
        
        :param inputs: These are sliced by time. Time is the second dimension
        :param length: Rollout length
        :param initial_inputs: These are not sliced and are overridden by cell output
        :param static_inputs: These are not sliced and can't be overridden by cell output
        :return:
        """
        # NOTE! Unrolling the cell directly will result in crash as the hidden state is not being reset
        # Use this function or CustomLSTMCell.unroll if needed
        initial_inputs, static_inputs = self.assert_begin(inputs, initial_inputs, static_inputs)

        step_inputs = initial_inputs.copy()
        step_inputs.update(static_inputs)
        lstm_outputs = []
        for t in range(length):
            step_inputs.update(map_dict(lambda x: x[:, t], inputs))  # Slicing
            output = self.cell(**step_inputs)
            
            self.assert_post(output, inputs, initial_inputs, static_inputs)
            # TODO Test what signature does with *args
            autoregressive_output = subdict(output, output.keys() & signature(self.cell.forward).parameters)
            step_inputs.update(autoregressive_output)
            lstm_outputs.append(output)
        
        # TODO recursively stack outputs
        lstm_outputs = listdict2dictlist(lstm_outputs)
        lstm_outputs = map_dict(lambda x: stack(x, dim=1), lstm_outputs)
            
        self.cell.reset()
        return lstm_outputs
예제 #2
0
파일: sampler.py 프로젝트: xiaofei-w/spirl
    def sample_batch(self, batch_size, is_train=True, global_step=None):
        """Samples an experience batch of the required size."""
        experience_batch = []
        step = 0
        with self._env.val_mode() if not is_train else contextlib.suppress():
            with self._agent.val_mode() if not is_train else contextlib.suppress():
                with self._agent.rollout_mode():
                    while step < batch_size:
                        # perform one rollout step
                        agent_output = self.sample_action(self._obs)
                        if agent_output.action is None:
                            self._episode_reset(global_step)
                            continue
                        agent_output = self._postprocess_agent_output(agent_output)
                        obs, reward, done, info = self._env.step(agent_output.action)
                        obs = self._postprocess_obs(obs)
                        experience_batch.append(AttrDict(
                            observation=self._obs,
                            reward=reward,
                            done=done,
                            action=agent_output.action,
                            observation_next=obs,
                        ))

                        # update stored observation
                        self._obs = obs
                        step += 1; self._episode_step += 1; self._episode_reward += reward

                        # reset if episode ends
                        if done or self._episode_step >= self._max_episode_len:
                            if not done:    # force done to be True for timeout
                                experience_batch[-1].done = True
                            self._episode_reset(global_step)

        return listdict2dictlist(experience_batch), step
예제 #3
0
파일: sampler.py 프로젝트: xiaofei-w/spirl
    def sample_episode(self, is_train, render=False):
        """Samples one episode from the environment."""
        self.init(is_train)
        episode, done = [], False
        with self._env.val_mode() if not is_train else contextlib.suppress():
            with self._agent.val_mode() if not is_train else contextlib.suppress():
                with self._agent.rollout_mode():
                    while not done and self._episode_step < self._max_episode_len:
                        # perform one rollout step
                        agent_output = self.sample_action(self._obs)
                        if agent_output.action is None:
                            break
                        agent_output = self._postprocess_agent_output(agent_output)
                        if render:
                            render_obs = self._env.render()
                        obs, reward, done, info = self._env.step(agent_output.action)
                        obs = self._postprocess_obs(obs)
                        episode.append(AttrDict(
                            observation=self._obs,
                            reward=reward,
                            done=done,
                            action=agent_output.action,
                            observation_next=obs,
                            info=obj2np(info),
                        ))
                        if render:
                            episode[-1].update(AttrDict(image=render_obs))

                        # update stored observation
                        self._obs = obs
                        self._episode_step += 1
        episode[-1].done = True     # make sure episode is marked as done at final time step

        return listdict2dictlist(episode)
예제 #4
0
파일: sampler.py 프로젝트: xiaofei-w/spirl
    def sample_batch(self, batch_size, is_train=True, global_step=None, store_ll=True):
        """Samples the required number of high-level transitions. Number of LL transitions can be higher."""
        hl_experience_batch, ll_experience_batch = [], []

        env_steps, hl_step = 0, 0
        with self._env.val_mode() if not is_train else contextlib.suppress():
            with self._agent.val_mode() if not is_train else contextlib.suppress():
                with self._agent.rollout_mode():
                    while hl_step < batch_size or len(ll_experience_batch) <= 1:
                        # perform one rollout step
                        agent_output = self.sample_action(self._obs)
                        agent_output = self._postprocess_agent_output(agent_output)
                        obs, reward, done, info = self._env.step(agent_output.action)
                        obs = self._postprocess_obs(obs)

                        # update last step's 'observation_next' with HL action
                        if store_ll:
                            if ll_experience_batch:
                                ll_experience_batch[-1].observation_next = \
                                    self._agent.make_ll_obs(ll_experience_batch[-1].observation_next, agent_output.hl_action)

                            # store current step in ll_experience_batch
                            ll_experience_batch.append(AttrDict(
                                observation=self._agent.make_ll_obs(self._obs, agent_output.hl_action),
                                reward=reward,
                                done=done,
                                action=agent_output.action,
                                observation_next=obs,       # this will get updated in the next step
                            ))

                        # store HL experience batch if this was HL action or episode is done
                        if agent_output.is_hl_step or (done or self._episode_step >= self._max_episode_len-1):
                            if self.last_hl_obs is not None and self.last_hl_action is not None:
                                hl_experience_batch.append(AttrDict(
                                    observation=self.last_hl_obs,
                                    reward=self.reward_since_last_hl,
                                    done=done,
                                    action=self.last_hl_action,
                                    observation_next=obs,
                                ))
                                hl_step += 1
                                if hl_step % 1000 == 0:
                                    print("Sample step {}".format(hl_step))
                            self.last_hl_obs = self._obs
                            self.last_hl_action = agent_output.hl_action
                            self.reward_since_last_hl = 0

                        # update stored observation
                        self._obs = obs
                        env_steps += 1; self._episode_step += 1; self._episode_reward += reward
                        self.reward_since_last_hl += reward

                        # reset if episode ends
                        if done or self._episode_step >= self._max_episode_len:
                            if not done:    # force done to be True for timeout
                                ll_experience_batch[-1].done = True
                                if hl_experience_batch:   # can potentially be empty 
                                    hl_experience_batch[-1].done = True
                            self._episode_reset(global_step)


        return AttrDict(
            hl_batch=listdict2dictlist(hl_experience_batch),
            ll_batch=listdict2dictlist(ll_experience_batch[:-1]),   # last element does not have updated obs_next!
        ), env_steps