Exemple #1
0
def collect_rollouts_per_task(task_idx, agent, policy_storage, env,
                              num_rollouts):
    for rollout in range(num_rollouts):
        obs = ptu.from_numpy(env.reset(task_idx))
        obs = obs.reshape(-1, obs.shape[-1])
        done_rollout = False

        while not done_rollout:
            action, _, _, _ = agent.act(obs=obs)  # SAC
            # observe reward and next obs
            next_obs, reward, done, info = utl.env_step(
                env, action.squeeze(dim=0))
            done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True

            # add data to policy buffer - (s+, a, r, s'+, term)
            term = env.unwrapped.is_goal_state() if "is_goal_state" in dir(
                env.unwrapped) else False
            rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0))
            policy_storage.add_sample(
                task=0,  #task_idx,
                observation=ptu.get_numpy(obs.squeeze(dim=0)),
                action=ptu.get_numpy(action.squeeze(dim=0)),
                reward=rew_to_buffer,
                terminal=np.array([term], dtype=float),
                next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)))

            # set: obs <- next_obs
            obs = next_obs.clone()
Exemple #2
0
    def evaluate(self, tasks):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps

        returns_per_episode = np.zeros((len(tasks), num_episodes))
        success_rate = np.zeros(len(tasks))

        if self.args.policy == 'dqn':
            values = np.zeros((len(tasks), self.args.max_trajectory_len))
        else:
            obs_size = self.env.unwrapped.observation_space.shape[0]
            observations = np.zeros((len(tasks), self.args.max_trajectory_len + 1, obs_size))
            log_probs = np.zeros((len(tasks), self.args.max_trajectory_len))

        for task_idx, task in enumerate(tasks):

            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            if self.args.policy == 'sac':
                observations[task_idx, step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=obs, deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(obs=obs,
                                                                deterministic=self.args.eval_deterministic,
                                                                return_log_prob=True)
                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    if self.args.policy == 'dqn':
                        values[task_idx, step] = value.item()
                    else:
                        observations[task_idx, step + 1, :] = ptu.get_numpy(next_obs[0, :obs_size])
                        log_probs[task_idx, step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state():
                        success_rate[task_idx] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_idx, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, values
        else:
            return returns_per_episode, success_rate, log_probs, observations
Exemple #3
0
    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[[self.env.action_space.sample()]]]).long()   # Sample random action
                    else:
                        action = ptu.FloatTensor([self.env.action_space.sample()])  # Sample random action
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=obs)   # DQN
                    else:
                        action, _, _, _ = self.agent.act(obs=obs)   # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state() if "is_goal_state" in dir(self.env.unwrapped) else False
                if self.args.dense_train_sparse_test:
                    rew_to_buffer = {rew_type: rew for rew_type, rew in info.items()
                                     if rew_type.startswith('reward')}
                else:
                    rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0))
                self.policy_storage.add_sample(task=self.task_idx,
                                               observation=ptu.get_numpy(obs.squeeze(dim=0)),
                                               action=ptu.get_numpy(action.squeeze(dim=0)),
                                               reward=rew_to_buffer,
                                               terminal=np.array([term], dtype=float),
                                               next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1
Exemple #4
0
    def load_and_render(self, load_iter):
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models')
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models')
        save_path = os.path.join(
            '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25',
            'models')
        self.policy.actor_critic = torch.load(
            os.path.join(save_path, "policy{0}.pt".format(load_iter)))
        self.vae.encoder = torch.load(
            os.path.join(save_path, "encoder{0}.pt").format(load_iter))

        args = self.args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        num_processes = 1
        num_episodes = 100
        num_steps = 1999

        #import pdb; pdb.set_trace()
        # initialise environments
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=num_processes,  # 1
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior(
            num_processes)

        for episode_idx in range(num_episodes):
            (prev_obs_raw, prev_obs_normalised) = envs.reset()
            prev_obs_raw = prev_obs_raw.to(device)
            prev_obs_normalised = prev_obs_normalised.to(device)
            for step_idx in range(num_steps):

                with torch.no_grad():
                    _, action, _ = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                        deterministic=True)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw,
                    rew_normalised), done, infos = utl.env_step(envs, action)
                # render
                envs.venv.venv.envs[0].env.env.env.env.render()

                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                if done[0]:
                    break
Exemple #5
0
    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        vae_is_pretrained = False
        for self.iter_idx in range(self.args.num_updates):

            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            # compute latent embedding (will return prior if current trajectory is empty)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory(
                )

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # add this initial hidden state to the policy storage
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                    )
                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                tasks = torch.FloatTensor([info['task']
                                           for info in infos]).to(device)
                done = torch.from_numpy(np.array(
                    done, dtype=int)).to(device).float().view((-1, 1))

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # compute next embedding (for next loop and/or value prediction bootstrap)
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=done,
                    hidden_state=hidden_state)

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder
                        and self.args.disable_stochasticity_in_latent):
                    self.vae.rollout_storage.insert(prev_obs_raw.clone(),
                                                    action.detach().clone(),
                                                    next_obs_raw.clone(),
                                                    rew_raw.clone(),
                                                    done.clone(),
                                                    tasks.clone())

                # add the obs before reset to the policy storage
                # (only used to recompute embeddings if rlloss is backpropagated through encoder)
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(
                    done.cpu().detach().flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw,
                    obs_normalised=next_obs_normalised,
                    actions=action,
                    action_log_probs=action_log_prob,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0).detach(),
                    latent_sample=latent_sample.detach(),
                    latent_mean=latent_mean.detach(),
                    latent_logvar=latent_logvar.detach(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            if self.args.precollect_len <= self.frames:
                # check if we are pre-training the VAE
                if self.args.pretrain_len > 0 and not vae_is_pretrained:
                    for _ in range(self.args.pretrain_len):
                        self.vae.compute_vae_loss(update=True)
                    vae_is_pretrained = True

                # otherwise do the normal update (policy + vae)
                else:

                    train_stats = self.update(
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar)

                    # log
                    run_stats = [action, action_log_prob, value]
                    if train_stats is not None:
                        self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()
Exemple #6
0
    def train(self):
        """ Main training loop """
        start_time = time.time()

        # reset environments
        state, belief, task = utl.reset_env(self.envs, self.args)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_state[0].copy_(state)

        # log once before training
        with torch.no_grad():
            self.log(None, None, start_time)

        for self.iter_idx in range(self.num_updates):

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        state=state,
                        belief=belief,
                        task=task,
                        deterministic=False)

                # observe reward and next obs
                [state, belief,
                 task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                     self.envs, action, self.args)

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # reset environments that are done
                done_indices = np.argwhere(done.flatten()).flatten()
                if len(done_indices) > 0:
                    state, belief, task = utl.reset_env(self.envs,
                                                        self.args,
                                                        indices=done_indices,
                                                        state=state)

                # add experience to policy buffer
                self.policy_storage.insert(
                    state=state,
                    belief=belief,
                    task=task,
                    actions=action,
                    action_log_probs=action_log_prob,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=torch.from_numpy(np.array(done,
                                                   dtype=float)).unsqueeze(1),
                )

                self.frames += self.args.num_processes

            # --- UPDATE ---

            train_stats = self.update(state=state, belief=belief, task=task)

            # log
            run_stats = [action, action_log_prob, value]
            if train_stats is not None:
                with torch.no_grad():
                    self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()
