Beispiel #1
0
def predict_rewards(learner, means, logvars):
    reward_preds = ptu.zeros([means.shape[0], learner.env.num_states])
    for t in range(reward_preds.shape[0]):
        task_samples = learner.vae.encoder._sample_gaussian(
            ptu.FloatTensor(means[t]), ptu.FloatTensor(logvars[t]), num=50)
        reward_preds[t, :] = learner.vae.reward_decoder(
            ptu.FloatTensor(task_samples), None).mean(dim=0).detach()

    return ptu.get_numpy(reward_preds)
Beispiel #2
0
def vis_rew_pred(args, rew_pred_arr, goal, **kwargs):
    env = gym.make(args.env_name)
    if args.env_name.startswith('GridNavi'):
        fig = plt.figure(figsize=(6, 6))
    else:  # 'TwoRooms'
        fig = plt.figure(figsize=(12, 6))

    ax = plt.gca()
    cmap = plt.cm.viridis
    for state in env.states:
        cell = Rectangle((state[0], state[1]),
                         width=1,
                         height=1,
                         fc=cmap(rew_pred_arr[ptu.get_numpy(
                             env.task_to_id(ptu.FloatTensor(state)))[0]]))
        ax.add_patch(cell)
        ax.text(state[0] + 0.5,
                state[1] + 0.5,
                rew_pred_arr[ptu.get_numpy(
                    env.task_to_id(ptu.FloatTensor(state)))[0]],
                ha="center",
                va="center",
                color="w")

    plt.xlim(env.observation_space.low[0] - 0.1,
             env.observation_space.high[0] + 1 + 0.1)
    plt.ylim(env.observation_space.low[1] - 0.1,
             env.observation_space.high[1] + 1 + 0.1)

    # add goal's position on grid
    line = Line2D([goal[0] + 0.3, goal[0] + 0.7],
                  [goal[1] + 0.3, goal[1] + 0.7],
                  lw=5,
                  color='black',
                  axes=ax)
    ax.add_line(line)
    line = Line2D([goal[0] + 0.3, goal[0] + 0.7],
                  [goal[1] + 0.7, goal[1] + 0.3],
                  lw=5,
                  color='black',
                  axes=ax)
    ax.add_line(line)
    if 'title' in kwargs:
        plt.title(kwargs['title'])

    if args.env_name.startswith('GridNavi'):
        ax.axis('equal')

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(axis='both', which='both', length=0)

    fig.tight_layout()
    return fig