Exemple #7
0
    def visualise_behaviour(
        self,
        env,
        args,
        policy,
        iter_idx,
        encoder=None,
        image_folder=None,
        return_pos=False,
        **kwargs,
    ):

        num_episodes = args.max_rollouts_per_task

        # --- initialise things we want to keep track of ---

        episode_prev_obs = [[] for _ in range(num_episodes)]
        episode_next_obs = [[] for _ in range(num_episodes)]
        episode_actions = [[] for _ in range(num_episodes)]
        episode_rewards = [[] for _ in range(num_episodes)]

        episode_returns = []
        episode_lengths = []

        if encoder is not None:
            episode_latent_samples = [[] for _ in range(num_episodes)]
            episode_latent_means = [[] for _ in range(num_episodes)]
            episode_latent_logvars = [[] for _ in range(num_episodes)]
        else:
            episode_latent_samples = episode_latent_means = episode_latent_logvars = None

        # --- roll out policy ---

        # (re)set environment
        env.reset_task()
        state, belief, task = utl.reset_env(env, args)
        start_obs_raw = state.clone()
        task = task.view(-1) if task is not None else None

        # initialise actions and rewards (used as initial input to policy if we have a recurrent policy)
        if hasattr(args, 'hidden_size'):
            hidden_state = torch.zeros((1, args.hidden_size)).to(device)
        else:
            hidden_state = None

        # keep track of what task we're in and the position of the cheetah
        pos = [[] for _ in range(args.max_rollouts_per_task)]
        start_pos = state

        for episode_idx in range(num_episodes):

            curr_rollout_rew = []
            pos[episode_idx].append(start_pos[0])

            if episode_idx == 0:
                if encoder is not None:
                    # reset to prior
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                        1)
                    curr_latent_sample = curr_latent_sample[0].to(device)
                    curr_latent_mean = curr_latent_mean[0].to(device)
                    curr_latent_logvar = curr_latent_logvar[0].to(device)
                else:
                    curr_latent_sample = curr_latent_mean = curr_latent_logvar = None

            if encoder is not None:
                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            for step_idx in range(1, env._max_episode_steps + 1):

                if step_idx == 1:
                    episode_prev_obs[episode_idx].append(start_obs_raw.clone())
                else:
                    episode_prev_obs[episode_idx].append(state.clone())
                # act
                latent = utl.get_latent_for_policy(
                    args,
                    latent_sample=curr_latent_sample,
                    latent_mean=curr_latent_mean,
                    latent_logvar=curr_latent_logvar)
                _, action = policy.act(state=state.view(-1),
                                       latent=latent,
                                       belief=belief,
                                       task=task,
                                       deterministic=True)

                (state, belief,
                 task), (rew, rew_normalised), done, info = utl.env_step(
                     env, action, args)
                state = state.float().reshape((1, -1)).to(device)
                task = task.view(-1) if task is not None else None

                # keep track of position
                pos[episode_idx].append(state[0])

                if encoder is not None:
                    # update task embedding
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                        action.reshape(1, -1).float().to(device),
                        state,
                        rew.reshape(1, -1).float().to(device),
                        hidden_state,
                        return_prior=False)

                    episode_latent_samples[episode_idx].append(
                        curr_latent_sample[0].clone())
                    episode_latent_means[episode_idx].append(
                        curr_latent_mean[0].clone())
                    episode_latent_logvars[episode_idx].append(
                        curr_latent_logvar[0].clone())

                episode_next_obs[episode_idx].append(state.clone())
                episode_rewards[episode_idx].append(rew.clone())
                episode_actions[episode_idx].append(action.clone())

                if info[0]['done_mdp'] and not done:
                    start_obs_raw = info[0]['start_state']
                    start_obs_raw = torch.from_numpy(
                        start_obs_raw).float().reshape((1, -1)).to(device)
                    start_pos = start_obs_raw
                    break

            episode_returns.append(sum(curr_rollout_rew))
            episode_lengths.append(step_idx)

        # clean up
        if encoder is not None:
            episode_latent_means = [
                torch.stack(e) for e in episode_latent_means
            ]
            episode_latent_logvars = [
                torch.stack(e) for e in episode_latent_logvars
            ]

        episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
        episode_next_obs = [torch.cat(e) for e in episode_next_obs]
        episode_actions = [torch.stack(e) for e in episode_actions]
        episode_rewards = [torch.cat(e) for e in episode_rewards]

        figsize = (5.5, 4)
        figure, axis = plt.subplots(1, 1, figsize=figsize)
        xlim = (-1.3, 1.3)
        if self.goal_sampler == semi_circle_goal_sampler:
            ylim = (-0.3, 1.3)
        else:
            ylim = (-1.3, 1.3)
        color_map = mpl.colors.ListedColormap(
            sns.color_palette("husl", num_episodes))

        observations = torch.stack(
            [episode_prev_obs[i] for i in range(num_episodes)]).cpu().numpy()
        curr_task = env.get_task()

        # plot goal
        axis.scatter(*curr_task, marker='x', color='k', s=50)
        # radius where we get reward
        if hasattr(self, 'goal_radius'):
            circle1 = plt.Circle(curr_task,
                                 self.goal_radius,
                                 color='c',
                                 alpha=0.2,
                                 edgecolor='none')
            plt.gca().add_artist(circle1)

        for i in range(num_episodes):
            color = color_map(i)
            path = observations[i]

            # plot (semi-)circle
            r = 1.0
            if self.goal_sampler == semi_circle_goal_sampler:
                angle = np.linspace(0, np.pi, 100)
            else:
                angle = np.linspace(0, 2 * np.pi, 100)
            goal_range = r * np.array((np.cos(angle), np.sin(angle)))
            plt.plot(goal_range[0], goal_range[1], 'k--', alpha=0.1)

            # plot trajectory
            axis.plot(path[:, 0], path[:, 1], '-', color=color, label=i)
            axis.scatter(*path[0, :2], marker='.', color=color, s=50)

        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.xticks([])
        plt.yticks([])
        plt.legend()
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_behaviour.png'.format(image_folder, iter_idx),
                        dpi=300,
                        bbox_inches='tight')
            plt.close()
        else:
            plt.show()

        plt_rew = [
            episode_rewards[i][:episode_lengths[i]]
            for i in range(len(episode_rewards))
        ]
        plt.plot(torch.cat(plt_rew).view(-1).cpu().numpy())
        plt.xlabel('env step')
        plt.ylabel('reward per step')
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_rewards.png'.format(image_folder, iter_idx),
                        dpi=300,
                        bbox_inches='tight')
            plt.close()
        else:
            plt.show()

        if not return_pos:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns
        else:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns, pos
Exemple #8
0
def rollout_policy(env, learner):
    is_vae_exist = "vae" in dir(learner)

    observations = []
    actions = []
    rewards = []
    values = []
    if is_vae_exist:
        latent_samples = []
        latent_means = []
        latent_logvars = []

    obs = ptu.from_numpy(env.reset())
    obs = obs.reshape(-1, obs.shape[-1])
    observations.append(obs)
    done_rollout = False
    if is_vae_exist:
        # get prior parameters
        with torch.no_grad():
            task_sample, task_mean, task_logvar, hidden_state = learner.vae.encoder.prior(
                batch_size=1)
        # store
        latent_samples.append(ptu.get_numpy(task_sample[0, 0]))
        latent_means.append(ptu.get_numpy(task_mean[0, 0]))
        latent_logvars.append(ptu.get_numpy(task_logvar[0, 0]))

    while not done_rollout:
        if is_vae_exist:
            # add distribution parameters to observation - policy is conditioned on posterior
            augmented_obs = learner.get_augmented_obs(obs=obs,
                                                      task_mu=task_mean,
                                                      task_std=task_logvar)
            with torch.no_grad():
                action, value = learner.agent.act(obs=augmented_obs,
                                                  deterministic=True)
        else:
            action, _, _, _ = learner.agent.act(obs=obs)

        # observe reward and next obs
        next_obs, reward, done, info = utl.env_step(env, action.squeeze(dim=0))
        # store
        observations.append(next_obs)
        actions.append(action)
        values.append(value)
        rewards.append(reward.item())
        done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True

        if is_vae_exist:
            # update encoding
            task_sample, task_mean, task_logvar, hidden_state = learner.vae.encoder(
                action,
                next_obs,
                reward.reshape((1, 1)),
                hidden_state,
                return_prior=False)

            # values.append(value.item())
            latent_samples.append(ptu.get_numpy(task_sample[0]))
            latent_means.append(ptu.get_numpy(task_mean[0]))
            latent_logvars.append(ptu.get_numpy(task_logvar[0]))
        # set: obs <- next_obs
        obs = next_obs.clone()
    if is_vae_exist:
        return observations, actions, rewards, values, \
               latent_samples, latent_means, latent_logvars
    else:
        return observations, actions, rewards, values
Exemple #9
0
    def visualise_behaviour(
        env,
        args,
        policy,
        iter_idx,
        encoder=None,
        image_folder=None,
        return_pos=False,
        **kwargs,
    ):

        num_episodes = args.max_rollouts_per_task
        unwrapped_env = env.venv.unwrapped.envs[0].unwrapped

        # --- initialise things we want to keep track of ---

        episode_prev_obs = [[] for _ in range(num_episodes)]
        episode_next_obs = [[] for _ in range(num_episodes)]
        episode_actions = [[] for _ in range(num_episodes)]
        episode_rewards = [[] for _ in range(num_episodes)]

        episode_returns = []
        episode_lengths = []

        if encoder is not None:
            episode_latent_samples = [[] for _ in range(num_episodes)]
            episode_latent_means = [[] for _ in range(num_episodes)]
            episode_latent_logvars = [[] for _ in range(num_episodes)]
        else:
            episode_latent_samples = episode_latent_means = episode_latent_logvars = None

        # --- roll out policy ---

        # (re)set environment
        env.reset_task()
        state, belief, task = utl.reset_env(env, args)
        start_obs_raw = state.clone()
        task = task.view(-1) if task is not None else None

        # initialise actions and rewards (used as initial input to policy if we have a recurrent policy)
        if hasattr(args, 'hidden_size'):
            hidden_state = torch.zeros((1, args.hidden_size)).to(device)
        else:
            hidden_state = None

        # keep track of what task we're in and the position of the cheetah
        pos = [[] for _ in range(args.max_rollouts_per_task)]
        start_pos = unwrapped_env.get_body_com("torso")[:2].copy()

        for episode_idx in range(num_episodes):

            curr_rollout_rew = []
            pos[episode_idx].append(start_pos)

            if episode_idx == 0:
                if encoder is not None:
                    # reset to prior
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                        1)
                    curr_latent_sample = curr_latent_sample[0].to(device)
                    curr_latent_mean = curr_latent_mean[0].to(device)
                    curr_latent_logvar = curr_latent_logvar[0].to(device)
                else:
                    curr_latent_sample = curr_latent_mean = curr_latent_logvar = None

            if encoder is not None:
                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            for step_idx in range(1, env._max_episode_steps + 1):

                if step_idx == 1:
                    episode_prev_obs[episode_idx].append(start_obs_raw.clone())
                else:
                    episode_prev_obs[episode_idx].append(state.clone())
                # act
                latent = utl.get_latent_for_policy(
                    args,
                    latent_sample=curr_latent_sample,
                    latent_mean=curr_latent_mean,
                    latent_logvar=curr_latent_logvar)
                _, action, _ = policy.act(state=state.view(-1),
                                          latent=latent,
                                          belief=belief,
                                          task=task,
                                          deterministic=True)

                (state, belief,
                 task), (rew, rew_normalised), done, info = utl.env_step(
                     env, action, args)
                state = state.float().reshape((1, -1)).to(device)
                task = task.view(-1) if task is not None else None

                # keep track of position
                pos[episode_idx].append(
                    unwrapped_env.get_body_com("torso")[:2].copy())

                if encoder is not None:
                    # update task embedding
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                        action.reshape(1, -1).float().to(device),
                        state,
                        rew.reshape(1, -1).float().to(device),
                        hidden_state,
                        return_prior=False)

                    episode_latent_samples[episode_idx].append(
                        curr_latent_sample[0].clone())
                    episode_latent_means[episode_idx].append(
                        curr_latent_mean[0].clone())
                    episode_latent_logvars[episode_idx].append(
                        curr_latent_logvar[0].clone())

                episode_next_obs[episode_idx].append(state.clone())
                episode_rewards[episode_idx].append(rew.clone())
                episode_actions[episode_idx].append(action.clone())

                if info[0]['done_mdp'] and not done:
                    start_obs_raw = info[0]['start_state']
                    start_obs_raw = torch.from_numpy(
                        start_obs_raw).float().reshape((1, -1)).to(device)
                    start_pos = unwrapped_env.get_body_com("torso")[:2].copy()
                    break

            episode_returns.append(sum(curr_rollout_rew))
            episode_lengths.append(step_idx)

        # clean up
        if encoder is not None:
            episode_latent_means = [
                torch.stack(e) for e in episode_latent_means
            ]
            episode_latent_logvars = [
                torch.stack(e) for e in episode_latent_logvars
            ]

        episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
        episode_next_obs = [torch.cat(e) for e in episode_next_obs]
        episode_actions = [torch.stack(e) for e in episode_actions]
        episode_rewards = [torch.cat(e) for e in episode_rewards]

        # plot the movement of the ant
        # print(pos)
        plt.figure(figsize=(5, 4 * num_episodes))
        min_dim = -3.5
        max_dim = 3.5
        span = max_dim - min_dim

        for i in range(num_episodes):
            plt.subplot(num_episodes, 1, i + 1)

            x = list(map(lambda p: p[0], pos[i]))
            y = list(map(lambda p: p[1], pos[i]))
            plt.plot(x[0], y[0], 'bo')

            plt.scatter(x, y, 1, 'g')

            curr_task = env.get_task()
            plt.title('task: {}'.format(curr_task), fontsize=15)
            if 'Goal' in args.env_name:
                plt.plot(curr_task[0], curr_task[1], 'rx')

            plt.ylabel('y-position (ep {})'.format(i), fontsize=15)

            if i == num_episodes - 1:
                plt.xlabel('x-position', fontsize=15)
                plt.ylabel('y-position (ep {})'.format(i), fontsize=15)
            plt.xlim(min_dim - 0.05 * span, max_dim + 0.05 * span)
            plt.ylim(min_dim - 0.05 * span, max_dim + 0.05 * span)

        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()

        if not return_pos:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns
        else:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns, pos