Beispiel #3
0
    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[[self.env.action_space.sample()]]]).long()   # Sample random action
                    else:
                        action = ptu.FloatTensor([self.env.action_space.sample()])  # Sample random action
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=obs)   # DQN
                    else:
                        action, _, _, _ = self.agent.act(obs=obs)   # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state() if "is_goal_state" in dir(self.env.unwrapped) else False
                if self.args.dense_train_sparse_test:
                    rew_to_buffer = {rew_type: rew for rew_type, rew in info.items()
                                     if rew_type.startswith('reward')}
                else:
                    rew_to_buffer = ptu.get_numpy(reward.squeeze(dim=0))
                self.policy_storage.add_sample(task=self.task_idx,
                                               observation=ptu.get_numpy(obs.squeeze(dim=0)),
                                               action=ptu.get_numpy(action.squeeze(dim=0)),
                                               reward=rew_to_buffer,
                                               terminal=np.array([term], dtype=float),
                                               next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(self.env.unwrapped) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1
Beispiel #4
0
    def estimate_log_sum_exp_q(self, qf, obs, N, action_space):
        '''
            estimate log(sum(exp(Q))) for CQL objective
        :param qf: Q function
        :param obs: state batch from buffer (s~D)
        :param N: number of actions to sample for estimation
        :param action_space: space of actions -- for uniform sampling
        :return:
        '''
        batch_size = obs.shape[0]
        obs_rep = obs.repeat(N, 1)

        # draw actions at uniform
        random_actions = ptu.FloatTensor(np.vstack([action_space.sample() for _ in range(N)]))
        random_actions = torch.repeat_interleave(random_actions, batch_size, dim=0)
        unif_a = 1 / np.prod(action_space.high - action_space.low)  # uniform density over action space

        # draw actions from current policy
        with torch.no_grad():
            policy_actions, _, _, policy_log_probs = self.act(obs_rep, return_log_prob=True)

        exp_q_unif = qf(obs_rep, random_actions) / unif_a
        exp_q_policy = qf(obs_rep, policy_actions) / torch.exp(policy_log_probs)
        log_sum_exp = torch.log(0.5 * torch.mean((exp_q_unif + exp_q_policy).reshape(N, batch_size, -1), dim=0))

        return log_sum_exp
Beispiel #5
0
def visualize_latent_space(latent_dim, n_samples, decoder):
    from sklearn.manifold import TSNE
    latents = ptu.FloatTensor(sample_random_normal(latent_dim, n_samples))

    pred_rewards = ptu.get_numpy(decoder(latents, None))
    goal_locations = np.argmax(pred_rewards, axis=-1)

    # embed to lower dim space - if dim > 2
    if latent_dim > 2:
        tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
        tsne_results = tsne.fit_transform(latents)

    # create DataFrame
    data = tsne_results if latent_dim > 2 else latents

    df = pd.DataFrame(data, columns=['x1', 'x2'])
    df["y"] = goal_locations

    fig = plt.figure(figsize=(6, 6))
    sns.scatterplot(x="x1",
                    y="x2",
                    hue="y",
                    s=30,
                    palette=sns.color_palette("hls", len(np.unique(df["y"]))),
                    data=df,
                    legend="full",
                    ax=plt.gca())
    fig.show()

    return data, goal_locations
Beispiel #6
0
def load_transitions(path, device=ptu.device):
    '''
        return arrays of obs, action ,rewards, next_obs, terminals
    :param path: path to directory in which there are numpy files
    :return:
    '''

    obs = ptu.FloatTensor(np.load(os.path.join(path, 'obs.npy'))).to(device)
    actions = ptu.FloatTensor(np.load(os.path.join(path,
                                                   'actions.npy'))).to(device)
    rewards = ptu.FloatTensor(np.load(os.path.join(path,
                                                   'rewards.npy'))).to(device)
    next_obs = ptu.FloatTensor(np.load(os.path.join(
        path, 'next_obs.npy'))).to(device)
    terminals = ptu.FloatTensor(np.load(os.path.join(
        path, 'terminals.npy'))).to(device)
    return obs, actions, rewards, next_obs, terminals
Beispiel #7
0
def env_step(env, action):
    # action should be of size: batch x 1
    action = ptu.get_numpy(action.squeeze(dim=-1))
    next_obs, reward, done, info = env.step(action)
    # move to torch
    next_obs = ptu.from_numpy(next_obs).view(-1, next_obs.shape[0])
    reward = ptu.FloatTensor([reward]).view(-1, 1)
    done = ptu.from_numpy(np.array(done, dtype=int)).view(-1, 1)

    return next_obs, reward, done, info
Beispiel #8
0
def transform_mdps_ds_to_bamdp_ds(dataset, vae, args):
    '''

    :param dataset: list of lists of lists. each list is list of arrays
    (s,a,r,s',done) arrays of size (traj_len, n_trajs, dim)
    :param vae: trained vae model
    :return:
    '''

    bamdp_dataset = []

    for i, set in enumerate(dataset):
        obs, actions, rewards, next_obs, terminals = set
        augmented_obs, belief_rewards, augmented_next_obs = \
            transform_mdp_to_bamdp_rollouts(vae, args,
                                            ptu.FloatTensor(obs),
                                            ptu.FloatTensor(actions),
                                            ptu.FloatTensor(rewards),
                                            ptu.FloatTensor(next_obs),
                                            ptu.FloatTensor(terminals))
        rewards = belief_rewards if belief_rewards is not None else ptu.FloatTensor(
            rewards)

        bamdp_dataset.append([
            ptu.get_numpy(augmented_obs), actions,
            ptu.get_numpy(rewards),
            ptu.get_numpy(augmented_next_obs), terminals
        ])
        print('{} datasets were processed.'.format(i + 1))
    return bamdp_dataset
Beispiel #9
0
def relabel_rollout(env, goal, observations, actions):
    env.set_goal(goal)
    rewards = [
        env.reward(obs, action) for (obs, action) in zip(
            ptu.get_numpy(observations) if type(observations) is not np.
            ndarray else observations,
            ptu.get_numpy(actions) if type(actions
                                           ) is not np.ndarray else actions)
    ]
    if type(observations) is np.ndarray:
        return np.vstack(rewards)
    else:
        return ptu.FloatTensor(np.vstack(rewards))
def eval_vae(dataset, vae, args):

    num_tasks = len(dataset)
    reward_preds = np.zeros((num_tasks, args.trajectory_len))
    rewards = np.zeros((num_tasks, args.trajectory_len))
    random_tasks = np.random.choice(len(dataset), 10)  # which trajectory to evaluate
    states, actions = get_heatmap_params()
    state_preds = np.zeros((num_tasks, states.shape[0]))

    for task_idx, task in enumerate(random_tasks):
        traj_idx_random = np.random.choice(dataset[0][0].shape[1])  # which trajectory to evaluate
        # get prior parameters
        with torch.no_grad():
            task_sample, task_mean, task_logvar, hidden_state = vae.encoder.prior(batch_size=1)
        for step in range(args.trajectory_len):
            # update encoding
            task_sample, task_mean, task_logvar, hidden_state = utl.update_encoding(
                encoder=vae.encoder,
                obs=ptu.FloatTensor(dataset[task][3][step, traj_idx_random]).unsqueeze(0),
                action=ptu.FloatTensor(dataset[task][1][step, traj_idx_random]).unsqueeze(0),
                reward=ptu.FloatTensor(dataset[task][2][step, traj_idx_random]).unsqueeze(0),
                done=ptu.FloatTensor(dataset[task][4][step, traj_idx_random]).unsqueeze(0),
                hidden_state=hidden_state
            )

            rewards[task_idx, step] = dataset[task][2][step, traj_idx_random].item()
            reward_preds[task_idx, step] = ptu.get_numpy(
                vae.reward_decoder(task_sample.unsqueeze(0),
                                   ptu.FloatTensor(dataset[task][3][step, traj_idx_random]).unsqueeze(0).unsqueeze(0),
                                   ptu.FloatTensor(dataset[task][0][step, traj_idx_random]).unsqueeze(0).unsqueeze(0),
                                   ptu.FloatTensor(dataset[task][1][step, traj_idx_random]).unsqueeze(0).unsqueeze(0))[0, 0])

        states, actions = get_heatmap_params()
        prediction = ptu.get_numpy(vae.state_decoder(task_sample.expand((1, 30, task_sample.shape[-1])),
                                                     ptu.FloatTensor(states).unsqueeze(0),
                                                     ptu.FloatTensor(actions).unsqueeze(0))).squeeze()
        for i in range(30):
            state_preds[task_idx, i] = 1 if np.linalg.norm(prediction[i, :]) > 1 else 0

    return rewards, reward_preds, state_preds, random_tasks
Beispiel #11
0
def eval_vae(dataset, vae, args):

    num_tasks = len(dataset)
    reward_preds = np.zeros((num_tasks, args.trajectory_len))
    rewards = np.zeros((num_tasks, args.trajectory_len))
    random_tasks = np.random.choice(
        len(dataset), NUM_EVAL_TASKS)  # which trajectory to evaluate

    for task_idx, task in enumerate(random_tasks):
        traj_idx_random = np.random.choice(
            dataset[task][0].shape[1])  # which trajectory to evaluate
        # traj_idx_random = np.random.choice(np.min([d[0].shape[1] for d in dataset]))
        # get prior parameters
        with torch.no_grad():
            task_sample, task_mean, task_logvar, hidden_state = vae.encoder.prior(
                batch_size=1)
        for step in range(args.trajectory_len):
            # update encoding
            task_sample, task_mean, task_logvar, hidden_state = utl.update_encoding(
                encoder=vae.encoder,
                obs=ptu.FloatTensor(
                    dataset[task][3][step, traj_idx_random]).unsqueeze(0),
                action=ptu.FloatTensor(
                    dataset[task][1][step, traj_idx_random]).unsqueeze(0),
                reward=ptu.FloatTensor(
                    dataset[task][2][step, traj_idx_random]).unsqueeze(0),
                done=ptu.FloatTensor(
                    dataset[task][4][step, traj_idx_random]).unsqueeze(0),
                hidden_state=hidden_state)

            rewards[task_idx, step] = dataset[task][2][step,
                                                       traj_idx_random].item()
            reward_preds[task_idx, step] = ptu.get_numpy(
                vae.reward_decoder(
                    task_sample.unsqueeze(0),
                    ptu.FloatTensor(dataset[task][3][
                        step, traj_idx_random]).unsqueeze(0).unsqueeze(0),
                    ptu.FloatTensor(dataset[task][0][
                        step, traj_idx_random]).unsqueeze(0).unsqueeze(0),
                    ptu.FloatTensor(dataset[task][1][
                        step, traj_idx_random]).unsqueeze(0).unsqueeze(0))[0,
                                                                           0])

    return rewards, reward_preds
Beispiel #12
0
 def act(self, obs, deterministic=False):
     '''
         epsilon-greedy policy based on Q values
     :param obs:
     :param deterministic: whether to sample or take most likely action
     :return: action and its corresponding Q value
     '''
     q_values = self.qf(obs)
     if deterministic:
         action = q_values.argmax(dim=-1, keepdims=True)
     else:  # epsilon greedy
         if random.random() <= self.eps:
             action = ptu.FloatTensor([
                 random.randrange(q_values.shape[-1])
                 for _ in range(q_values.shape[0])
             ]).long().unsqueeze(dim=-1)
         else:
             action = q_values.argmax(dim=-1, keepdims=True)
     value = q_values.gather(dim=-1, index=action)
     return action, value
Beispiel #13
0
def train(vae, dataset, args):
    '''

    :param vae:
    :param dataset: list of lists. each list for different task contains torch tensors of s,a,r,s',t
    :param args:
    :return:
    '''

    if args.log_tensorboard:
        writer = SummaryWriter(args.full_save_path)

    num_tasks = len(dataset)

    start_time = time.time()
    total_updates = 0
    for iter_ in range(args.num_iters):
        n_batches = np.min([
            int(np.ceil(d[0].shape[1] / args.vae_batch_num_rollouts_per_task))
            for d in dataset
        ])
        # traj_permutation = np.random.permutation(dataset[0][0].shape[1])
        traj_permutation = np.random.permutation(
            np.min([d[0].shape[1] for d in dataset]))
        loss_tr, rew_loss_tr, state_loss_tr, kl_loss_tr = 0, 0, 0, 0  # initialize loss for epoch
        n_updates = 0  # count number of updates
        for i in tqdm(range(n_batches), desc="Epoch {}".format(iter_)):

            if i == n_batches - 1:
                traj_indices = traj_permutation[
                    i * args.vae_batch_num_rollouts_per_task:]
            else:
                traj_indices = traj_permutation[
                    i * args.vae_batch_num_rollouts_per_task:(i + 1) *
                    args.vae_batch_num_rollouts_per_task]

            n_task_batches = int(np.ceil(num_tasks / args.tasks_batch_size))
            task_permutation = np.random.permutation(num_tasks)

            for j in range(n_task_batches):  # run over tasks
                if j == n_task_batches - 1:
                    indices = task_permutation[j * args.tasks_batch_size:]
                else:
                    indices = task_permutation[j *
                                               args.tasks_batch_size:(j + 1) *
                                               args.tasks_batch_size]

                obs, actions, rewards, next_obs = [], [], [], []
                for idx in indices:
                    # random_subset = np.random.permutation(dataset[idx][0].shape[1], )
                    # random_subset = np.random.choice(dataset[idx][0].shape[1], args.vae_batch_num_rollouts_per_task)
                    obs.append(
                        ptu.FloatTensor(dataset[idx][0][:, traj_indices, :]))
                    actions.append(
                        ptu.FloatTensor(dataset[idx][1][:, traj_indices, :]))
                    rewards.append(
                        ptu.FloatTensor(dataset[idx][2][:, traj_indices, :]))
                    next_obs.append(
                        ptu.FloatTensor(dataset[idx][3][:, traj_indices, :]))
                obs = torch.cat(obs, dim=1)
                actions = torch.cat(actions, dim=1)
                rewards = torch.cat(rewards, dim=1)
                next_obs = torch.cat(next_obs, dim=1)
                rew_recon_loss, state_recon_loss, kl_term = update_step(
                    vae, obs, actions, rewards, next_obs, args)

                # take average (this is the expectation over p(M))
                loss = args.rew_loss_coeff * rew_recon_loss + \
                       args.state_loss_coeff * state_recon_loss + \
                       args.kl_weight * kl_term
                # update
                vae.optimizer.zero_grad()
                loss.backward()
                vae.optimizer.step()

                n_updates += 1
                loss_tr += loss.item()
                rew_loss_tr += rew_recon_loss.item()
                state_loss_tr += state_recon_loss.item()
                kl_loss_tr += kl_term.item()

        print(
            'Elapsed time: {:.2f}, loss: {:.4f} -- rew_loss: {:.4f} -- state_loss: {:.4f} -- kl: {:.4f}'
            .format(time.time() - start_time, loss_tr / n_updates,
                    rew_loss_tr / n_updates, state_loss_tr / n_updates,
                    kl_loss_tr / n_updates))

        total_updates += n_updates
        # log tb
        if args.log_tensorboard:
            writer.add_scalar('loss/vae_loss', loss_tr / n_updates,
                              total_updates)
            writer.add_scalar('loss/rew_recon_loss', rew_loss_tr / n_updates,
                              total_updates)
            writer.add_scalar('loss/state_recon_loss',
                              state_loss_tr / n_updates, total_updates)
            writer.add_scalar('loss/kl', kl_loss_tr / n_updates, total_updates)
            if args.env_name != 'GridNavi-v2':  # TODO: eval for gridworld domain
                rewards_eval, reward_preds_eval = eval_vae(dataset, vae, args)
                for task in range(NUM_EVAL_TASKS):
                    writer.add_figure(
                        'reward_prediction/task_{}'.format(task),
                        utl_eval.plot_rew_pred_vs_rew(
                            rewards_eval[task, :], reward_preds_eval[task, :]),
                        total_updates)

        if (iter_ + 1) % args.eval_interval == 0:
            pass

        if args.save_model and (iter_ + 1) % args.save_interval == 0:
            save_path = os.path.join(os.getcwd(), args.full_save_path,
                                     'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                vae.encoder.state_dict(),
                os.path.join(save_path, "encoder{0}.pt".format(iter_ + 1)))
            if vae.reward_decoder is not None:
                torch.save(
                    vae.reward_decoder.state_dict(),
                    os.path.join(save_path,
                                 "reward_decoder{0}.pt".format(iter_ + 1)))
            if vae.state_decoder is not None:
                torch.save(
                    vae.state_decoder.state_dict(),
                    os.path.join(save_path,
                                 "state_decoder{0}.pt".format(iter_ + 1)))
Beispiel #14
0
    def collect_rollouts(self, num_rollouts, random_actions=False):
        '''

        :param num_rollouts:
        :param random_actions: whether to use policy to sample actions, or randomly sample action space
        :return:
        '''

        for rollout in range(num_rollouts):
            obs = ptu.from_numpy(self.env.reset(self.task_idx))
            obs = obs.reshape(-1, obs.shape[-1])
            done_rollout = False
            # self.policy_storage.reset_running_episode(self.task_idx)

            # if self.args.fixed_latent_params:
            #     assert 2 ** self.args.task_embedding_size >= self.args.num_tasks
            #     task_mean = ptu.FloatTensor(utl.vertices(self.args.task_embedding_size)[self.task_idx])
            #     task_logvar = -2. * ptu.ones_like(task_logvar)   # arbitrary negative enough number
            # add distribution parameters to observation - policy is conditioned on posterior
            augmented_obs = self.get_augmented_obs(obs=obs)

            while not done_rollout:
                if random_actions:
                    if self.args.policy == 'dqn':
                        action = ptu.FloatTensor([[
                            self.env.action_space.sample()
                        ]]).type(torch.long)  # Sample random action
                    else:
                        action = ptu.FloatTensor(
                            [self.env.action_space.sample()])
                else:
                    if self.args.policy == 'dqn':
                        action, _ = self.agent.act(obs=augmented_obs)  # DQN
                    else:
                        action, _, _, _ = self.agent.act(
                            obs=augmented_obs)  # SAC
                # observe reward and next obs
                next_obs, reward, done, info = utl.env_step(
                    self.env, action.squeeze(dim=0))
                done_rollout = False if ptu.get_numpy(
                    done[0][0]) == 0. else True

                # get augmented next obs
                augmented_next_obs = self.get_augmented_obs(obs=next_obs)

                # add data to policy buffer - (s+, a, r, s'+, term)
                term = self.env.unwrapped.is_goal_state(
                ) if "is_goal_state" in dir(self.env.unwrapped) else False
                self.policy_storage.add_sample(
                    task=self.task_idx,
                    observation=ptu.get_numpy(augmented_obs.squeeze(dim=0)),
                    action=ptu.get_numpy(action.squeeze(dim=0)),
                    reward=ptu.get_numpy(reward.squeeze(dim=0)),
                    terminal=np.array([term], dtype=float),
                    next_observation=ptu.get_numpy(
                        augmented_next_obs.squeeze(dim=0)))
                if not random_actions:
                    self.current_experience_storage.add_sample(
                        task=self.task_idx,
                        observation=ptu.get_numpy(
                            augmented_obs.squeeze(dim=0)),
                        action=ptu.get_numpy(action.squeeze(dim=0)),
                        reward=ptu.get_numpy(reward.squeeze(dim=0)),
                        terminal=np.array([term], dtype=float),
                        next_observation=ptu.get_numpy(
                            augmented_next_obs.squeeze(dim=0)))

                # set: obs <- next_obs
                obs = next_obs.clone()
                augmented_obs = augmented_next_obs.clone()

                # update statistics
                self._n_env_steps_total += 1
                if "is_goal_state" in dir(
                        self.env.unwrapped
                ) and self.env.unwrapped.is_goal_state():  # count successes
                    self._successes_in_buffer += 1
            self._n_rollouts_total += 1
Beispiel #15
0
    def evaluate(self):
        num_episodes = self.args.max_rollouts_per_task
        num_steps_per_episode = self.env.unwrapped._max_episode_steps
        num_tasks = self.args.num_eval_tasks
        obs_size = self.env.unwrapped.observation_space.shape[0]

        returns_per_episode = np.zeros((num_tasks, num_episodes))
        success_rate = np.zeros(num_tasks)

        rewards = np.zeros((num_tasks, self.args.trajectory_len))
        reward_preds = np.zeros((num_tasks, self.args.trajectory_len))
        observations = np.zeros(
            (num_tasks, self.args.trajectory_len + 1, obs_size))
        if self.args.policy == 'sac':
            log_probs = np.zeros((num_tasks, self.args.trajectory_len))

        # This part is very specific for the Semi-Circle env
        # if self.args.env_name == 'PointRobotSparse-v0':
        #     reward_belief = np.zeros((num_tasks, self.args.trajectory_len))
        #
        #     low_x, high_x, low_y, high_y = -2., 2., -1., 2.
        #     resolution = 0.1
        #     grid_x = np.arange(low_x, high_x + resolution, resolution)
        #     grid_y = np.arange(low_y, high_y + resolution, resolution)
        #     centers_x = (grid_x[:-1] + grid_x[1:]) / 2
        #     centers_y = (grid_y[:-1] + grid_y[1:]) / 2
        #     yv, xv = np.meshgrid(centers_y, centers_x, sparse=False, indexing='ij')
        #     centers = np.vstack([xv.ravel(), yv.ravel()]).T
        #     n_grid_points = centers.shape[0]
        #     reward_belief_discretized = np.zeros((num_tasks, self.args.trajectory_len, centers.shape[0]))

        circle_states, circle_actions = get_heatmap_params()
        state_preds = np.zeros(
            (num_tasks, self.args.trajectory_len, circle_states.shape[0]))

        for task in self.env.unwrapped.get_all_task_idx():
            obs = ptu.from_numpy(self.env.reset(task))
            obs = obs.reshape(-1, obs.shape[-1])
            step = 0

            # get prior parameters
            with torch.no_grad():
                task_sample, task_mean, task_logvar, hidden_state = self.vae.encoder.prior(
                    batch_size=1)

            observations[task, step, :] = ptu.get_numpy(obs[0, :obs_size])

            for episode_idx in range(num_episodes):
                running_reward = 0.
                for step_idx in range(num_steps_per_episode):
                    # add distribution parameters to observation - policy is conditioned on posterior
                    augmented_obs = self.get_augmented_obs(
                        obs, task_mean, task_logvar)
                    if self.args.policy == 'dqn':
                        action, value = self.agent.act(obs=augmented_obs,
                                                       deterministic=True)
                    else:
                        action, _, _, log_prob = self.agent.act(
                            obs=augmented_obs,
                            deterministic=self.args.eval_deterministic,
                            return_log_prob=True)

                    # observe reward and next obs
                    next_obs, reward, done, info = utl.env_step(
                        self.env, action.squeeze(dim=0))
                    running_reward += reward.item()
                    # done_rollout = False if ptu.get_numpy(done[0][0]) == 0. else True
                    # update encoding
                    task_sample, task_mean, task_logvar, hidden_state = self.update_encoding(
                        obs=next_obs,
                        action=action,
                        reward=reward,
                        done=done,
                        hidden_state=hidden_state)
                    rewards[task, step] = reward.item()
                    reward_preds[task, step] = ptu.get_numpy(
                        self.vae.reward_decoder(task_sample, next_obs, obs,
                                                action)[0, 0])

                    # This part is very specific for the Semi-Circle env
                    # if self.args.env_name == 'PointRobotSparse-v0':
                    #     reward_belief[task, step] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean, task_logvar, obs, next_obs, action)[0])
                    #
                    #     reward_belief_discretized[task, step, :] = ptu.get_numpy(
                    #         self.vae.compute_belief_reward(task_mean.repeat(n_grid_points, 1),
                    #                                        task_logvar.repeat(n_grid_points, 1),
                    #                                        None,
                    #                                        torch.cat((ptu.FloatTensor(centers),
                    #                                                   ptu.zeros(centers.shape[0], 1)), dim=-1).unsqueeze(0),
                    #                                        None)[:, 0])

                    observations[task, step + 1, :] = ptu.get_numpy(
                        next_obs[0, :obs_size])
                    if self.args.policy != 'dqn':
                        log_probs[task, step] = ptu.get_numpy(log_prob[0])

                    prediction = ptu.get_numpy(
                        self.vae.state_decoder(
                            task_sample.expand((1, 30, task_sample.shape[-1])),
                            ptu.FloatTensor(circle_states).unsqueeze(0),
                            ptu.FloatTensor(circle_actions).unsqueeze(
                                0))).squeeze()
                    for i in range(30):
                        state_preds[task, step, i] = 1 if np.linalg.norm(
                            prediction[i, :]) > 1 else 0

                    if "is_goal_state" in dir(
                            self.env.unwrapped
                    ) and self.env.unwrapped.is_goal_state():
                        success_rate[task] = 1.
                    # set: obs <- next_obs
                    obs = next_obs.clone()
                    step += 1

                returns_per_episode[task, episode_idx] = running_reward

        if self.args.policy == 'dqn':
            return returns_per_episode, success_rate, observations, rewards, reward_preds, state_preds
        # This part is very specific for the Semi-Circle env
        # elif self.args.env_name == 'PointRobotSparse-v0':
        #     return returns_per_episode, success_rate, log_probs, observations, \
        #            rewards, reward_preds, reward_belief, reward_belief_discretized, centers
        else:
            return returns_per_episode, success_rate, log_probs, observations, rewards, reward_preds, state_preds