Exemple #10
0
def evaluate(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             encoder=None,
             num_episodes=None):
    env_name = args.env_name
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros(
        (num_processes, num_episodes + 1)).to(device)

    # --- initialise environments and latents ---

    envs = make_vec_envs(
        env_name,
        seed=args.seed * 42 + iter_idx,
        num_processes=num_processes,
        gamma=args.policy_gamma,
        device=device,
        rank_offset=num_processes +
        1,  # to use diff tmp folders than main processes
        episodes_per_task=num_episodes,
        normalise_rew=args.norm_rew_for_policy,
        ret_rms=ret_rms,
        tasks=tasks,
        add_done_info=args.max_rollouts_per_task > 1,
    )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, task = utl.reset_env(envs, args)

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(device)

    if encoder is not None:
        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(
            num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    for episode_idx in range(num_episodes):

        for step_idx in range(num_steps):

            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=True)

            # observe reward and next obs
            [state, belief,
             task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                 envs, action, args)
            done_mdp = [info['done_mdp'] for info in infos]

            if encoder is not None:
                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=encoder,
                    next_obs=state,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

            # add rewards
            returns_per_episode[range(num_processes),
                                task_count] += rew_raw.view(-1)

            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1,
                                    num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs,
                                                    args,
                                                    indices=done_indices,
                                                    state=state)

    envs.close()

    return returns_per_episode[:, :num_episodes]
Exemple #11
0
    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False
            # self.policy_storage.reset_running_episode(self.task_idx)

            # if self.args.fixed_latent_params:
            #     assert 2 ** self.args.task_embedding_size >= self.args.num_tasks
            #     task_mean = ptu.FloatTensor(utl.vertices(self.args.task_embedding_size)[self.task_idx])
            #     task_logvar = -2. * ptu.ones_like(task_logvar)   # arbitrary negative enough number
            # add distribution parameters to observation - policy is conditioned on posterior
            augmented_obs = self.get_augmented_obs(obs=obs)

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[
                            self.env.action_space.sample()
                        ]]).type(torch.long)  # Sample random action
                    else:
                        action = ptu.FloatTensor(
                            [self.env.action_space.sample()])
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=augmented_obs)  # DQN
                    else:
                        action, _, _, _ = self.agent.act(
                            obs=augmented_obs)  # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(
                    self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(
                    done[0][0]) == 0. else True

                # get augmented next obs
                augmented_next_obs = self.get_augmented_obs(obs=next_obs)

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state(
                ) if "is_goal_state" in dir(self.env.unwrapped) else False
                self.policy_storage.add_sample(
                    task=self.task_idx,
                    observation=ptu.get_numpy(augmented_obs.squeeze(dim=0)),
                    action=ptu.get_numpy(action.squeeze(dim=0)),
                    reward=ptu.get_numpy(reward.squeeze(dim=0)),
                    terminal=np.array([term], dtype=float),
                    next_observation=ptu.get_numpy(
                        augmented_next_obs.squeeze(dim=0)))
                if not random_actions:
                    self.current_experience_storage.add_sample(
                        task=self.task_idx,
                        observation=ptu.get_numpy(
                            augmented_obs.squeeze(dim=0)),
                        action=ptu.get_numpy(action.squeeze(dim=0)),
                        reward=ptu.get_numpy(reward.squeeze(dim=0)),
                        terminal=np.array([term], dtype=float),
                        next_observation=ptu.get_numpy(
                            augmented_next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()
                augmented_obs = augmented_next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(
                        self.env.unwrapped
                ) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1
Exemple #12
0
    def collect_rollouts(self):
        self.training_mode(False)
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps
        num_tasks = self.args.num_eval_tasks
        obs_size = self.env.unwrapped.observation_space.shape[0]

        returns_per_episode = np.zeros((num_tasks, num_episodes))
        success_rate = np.zeros(num_tasks)

        rewards = np.zeros((num_tasks, self.args.trajectory_len))
        observations = np.zeros(
            (num_tasks, self.args.trajectory_len + 1, obs_size))
        actions = np.zeros(
            (num_tasks, self.args.trajectory_len, self.args.action_dim))

        log_probs = np.zeros((num_tasks, self.args.trajectory_len))

        for task in self.env.unwrapped.get_all_task_idx():
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            # get prior parameters
            task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior(
                batch_size=1)

            observations[task, step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(
                        obs, task_mean, task_logvar)
                    action, _, _, log_prob = self.agent.act(
                        obs=augmented_obs,
                        deterministic=self.args.eval_deterministic,
                        return_log_prob=True)

                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()

                    # update encoding
                    task_sample, task_mean, task_logvar, hidden_state = self.update_encoding(
                        obs=next_obs,
                        action=action,
                        reward=reward,
                        done=done,
                        hidden_state=hidden_state)
                    rewards[task, step] = reward.item()
                    #reward_preds[task, step] = ptu.get_numpy(
                    #    self.vae.reward_decoder(task_sample, next_obs, obs, action)[0, 0])

                    observations[task, step + 1, :] = ptu.get_numpy(
                        next_obs[0, :obs_size])
                    actions[task, step, :] = ptu.get_numpy(action[0, :])
                    log_probs[task, step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task, episode_idx] = running_reward

        return returns_per_episode, success_rate, log_probs, observations, rewards, actions
Exemple #13
0
    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        for self.iter_idx in range(self.args.num_updates):

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # rollouts policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        policy=self.policy,
                        args=self.args,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                action = action.float()

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # add the obs before reset to the policy storage
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw.clone(),
                    obs_normalised=next_obs_normalised.clone(),
                    actions=action.clone(),
                    action_log_probs=action_log_prob.clone(),
                    rewards_raw=rew_raw.clone(),
                    rewards_normalised=rew_normalised.clone(),
                    value_preds=value.clone(),
                    masks=masks_done.clone(),
                    bad_masks=bad_masks.clone(),
                    done=torch.from_numpy(np.array(
                        done, dtype=float)).unsqueeze(1).clone(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            train_stats = self.update(prev_obs_normalised if self.args.
                                      norm_obs_for_policy else prev_obs_raw)

            # log
            run_stats = [action, action_log_prob, value]
            if train_stats is not None:
                self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()
Exemple #14
0
    def visualise_behaviour(env,
                            args,
                            policy,
                            iter_idx,
                            encoder=None,
                            reward_decoder=None,
                            image_folder=None,
                            **kwargs
                            ):
        """
        Visualises the behaviour of the policy, together with the latent state and belief.
        The environment passed to this method should be a SubProcVec or DummyVecEnv, not the raw env!
        """

        num_episodes = args.max_rollouts_per_task
        unwrapped_env = env.venv.unwrapped.envs[0]

        # --- initialise things we want to keep track of ---

        episode_all_obs = [[] for _ in range(num_episodes)]
        episode_prev_obs = [[] for _ in range(num_episodes)]
        episode_next_obs = [[] for _ in range(num_episodes)]
        episode_actions = [[] for _ in range(num_episodes)]
        episode_rewards = [[] for _ in range(num_episodes)]

        episode_returns = []
        episode_lengths = []

        episode_goals = []
        if getattr(unwrapped_env, 'belief_oracle', False):
            episode_beliefs = [[] for _ in range(num_episodes)]
        else:
            episode_beliefs = None

        if encoder is not None:
            # keep track of latent spaces
            episode_latent_samples = [[] for _ in range(num_episodes)]
            episode_latent_means = [[] for _ in range(num_episodes)]
            episode_latent_logvars = [[] for _ in range(num_episodes)]
        else:
            episode_latent_samples = episode_latent_means = episode_latent_logvars = None

        curr_latent_sample = curr_latent_mean = curr_latent_logvar = None

        # --- roll out policy ---

        env.reset_task()
        (obs_raw, obs_normalised) = env.reset()
        obs_raw = obs_raw.float().reshape((1, -1)).to(device)
        obs_normalised = obs_normalised.float().reshape((1, -1)).to(device)
        start_obs_raw = obs_raw.clone()

        for episode_idx in range(args.max_rollouts_per_task):

            curr_goal = env.get_task()
            curr_rollout_rew = []
            curr_rollout_goal = []

            if encoder is not None:

                if episode_idx == 0:
                    # reset to prior
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
                    curr_latent_sample = curr_latent_sample[0].to(device)
                    curr_latent_mean = curr_latent_mean[0].to(device)
                    curr_latent_logvar = curr_latent_logvar[0].to(device)

                episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

            episode_all_obs[episode_idx].append(start_obs_raw.clone())
            if getattr(unwrapped_env, 'belief_oracle', False):
                episode_beliefs[episode_idx].append(unwrapped_env.unwrapped._belief_state.copy())

            for step_idx in range(1, env._max_episode_steps + 1):

                if step_idx == 1:
                    episode_prev_obs[episode_idx].append(start_obs_raw.clone())
                else:
                    episode_prev_obs[episode_idx].append(obs_raw.clone())

                # act
                _, action, _ = utl.select_action(args=args,
                                                 policy=policy,
                                                 obs=obs_normalised if args.norm_obs_for_policy else obs_raw,
                                                 deterministic=True,
                                                 latent_sample=curr_latent_sample, latent_mean=curr_latent_mean,
                                                 latent_logvar=curr_latent_logvar)

                # observe reward and next obs
                (obs_raw, obs_normalised), (rew_raw, rew_normalised), done, infos = utl.env_step(env, action)
                obs_raw = obs_raw.reshape((1, -1)).to(device)
                obs_normalised = obs_normalised.reshape((1, -1)).to(device)

                if encoder is not None:
                    # update task embedding
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                        action.float().to(device),
                        obs_raw,
                        rew_raw.reshape((1, 1)).float().to(device),
                        hidden_state,
                        return_prior=False)

                    episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
                    episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
                    episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

                episode_all_obs[episode_idx].append(obs_raw.clone())
                episode_next_obs[episode_idx].append(obs_raw.clone())
                episode_rewards[episode_idx].append(rew_raw.clone())
                episode_actions[episode_idx].append(action.clone())

                curr_rollout_rew.append(rew_raw.clone())
                curr_rollout_goal.append(env.get_task().copy())

                if getattr(unwrapped_env, 'belief_oracle', False):
                    episode_beliefs[episode_idx].append(unwrapped_env.unwrapped._belief_state.copy())

                if infos[0]['done_mdp'] and not done:
                    start_obs_raw = infos[0]['start_state']
                    start_obs_raw = torch.from_numpy(start_obs_raw).float().reshape((1, -1)).to(device)
                    break

            episode_returns.append(sum(curr_rollout_rew))
            episode_lengths.append(step_idx)
            episode_goals.append(curr_goal)

        # clean up

        if encoder is not None:
            episode_latent_means = [torch.stack(e) for e in episode_latent_means]
            episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]

        episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
        episode_next_obs = [torch.cat(e) for e in episode_next_obs]
        episode_actions = [torch.cat(e) for e in episode_actions]
        episode_rewards = [torch.cat(e) for e in episode_rewards]

        # plot behaviour & visualise belief in env

        rew_pred_means, rew_pred_vars = plot_bb(env, args, episode_all_obs, episode_goals, reward_decoder,
                                                episode_latent_means, episode_latent_logvars,
                                                image_folder, iter_idx, episode_beliefs)

        if reward_decoder:
            plot_rew_reconstruction(env, rew_pred_means, rew_pred_vars, image_folder, iter_idx)

        return episode_latent_means, episode_latent_logvars, \
               episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
               episode_returns
Exemple #15
0
def get_test_rollout(args, env, policy, encoder=None):
    num_episodes = args.max_rollouts_per_task

    # --- initialise things we want to keep track of ---

    episode_prev_obs = [[] for _ in range(num_episodes)]
    episode_next_obs = [[] for _ in range(num_episodes)]
    episode_actions = [[] for _ in range(num_episodes)]
    episode_rewards = [[] for _ in range(num_episodes)]

    episode_returns = []
    episode_lengths = []

    if encoder is not None:
        episode_latent_samples = [[] for _ in range(num_episodes)]
        episode_latent_means = [[] for _ in range(num_episodes)]
        episode_latent_logvars = [[] for _ in range(num_episodes)]
    else:
        curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
        episode_latent_means = episode_latent_logvars = None

    # --- roll out policy ---

    # (re)set environment
    env.reset_task()
    state, belief, task = utl.reset_env(env, args)
    state = state.reshape((1, -1)).to(device)
    task = task.view(-1) if task is not None else None

    for episode_idx in range(num_episodes):

        curr_rollout_rew = []

        if encoder is not None:
            if episode_idx == 0:
                # reset to prior
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                    1)
                curr_latent_sample = curr_latent_sample[0].to(device)
                curr_latent_mean = curr_latent_mean[0].to(device)
                curr_latent_logvar = curr_latent_logvar[0].to(device)
            episode_latent_samples[episode_idx].append(
                curr_latent_sample[0].clone())
            episode_latent_means[episode_idx].append(
                curr_latent_mean[0].clone())
            episode_latent_logvars[episode_idx].append(
                curr_latent_logvar[0].clone())

        for step_idx in range(1, env._max_episode_steps + 1):

            episode_prev_obs[episode_idx].append(state.clone())

            latent = utl.get_latent_for_policy(
                args,
                latent_sample=curr_latent_sample,
                latent_mean=curr_latent_mean,
                latent_logvar=curr_latent_logvar)
            _, action = policy.act(state=state.view(-1),
                                   latent=latent,
                                   belief=belief,
                                   task=task,
                                   deterministic=True)
            action = action.reshape((1, *action.shape))

            # observe reward and next obs
            (state, belief,
             task), (rew_raw, rew_normalised), done, infos = utl.env_step(
                 env, action, args)
            state = state.reshape((1, -1)).to(device)
            task = task.view(-1) if task is not None else None

            if encoder is not None:
                # update task embedding
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                    action.float().to(device),
                    state,
                    rew_raw.reshape((1, 1)).float().to(device),
                    hidden_state,
                    return_prior=False)

                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            episode_next_obs[episode_idx].append(state.clone())
            episode_rewards[episode_idx].append(rew_raw.clone())
            episode_actions[episode_idx].append(action.clone())

            if infos[0]['done_mdp']:
                break

        episode_returns.append(sum(curr_rollout_rew))
        episode_lengths.append(step_idx)

    # clean up
    if encoder is not None:
        episode_latent_means = [torch.stack(e) for e in episode_latent_means]
        episode_latent_logvars = [
            torch.stack(e) for e in episode_latent_logvars
        ]

    episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
    episode_next_obs = [torch.cat(e) for e in episode_next_obs]
    episode_actions = [torch.cat(e) for e in episode_actions]
    episode_rewards = [torch.cat(r) for r in episode_rewards]

    return episode_latent_means, episode_latent_logvars, \
           episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
           episode_returns
Exemple #16
0
    def visualise_behaviour(
        env,
        args,
        policy,
        iter_idx,
        encoder=None,
        image_folder=None,
        return_pos=False,
        **kwargs,
    ):

        num_episodes = args.max_rollouts_per_task
        unwrapped_env = env.venv.unwrapped.envs[0].unwrapped

        # --- initialise things we want to keep track of ---

        episode_prev_obs = [[] for _ in range(num_episodes)]
        episode_next_obs = [[] for _ in range(num_episodes)]
        episode_actions = [[] for _ in range(num_episodes)]
        episode_rewards = [[] for _ in range(num_episodes)]

        episode_returns = []
        episode_lengths = []

        if encoder is not None:
            episode_latent_samples = [[] for _ in range(num_episodes)]
            episode_latent_means = [[] for _ in range(num_episodes)]
            episode_latent_logvars = [[] for _ in range(num_episodes)]
        else:
            curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
            episode_latent_samples = episode_latent_means = episode_latent_logvars = None

        # --- roll out policy ---

        # (re)set environment
        env.reset_task()
        state, belief, task = utl.reset_env(env, args)
        start_state = state.clone()

        # if hasattr(args, 'hidden_size'):
        #     hidden_state = torch.zeros((1, args.hidden_size)).to(device)
        # else:
        #     hidden_state = None

        # keep track of what task we're in and the position of the cheetah
        pos = [[] for _ in range(args.max_rollouts_per_task)]
        start_pos = unwrapped_env.get_body_com("torso")[0].copy()

        for episode_idx in range(num_episodes):

            curr_rollout_rew = []
            pos[episode_idx].append(start_pos)

            if encoder is not None:
                if episode_idx == 0:
                    # reset to prior
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                        1)
                    curr_latent_sample = curr_latent_sample[0].to(device)
                    curr_latent_mean = curr_latent_mean[0].to(device)
                    curr_latent_logvar = curr_latent_logvar[0].to(device)
                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            for step_idx in range(1, env._max_episode_steps + 1):

                if step_idx == 1:
                    episode_prev_obs[episode_idx].append(start_state.clone())
                else:
                    episode_prev_obs[episode_idx].append(state.clone())
                # act
                latent = utl.get_latent_for_policy(
                    args,
                    latent_sample=curr_latent_sample,
                    latent_mean=curr_latent_mean,
                    latent_logvar=curr_latent_logvar)
                _, action = policy.act(state=state.view(-1),
                                       latent=latent,
                                       belief=belief,
                                       task=task,
                                       deterministic=True)

                (state, belief,
                 task), (rew, rew_normalised), done, info = utl.env_step(
                     env, action, args)
                state = state.reshape((1, -1)).float().to(device)

                # keep track of position
                pos[episode_idx].append(
                    unwrapped_env.get_body_com("torso")[0].copy())

                if encoder is not None:
                    # update task embedding
                    curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                        action.reshape(1, -1).float().to(device),
                        state,
                        rew.reshape(1, -1).float().to(device),
                        hidden_state,
                        return_prior=False)

                    episode_latent_samples[episode_idx].append(
                        curr_latent_sample[0].clone())
                    episode_latent_means[episode_idx].append(
                        curr_latent_mean[0].clone())
                    episode_latent_logvars[episode_idx].append(
                        curr_latent_logvar[0].clone())

                episode_next_obs[episode_idx].append(state.clone())
                episode_rewards[episode_idx].append(rew.clone())
                episode_actions[episode_idx].append(
                    action.reshape(1, -1).clone())

                if info[0]['done_mdp'] and not done:
                    start_state = info[0]['start_state']
                    start_state = torch.from_numpy(start_state).reshape(
                        (1, -1)).float().to(device)
                    start_pos = unwrapped_env.get_body_com("torso")[0].copy()
                    break

            episode_returns.append(sum(curr_rollout_rew))
            episode_lengths.append(step_idx)

        # clean up
        if encoder is not None:
            episode_latent_means = [
                torch.stack(e) for e in episode_latent_means
            ]
            episode_latent_logvars = [
                torch.stack(e) for e in episode_latent_logvars
            ]

        episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
        episode_next_obs = [torch.cat(e) for e in episode_next_obs]
        episode_actions = [torch.cat(e) for e in episode_actions]
        episode_rewards = [torch.cat(e) for e in episode_rewards]

        # plot the movement of the half-cheetah
        plt.figure(figsize=(7, 4 * num_episodes))
        min_x = min([min(p) for p in pos])
        max_x = max([max(p) for p in pos])
        span = max_x - min_x
        for i in range(num_episodes):
            plt.subplot(num_episodes, 1, i + 1)
            # (not plotting the last step because this gives weird artefacts)
            plt.plot(pos[i][:-1], range(len(pos[i][:-1])), 'k')
            plt.title('task: {}'.format(task), fontsize=15)
            plt.ylabel('steps (ep {})'.format(i), fontsize=15)
            if i == num_episodes - 1:
                plt.xlabel('position', fontsize=15)
            # else:
            #     plt.xticks([])
            plt.xlim(min_x - 0.05 * span, max_x + 0.05 * span)
            plt.plot([0, 0], [200, 200], 'b--', alpha=0.2)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()

        if not return_pos:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns
        else:
            return episode_latent_means, episode_latent_logvars, \
                   episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
                   episode_returns, pos
Exemple #17
0
def get_test_rollout(args, env, policy, encoder=None):
    num_episodes = args.max_rollouts_per_task

    # --- initialise things we want to keep track of ---

    episode_prev_obs = [[] for _ in range(num_episodes)]
    episode_next_obs = [[] for _ in range(num_episodes)]
    episode_actions = [[] for _ in range(num_episodes)]
    episode_rewards = [[] for _ in range(num_episodes)]

    episode_returns = []
    episode_lengths = []

    if encoder is not None:
        episode_latent_samples = [[] for _ in range(num_episodes)]
        episode_latent_means = [[] for _ in range(num_episodes)]
        episode_latent_logvars = [[] for _ in range(num_episodes)]
    else:
        curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
        episode_latent_means = episode_latent_logvars = None
    # --- roll out policy ---

    # (re)set environment
    [obs_raw, obs_normalised] = env.reset()
    obs_raw = obs_raw.reshape((1, -1)).to(ptu.device)
    obs_normalised = obs_normalised.reshape((1, -1)).to(ptu.device)

    for episode_idx in range(num_episodes):

        curr_rollout_rew = []

        if encoder is not None:
            if episode_idx == 0 and encoder:
                # reset to prior
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                    1)
                curr_latent_sample = curr_latent_sample[0].to(ptu.device)
                curr_latent_mean = curr_latent_mean[0].to(ptu.device)
                curr_latent_logvar = curr_latent_logvar[0].to(ptu.device)

            episode_latent_samples[episode_idx].append(
                curr_latent_sample[0].clone())
            episode_latent_means[episode_idx].append(
                curr_latent_mean[0].clone())
            episode_latent_logvars[episode_idx].append(
                curr_latent_logvar[0].clone())

        for step_idx in range(1, env._max_episode_steps + 1):

            episode_prev_obs[episode_idx].append(obs_raw.clone())

            _, action, _ = utl.select_action(
                args=args,
                policy=policy,
                obs=obs_normalised if args.norm_obs_for_policy else obs_raw,
                deterministic=True,
                task_sample=curr_latent_sample,
                task_mean=curr_latent_mean,
                task_logvar=curr_latent_logvar)

            # observe reward and next obs
            (obs_raw,
             obs_normalised), (rew_raw,
                               rew_normalised), done, infos = utl.env_step(
                                   env, action)
            obs_raw = obs_raw.reshape((1, -1)).to(ptu.device)
            obs_normalised = obs_normalised.reshape((1, -1)).to(ptu.device)

            if encoder is not None:
                # update task embedding
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                    action.float().to(ptu.device),
                    obs_raw,
                    rew_raw.reshape((1, 1)).float().to(ptu.device),
                    hidden_state,
                    return_prior=False)

                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            episode_next_obs[episode_idx].append(obs_raw.clone())
            episode_rewards[episode_idx].append(rew_raw.clone())
            episode_actions[episode_idx].append(action.clone())

            if infos[0]['done_mdp']:
                break

        episode_returns.append(sum(curr_rollout_rew))
        episode_lengths.append(step_idx)

    # clean up
    if encoder is not None:
        episode_latent_means = [torch.stack(e) for e in episode_latent_means]
        episode_latent_logvars = [
            torch.stack(e) for e in episode_latent_logvars
        ]

    episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
    episode_next_obs = [torch.cat(e) for e in episode_next_obs]
    episode_actions = [torch.cat(e) for e in episode_actions]
    episode_rewards = [torch.cat(r) for r in episode_rewards]

    return episode_latent_means, episode_latent_logvars, \
           episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
           episode_returns
Exemple #18
0
    def evaluate(self):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps
        num_tasks = self.args.num_eval_tasks
        obs_size = self.env.unwrapped.observation_space.shape[0]

        returns_per_episode = np.zeros((num_tasks, num_episodes))
        success_rate = np.zeros(num_tasks)

        rewards = np.zeros((num_tasks, self.args.trajectory_len))
        reward_preds = np.zeros((num_tasks, self.args.trajectory_len))
        observations = np.zeros(
            (num_tasks, self.args.trajectory_len + 1, obs_size))
        if self.args.policy == 'sac':
            log_probs = np.zeros((num_tasks, self.args.trajectory_len))

        # This part is very specific for the Semi-Circle env
        # if self.args.env_name == 'PointRobotSparse-v0':
        #     reward_belief = np.zeros((num_tasks, self.args.trajectory_len))
        #
        #     low_x, high_x, low_y, high_y = -2., 2., -1., 2.
        #     resolution = 0.1
        #     grid_x = np.arange(low_x, high_x + resolution, resolution)
        #     grid_y = np.arange(low_y, high_y + resolution, resolution)
        #     centers_x = (grid_x[:-1] + grid_x[1:]) / 2
        #     centers_y = (grid_y[:-1] + grid_y[1:]) / 2
        #     yv, xv = np.meshgrid(centers_y, centers_x, sparse=False, indexing='ij')
        #     centers = np.vstack([xv.ravel(), yv.ravel()]).T
        #     n_grid_points = centers.shape[0]
        #     reward_belief_discretized = np.zeros((num_tasks, self.args.trajectory_len, centers.shape[0]))

        for task_loop_i, task in enumerate(
                self.env.unwrapped.get_all_eval_task_idx()):
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            # get prior parameters
            with torch.no_grad():
                task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior(
                    batch_size=1)

            observations[task_loop_i,
                         step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(
                        obs, task_mean, task_logvar)
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=augmented_obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=augmented_obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)

                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    # done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True
                    # update encoding
                    task_sample, task_mean, task_logvar, hidden_state = self.update_encoding(
                        obs=next_obs,
                        action=action,
                        reward=reward,
                        done=done,
                        hidden_state=hidden_state)
                    rewards[task_loop_i, step] = reward.item()
                    reward_preds[task_loop_i, step] = ptu.get_numpy(
                        self.vae.reward_decoder(task_sample, next_obs, obs,
                                                action)[0, 0])

                    # This part is very specific for the Semi-Circle env
                    # if self.args.env_name == 'PointRobotSparse-v0':
                    #     reward_belief[task, step] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean, task_logvar, obs, next_obs, action)[0])
                    #
                    #     reward_belief_discretized[task, step, :] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean.repeat(n_grid_points, 1),
                    #                                        task_logvar.repeat(n_grid_points, 1),
                    #                                        None,
                    #                                        torch.cat((ptu.FloatTensor(centers),
                    #                                                   ptu.zeros(centers.shape[0], 1)), dim=-1).unsqueeze(0),
                    #                                        None)[:, 0])

                    observations[task_loop_i, step + 1, :] = ptu.get_numpy(
                        next_obs[0, :obs_size])
                    if self.args.policy != 'dqn':
                        log_probs[task_loop_i,
                                  step] = ptu.get_numpy(log_prob[0])

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task_loop_i] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task_loop_i, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, observations, rewards, reward_preds
        # This part is very specific for the Semi-Circle env
        # elif self.args.env_name == 'PointRobotSparse-v0':
        #     return returns_per_episode, success_rate, log_probs, observations, \
        #            rewards, reward_preds, reward_belief, reward_belief_discretized, centers
        else:
            return returns_per_episode, success_rate, log_probs, observations, rewards, reward_preds
Exemple #19
0
    def train(self):
        """ Main Meta-Training loop """
        start_time = time.time()

        # reset environments
        prev_state, belief, task = utl.reset_env(self.envs, self.args)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_state[0].copy_(prev_state)

        # log once before training
        with torch.no_grad():
            self.log(None, None, start_time)

        for self.iter_idx in range(self.num_updates):

            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory(
                )

            # add this initial hidden state to the policy storage
            assert len(self.policy_storage.latent_mean
                       ) == 0  # make sure we emptied buffers
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        state=prev_state,
                        belief=belief,
                        task=task,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                    )

                # take step in the environment
                [next_state, belief,
                 task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                     self.envs, action, self.args)

                done = torch.from_numpy(np.array(
                    done, dtype=int)).to(device).float().view((-1, 1))
                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                with torch.no_grad():
                    # compute next embedding (for next loop and/or value prediction bootstrap)
                    latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                        encoder=self.vae.encoder,
                        next_obs=next_state,
                        action=action,
                        reward=rew_raw,
                        done=done,
                        hidden_state=hidden_state)

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder
                        and self.args.disable_kl_term):
                    self.vae.rollout_storage.insert(
                        prev_state.clone(),
                        action.detach().clone(), next_state.clone(),
                        rew_raw.clone(), done.clone(),
                        task.clone() if task is not None else None)

                # add the obs before reset to the policy storage
                self.policy_storage.next_state[step] = next_state.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.cpu().flatten()).flatten()
                if len(done_indices) > 0:
                    next_state, belief, task = utl.reset_env(
                        self.envs,
                        self.args,
                        indices=done_indices,
                        state=next_state)

                # TODO: deal with resampling for posterior sampling algorithm
                #     latent_sample = latent_sample
                #     latent_sample[i] = latent_sample[i]

                # add experience to policy buffer
                self.policy_storage.insert(
                    state=next_state,
                    belief=belief,
                    task=task,
                    actions=action,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0),
                    latent_sample=latent_sample,
                    latent_mean=latent_mean,
                    latent_logvar=latent_logvar,
                )

                prev_state = next_state

                self.frames += self.args.num_processes

            # --- UPDATE ---

            if self.args.precollect_len <= self.frames:

                # check if we are pre-training the VAE
                if self.args.pretrain_len > self.iter_idx:
                    for p in range(self.args.num_vae_updates_per_pretrain):
                        self.vae.compute_vae_loss(
                            update=True,
                            pretrain_index=self.iter_idx *
                            self.args.num_vae_updates_per_pretrain + p)
                # otherwise do the normal update (policy + vae)
                else:

                    train_stats = self.update(state=prev_state,
                                              belief=belief,
                                              task=task,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar)

                    # log
                    run_stats = [
                        action, self.policy_storage.action_log_probs, value
                    ]
                    with torch.no_grad():
                        self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

        self.envs.close()