Ejemplo n.º 1
0
 def predict(self, obs_np):
     bs, path_len, obs_dim = obs_np.shape
     obs = obs_np.reshape(-1, obs_dim)
     if self._coeffs is None:
         return Variable(torch.zeros((bs, path_len)))
     returns = self._features(obs).dot(self._coeffs).reshape((-1, path_len))
     return np_to_var(returns)
Ejemplo n.º 2
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 num_layers,
                 output_dim,
                 log_var_network,
                 init=xavier_init,
                 scale_final=False,
                 min_var=1e-4,
                 obs_filter=None):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)

        self.log_var_network = log_var_network
        self.modules = [self.rnn, self.linear, self.log_var_network]

        self.obs_filter = obs_filter
        self.min_log_var = np_to_var(
            np.log(np.array([min_var])).astype(np.float32))

        self.apply(init)
        # self.apply(weights_init_mlp)
        if scale_final:
            if hasattr(self.mean_network, 'network'):
                self.mean_network.network.finallayer.weight.data.mul_(0.01)
Ejemplo n.º 3
0
    def train_pd_match_sd(self, dataset, bs, itr, outer_itr):
        sampler = self.sampler
        expert_traj, _ = dataset.sample(bs)

        # sample from dataset to initialize trajectory from
        x, actdata = self.vae.splitobs(FloatTensor(expert_traj))

        z = Variable(torch.randn((bs, self.latent_dim)))
        pd_traj, sd_traj = self.forward(sampler, x, z)
        sd_traj_obs = get_numpy(sd_traj.mle)

        traj_3d_shape = (bs, -1, self.obs_dim)

        pd_traj_obs = np_to_var(pd_traj['obs'][:, 1:])

        se = sd_traj.reshape(traj_3d_shape).log_likelihood(pd_traj_obs)
        mse_sd_pd = self.compute_traj_mse(pd_traj_obs, sd_traj.mle,
                                          traj_3d_shape)

        pd_traj['rewards'] = get_numpy(se)

        self.policy_algo.process_samples(0, pd_traj, augment_obs=get_numpy(z))
        self.policy_algo.optimize_policy(0, pd_traj)

        traj_sets = [sd_traj_obs, pd_traj['obs'][:, 1:]]

        pd_traj['stats']['mse_sd_pd'] = get_numpy(mse_sd_pd.mean()).item()
        pd_traj['stats']['ll'] = np.mean(get_numpy(se))

        return pd_traj['stats']
Ejemplo n.º 4
0
    def plot_compare(self, dataset, itr):
        trajs, _ = dataset.sample_hard(self.plot_size)
        x, actdata = self.vae.splitobs(np_to_var(trajs))
        latent_dist = self.vae.encode(x)
        latent = latent_dist.sample(deterministic=True)

        #import pdb; pdb.set_trace()
        pd_traj, sd_traj = self.forward(self.plot_sampler, FloatTensor(trajs),
                                        latent)

        # sample for plottin
        traj_sets = [
            get_numpy(x)[:, self.step_dim:],
            get_numpy(sd_traj.mle), pd_traj['obs'][:, 1:]
        ]
        traj_sets = [
            x.reshape((self.plot_size, self.max_path_length, -1))
            for x in traj_sets
        ]
        traj_names = ['expert', 'sd', 'pd']
        plot_traj_sets([dataset.process(x) for x in traj_sets],
                       traj_names,
                       itr,
                       env_id=dataset.env_id)

        #dataset.plot_pd_compare([x[0] for x in traj_sets], traj_names, itr)
        for traj_no in range(5):
            dataset.plot_pd_compare([x[traj_no, ...] for x in traj_sets],
                                    traj_names,
                                    itr,
                                    name='Full_State_%d' % traj_no,
                                    save_dir='pd_match_expert')
        self.zero_grad()
Ejemplo n.º 5
0
    def compute_loss_terms(self, sd):

        obs = np_to_var(sd['obs_flat'])
        old_dist = sd['action_dist'].detach().reshape(
            (-1, sd['action_dist'].dim))

        new_dist = self.policy.forward(obs)

        mean_kl = old_dist.kl(new_dist).mean(0)
        lr = old_dist.log_likelihood_ratio(
            sd['actions'].view(-1, old_dist.dim), new_dist)

        surr_loss = -(lr.view(sd['discount_adv'].shape) *
                      np_to_var(sd['discount_adv']))
        surr_loss = surr_loss.sum(-1).mean(0)

        return surr_loss, mean_kl
Ejemplo n.º 6
0
 def sample(self, deterministic=False):
     if deterministic:
         return self.prob_3d.max(-1)[1].unsqueeze(-1)
     else:
         cat_size = self.probs_3d.size()[-1]
         onehot = np.zeros((self.bs * self.path_len, cat_size))
         idx = torch.multinomial(self.prob.view(-1, cat_size), 1)
         onehot[np.arange(self.bs * self.path_len),
                get_numpy(idx.squeeze())] = 1
         return np_to_var(onehot.reshape(self.probs_3d.size()))
Ejemplo n.º 7
0
    def plot_interp(self, dataset, itr):
        x = dataset.sample(2)[0]

        x1 = np_to_var(np.expand_dims(x[0, ...], 0))
        x2 = np_to_var(np.expand_dims(x[1, ...], 0))

        l1 = get_numpy(self.encode(x1).sample(deterministic=True))
        l2 = get_numpy(self.encode(x2).sample(deterministic=True))

        num_interp = 7
        latents = np.zeros((num_interp, self.latent_dim))
        for i in range(self.latent_dim):
            latents[:, i] = np.interp(np.linspace(0, 1, num_interp), [0, 1], [l1[0, i], l2[0, i]])


        traj = dataset.unnormalize(get_numpy(self.decode(x1.repeat(num_interp, 1).data, np_to_var(latents)).mle))

        traj_sets = [traj[i, ...] for i in range(num_interp)]
        traj_names = range(num_interp)
        dataset.plot_pd_compare(traj_sets, traj_names, itr, save_dir='interp')
Ejemplo n.º 8
0
    def test_pd(self, dataset, lim=-1):
        def rollout(env,
                    policy,
                    max_path_length,
                    add_input=None,
                    volatile=False,
                    reset_args=None):
            sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[])
            obs = env.reset(reset_args)
            for s in range(max_path_length):
                policy_input = Variable(from_numpy(np.array([obs])).float(),
                                        volatile=volatile)

                if add_input is not None:
                    policy_input = torch.cat([policy_input, add_input], -1)
                if s == 0:
                    policy.reset(1)
                if policy.recurrent():
                    policy_input = policy_input.unsqueeze(0)
                action_dist = policy.forward(policy_input)
                action = action_dist.sample()

                x = env.step(get_numpy(action))
                next_obs = x[0]
                sd['obs'].append(obs)
                sd['rewards'].append(x[1])
                sd['actions'].append(action)
                obs = next_obs
            sd['obs'].append(obs)
            sd['obs'] = np.array(sd['obs'])  # (bs, max_path_length, obs_dim)
            sd['rewards'] = np.array(sd['rewards'])  # (bs, max_path_length)
            sd['actions'] = torch.stack(sd['actions'], 1)

            return sd

        errs = []
        limdata = dataset.train_data
        if lim != -1:
            limdata = dataset.train_data[:lim]
        for data in limdata:
            traj = data.reshape(
                (-1, self.act_dim + self.obs_dim))[None, :, :self.obs_dim]
            N = traj.shape[1]
            x = np_to_var(traj)

            z_dist = self.encode(x)
            z = z_dist.sample()
            ro = rollout(self.env, self.policy, N, z, reset_args=traj[0, 0])
            errs.append(np.linalg.norm(traj - ro['obs'][:N]))

        return np.mean(errs)
Ejemplo n.º 9
0
def rollout(policy, env, max_path_length, add_input=None, plot=False):
    obs = env.reset()
    sd = dict(obs=[], rewards=[], actions=[], action_dist_lst=[])

    for s in range(max_path_length):
        if add_input is not None:
            policy_input = torch.cat([np_to_var(obs), add_input], -1).view(1, -1)
        else:
            policy_input = np_to_var(obs).unsqueeze(0)
        action_dist = policy.forward(policy_input)
        action = action_dist.sample()

        next_obs, reward, done, info = env.step(get_numpy(action.squeeze()))
        sd['obs'].append(obs)
        sd['rewards'].append(reward)
        sd['actions'].append(action)
        sd['action_dist_lst'].append(action_dist)
        obs = next_obs

        if plot:
            env.render()

    return sd
Ejemplo n.º 10
0
    def __init__(self, mean_network, log_var_network, prob_network, recurrent_network, path_len, output_dim,
                 gaussian_output_dim, cat_output_dim,
                 min_var=1e-4):
        super().__init__()
        self.mean_network = mean_network
        self.log_var_network = log_var_network
        self.prob_network = prob_network
        self.recurrent_network = recurrent_network
        self.modules = [self.mean_network, self.log_var_network, self.recurrent_network, self.prob_network]
        self.min_log_var = np_to_var(np.log(np.array([min_var])).astype(np.float32))
        self.apply(xavier_init)

        self.gaussian_output_dim = gaussian_output_dim
        self.cat_output_dim = cat_output_dim # Onehot size
        self.output_dim = output_dim
        self.path_len = path_len
Ejemplo n.º 11
0
 def rollout_meta(self, latents, cur_obs, reward_fn, rstate):
     nbatch = latents.shape[1]
     state = cur_obs  #np.array([cur_obs] * nbatch)
     trajs = []
     for lat in latents:
         latent_v = np_to_var(lat)
         state_v = from_numpy(state).float()
         sd_traj = self.vae.decode(state_v, latent_v)
         self.vae.decoder.zero_grad()
         decoded_traj = get_numpy(sd_traj.mle).reshape(
             (nbatch, -1, cur_obs.shape[1]))
         state = decoded_traj[:, -1]
         trajs.append(decoded_traj)
     combo_traj = np.concatenate(trajs, axis=1)
     rewards, rstate = self.eval_rewards(combo_traj,
                                         reward_fn,
                                         rstate,
                                         discount=True)
     return rewards, combo_traj
Ejemplo n.º 12
0
    def __init__(self,
                 mean_network,
                 log_var_network,
                 init=xavier_init,
                 scale_final=False,
                 min_var=1e-4,
                 obs_filter=None):
        super().__init__()
        self.mean_network = mean_network
        self.log_var_network = log_var_network
        self.modules = [self.mean_network, self.log_var_network]
        self.obs_filter = obs_filter
        self.min_log_var = np_to_var(
            np.log(np.array([min_var])).astype(np.float32))

        self.apply(init)
        # self.apply(weights_init_mlp)
        if scale_final:
            if hasattr(self.mean_network, 'network'):
                self.mean_network.network.finallayer.weight.data.mul_(0.01)
Ejemplo n.º 13
0
    def forward(self, z, initial_input=None):
        # z is (bs, latent_dim)
        bs = z.size()[0]
        self.recurrent_network.init_hidden(bs)
        if initial_input is None:
            initial_input = self.init_input(bs)

        z = z.unsqueeze(0)  # (1, bs, latent_dim)
        x = initial_input.unsqueeze(0)  # (1, bs, sum(cat_sizes))

        probs, argmaxs = [], []
        for s in range(self.path_len):
            x = torch.cat([x, z], -1)
            prob, argmax = self.step(x)
            probs.append(prob)
            argmaxs.append(argmax)
            x = prob.unsqueeze(0)

        probs = torch.stack(probs, 1)  # (bs, path_len, sum(cat_sizes))
        argmaxs = torch.stack(argmaxs, 1)  # (bs, path_len, len(cat_sizes))
        onehot = np_to_var(np.eye(self.output_dim)[get_numpy(argmaxs)])
        dist = RecurrentCategorical(probs, self.path_len, onehot)
        return dist
Ejemplo n.º 14
0
    def forward(self, z, initial_input=None):
        # z is (bs, latent_dim)
        # Initial input is initial obs (bs, obs_dim)
        bs = z.size()[0]
        self.recurrent_network.init_hidden(bs)
        if initial_input is None:
            initial_input = self.init_input(bs)

        z = z.unsqueeze(0) # (1, bs, latent_dim)
        initial_input = initial_input.unsqueeze(0) # (1, bs, obs_dim)

        means, log_vars, probs, onehots = [], [], [], []
        x = initial_input

        for s in range(self.path_len):
            x = torch.cat([x, z], -1)
            mean, log_var, prob = self.step(x)
            onehot = np.zeros(prob.size()[1:])
            onehot[np.arange(0, bs), get_numpy(torch.max(prob.squeeze(0), -1)[1]).astype(np.int32)] = 1
            onehot = np_to_var(onehot).unsqueeze(0)
            x = torch.cat([mean, onehot], -1)
            #x = Variable(torch.randn(mean.size())) * torch.exp(log_var) + mean
            means.append(mean.squeeze(dim=0))
            log_vars.append(log_var.squeeze(dim=0))
            probs.append(prob.squeeze(dim=0))
            onehots.append(onehot.squeeze(dim=0))

        means = torch.stack(means, 1).view(bs, -1)
        log_vars = torch.stack(log_vars, 1).view(bs, -1)
        probs = torch.stack(probs, 1).view(bs, -1)
        onehots = torch.stack(onehots, 1).view(bs, -1)

        gauss_dist = Normal(means, log_var=log_vars)
        cat_dist = RecurrentCategorical(probs, self.path_len, onehots)

        return Mixed(gauss_dist, cat_dist, self.path_len)
Ejemplo n.º 15
0
    def predict(self, obs_np):
        bs, path_len, obs_dim = obs_np.shape

        obs = np_to_var(obs_np.reshape(-1, obs_dim).astype(np.float32))

        return self.network(obs).view(-1, path_len)
Ejemplo n.º 16
0
 def loss(self, sd):
     loss = -(sd['log_prob'] * np_to_var(sd['discount_adv']))
     return loss.sum(-1).mean(0)
Ejemplo n.º 17
0
    def optimize_policy(self, itr, samples_data):
        try_penalty = float(
            np.clip(self._penalty, self._min_penalty, self._max_penalty))
        penalty_scale_factor = None

        def gen_f_opt(penalty):
            def f(flat_params):
                self.policy.set_params_flat(from_numpy(flat_params))
                return self.get_opt_output(samples_data, penalty)

            return f

        cur_params = get_numpy(self.policy.get_params_flat().double())
        opt_params = cur_params

        # Save views of objs for efficiency
        samples_data['obs_flat_var'] = np_to_var(samples_data['obs_flat'])
        samples_data['action_dist_flat'] = samples_data['action_dist'].detach(
        ).reshape((-1, samples_data['action_dist'].dim))
        samples_data['actions_flat'] = samples_data['actions'].view(
            -1, self.action_dim)
        samples_data['discount_adv_var'] = np_to_var(
            samples_data['discount_adv'])

        for penalty_itr in range(self._max_penalty_itr):
            logger.log('trying penalty=%.3f...' % try_penalty)

            itr_opt_params, _, _ = scipy.optimize.fmin_l_bfgs_b(
                func=gen_f_opt(try_penalty),
                x0=cur_params,
                maxiter=self._max_opt_itr)

            _, try_loss, try_constraint_val = self.compute_loss_terms(
                samples_data, try_penalty)
            try_loss = get_numpy(try_loss)[0]
            try_constraint_val = get_numpy(try_constraint_val)[0]

            logger.log('penalty %f => loss %f, %s %f' %
                       (try_penalty, try_loss, self._constraint_name,
                        try_constraint_val))

            if try_constraint_val < self._max_constraint_val or \
                    (penalty_itr == self._max_penalty_itr - 1 and opt_params is None):
                opt_params = itr_opt_params

            if not self._adapt_penalty:
                break

            # Decide scale factor on the first iteration, or if constraint violation yields numerical error
            if penalty_scale_factor is None or np.isnan(try_constraint_val):
                # Increase penalty if constraint violated, or if constraint term is NAN
                if try_constraint_val > self._max_constraint_val or np.isnan(
                        try_constraint_val):
                    penalty_scale_factor = self._increase_penalty_factor
                else:
                    # Otherwise (i.e. constraint satisfied), shrink penalty
                    penalty_scale_factor = self._decrease_penalty_factor
                    opt_params = itr_opt_params
            else:
                if penalty_scale_factor > 1 and \
                                try_constraint_val <= self._max_constraint_val:
                    break
                elif penalty_scale_factor < 1 and \
                                try_constraint_val >= self._max_constraint_val:
                    break
            try_penalty *= penalty_scale_factor
            try_penalty = float(
                np.clip(try_penalty, self._min_penalty, self._max_penalty))
            self._penalty = try_penalty

        self.policy.set_params_flat(from_numpy(opt_params))
Ejemplo n.º 18
0
 def compute_mle(self):
     onehot = np.zeros(self.prob.size())
     onehot[np.arange(0, self.bs),
            get_numpy(torch.max(self.prob, -1)[1]).astype(np.int32)] = 1
     return np_to_var(onehot)
Ejemplo n.º 19
0
 def loss(self, sd):
     loss = -(sd['log_prob'] *
              np_to_var(sd['discount_adv'])[:, :self.max_path_length]
              ) - self.entropy_bonus * sd['entropy']
     return loss.sum(-1).mean(0)
Ejemplo n.º 20
0
    def train_explorer(self, dataset, test_dataset, dummy_dataset, itr):
        bs = self.batch_size

        # load fixed initial state and goals from config
        init_state = self.block_config[0]
        goals = np.array(self.block_config[1])

        # functions for computing the reward and initializing the reward state (rstate)
        # rstate is used to keep track of things such as which goal you are currently on
        reward_fn, init_rstate = self.reward_fn

        # total actual reward collected by MPC agent so far
        total_mpc_rew = np.zeros(self.mpc_batch)

        # keep track of states visited by MPC to initialize the explorer from
        all_inits = []

        # current state of mpc batche
        cur_state = np.array([init_state] * self.mpc_batch)

        # initialize the reward state for the mpc batch
        rstate = init_rstate(self.mpc_batch)

        # for visualization purposes
        mpc_preds = []
        mpc_actual = []
        mpc_span = []
        rstates = []

        # Perform MPC over max_horizon
        for T in range(self.max_horizon):
            print(T)

            # for goal visulization
            rstates.append(rstate)

            # rollout imaginary trajectories using state decoder
            rollouts = self.mpc(cur_state,
                                min(self.plan_horizon,
                                    self.max_horizon - T), self.mpc_explore,
                                self.mpc_explore_batch, reward_fn, rstate)

            # get first latent of best trajectory for each batch
            np_latents = rollouts[2][:, 0]

            # rollout the first latent in simulator
            mpc_traj = self.sampler_mpc.obtain_samples(self.mpc_batch *
                                                       self.max_path_length,
                                                       self.max_path_length,
                                                       np_to_var(np_latents),
                                                       reset_args=cur_state)

            # update reward and reward state based on trajectory from simulator
            mpc_rew, rstate = self.eval_rewards(mpc_traj['obs'], reward_fn,
                                                rstate)

            # for logging and visualization purposes
            futures = rollouts[0] + total_mpc_rew
            total_mpc_rew += mpc_rew
            mpc_preds.append(rollouts[1][0])
            mpc_span.append(rollouts[3])
            mpc_stats = {
                'mean futures': np.mean(futures),
                'std futures': np.std(futures),
                'mean actual': np.mean(total_mpc_rew),
                'std actual': np.std(total_mpc_rew),
            }
            mpc_actual.append(mpc_traj['obs'][0])
            with logger.prefix('itr #%d mpc step #%d | ' % (itr, T)):
                self.vae.print_diagnostics(mpc_stats)
            record_tabular(mpc_stats, 'mpc_stats.csv')

            # add current state to list of states explorer can initialize from
            all_inits.append(cur_state)

            # update current state to current state of simulator
            cur_state = mpc_traj['obs'][:, -1]

        # for visualization
        for idx, (actual, pred, rs, span) in enumerate(
                zip(mpc_actual, mpc_preds, rstates, mpc_span)):
            dataset.plot_pd_compare(
                [actual, pred, span[:100], span[:100, :dataset.path_len]],
                ['actual', 'pred', 'imagined', 'singlestep'],
                itr,
                save_dir='mpc_match',
                name='Pred' + str(idx),
                goals=goals,
                goalidx=rs[0])

        # compute reward at final state, for some tasks that care about final state reward
        final_reward, _ = reward_fn(cur_state, rstate)
        print(total_mpc_rew)
        print(final_reward)

        # randomly select states for explorer to explore
        start_states = np.concatenate(all_inits, axis=0)
        start_states = start_states[np.random.choice(
            start_states.shape[0],
            self.rand_per_mpc_step,
            replace=self.rand_per_mpc_step > start_states.shape[0])]

        # run the explorer from those states
        explore_len = ((self.max_path_length + 1) * self.mpc_explore_len) - 1
        self.policy_ex_algo.max_path_length = explore_len
        ex_trajs = self.sampler_ex.obtain_samples(start_states.shape[0] *
                                                  explore_len,
                                                  explore_len,
                                                  None,
                                                  reset_args=start_states)

        # Now concat actions taken by explorer with observations for adding to the dataset
        trajs = ex_trajs['obs']
        obs = trajs[:, -1]
        if hasattr(self.action_space,
                   'shape') and len(self.action_space.shape) > 0:
            acts = get_numpy(ex_trajs['actions'])
        else:
            # convert discrete actions into onehot
            act_idx = get_numpy(ex_trajs['actions'])
            acts = np.zeros(
                (trajs.shape[0], trajs.shape[1] - 1, dataset.action_dim))
            acts_reshape = acts.reshape((-1, dataset.action_dim))
            acts_reshape[range(acts_reshape.shape[0]),
                         act_idx.reshape(-1)] = 1.0

        # concat actions with obs
        acts = np.concatenate((acts, acts[:, -1:, :]), 1)
        trajacts = np.concatenate((ex_trajs['obs'], acts), axis=-1)
        trajacts = trajacts.reshape(
            (-1, self.max_path_length + 1, trajacts.shape[-1]))

        # compute train/val split
        ntrain = min(int(0.9 * trajacts.shape[0]),
                     dataset.buffer_size // self.add_frac)
        if dataset.n < dataset.batch_size and ntrain < dataset.batch_size:
            ntrain = dataset.batch_size
        nvalid = min(trajacts.shape[0] - ntrain,
                     test_dataset.buffer_size // self.add_frac)
        if test_dataset.n < test_dataset.batch_size and nvalid < test_dataset.batch_size:
            nvalid = test_dataset.batch_size

        print("Adding ", ntrain, ", Valid: ", nvalid)

        dataset.add_samples(trajacts[:ntrain].reshape((ntrain, -1)))
        test_dataset.add_samples(trajacts[-nvalid:].reshape((nvalid, -1)))

        # dummy dataset stores only data from this iteration
        dummy_dataset.clear()
        dummy_dataset.add_samples(trajacts[:-nvalid].reshape(
            (trajacts.shape[0] - nvalid, -1)))

        # compute negative ELBO on trajectories of explorer
        neg_elbos = []
        cur_batch = from_numpy(trajacts).float()
        for i in range(0, trajacts.shape[0], self.batch_size):
            mse, neg_ll, kl, bcloss, z_dist = self.vae.forward_batch(
                cur_batch[i:i + self.batch_size])
            neg_elbo = (get_numpy(neg_ll) + get_numpy(kl))
            neg_elbos.append(neg_elbo)

        # reward the explorer
        rewards = np.zeros_like(ex_trajs['rewards'])
        neg_elbos = np.concatenate(neg_elbos, axis=0)
        neg_elbos = neg_elbos.reshape((rewards.shape[0], -1))
        # just not on the first iteration, since VAE hasnt fitted yet
        if itr != 1:
            rewidx = list(
                range(self.max_path_length, explore_len,
                      self.max_path_length + 1)) + [explore_len - 1]
            for i in range(rewards.shape[0]):
                rewards[i, rewidx] = neg_elbos[i]

            # add in true reward to explorer if desired
            if self.true_reward_scale != 0:
                rstate = init_rstate(rewards.shape[0])
                for oidx in range(rewards.shape[1]):
                    r, rstate = reward_fn(ex_trajs['obs'][:, oidx], rstate)
                    rewards[:, oidx] += r * self.true_reward_scale

        ex_trajs['rewards'] = rewards

        # train explorer using PPO with neg elbo
        self.policy_ex_algo.process_samples(
            0, ex_trajs)  #, augment_obs=get_numpy(z))
        if itr != 1:
            self.policy_ex_algo.optimize_policy(0, ex_trajs)
        ex_trajs['stats']['MPC Actual'] = np.mean(total_mpc_rew)
        ex_trajs['stats']['Final Reward'] = np.mean(final_reward)

        # reset explorer if necessary
        if ex_trajs['stats']['Entropy'] < self.reset_ent:
            if hasattr(self.policy_ex, "prob_network"):
                self.policy_ex.prob_network.apply(xavier_init)
            else:
                self.policy_ex.apply(xavier_init)
                self.policy_ex.log_var_network.params_var.data = self.policy_ex.log_var_network.param_init

        # for visualization purposes
        colors = ['purple', 'magenta', 'green', 'black', 'yellow', 'black']
        fig, ax = plt.subplots(3, 2, figsize=(10, 10))
        for i in range(6):
            if i * 2 + 1 < obs.shape[1]:
                axx = ax[i // 2][i % 2]
                if i == 5:
                    axx.scatter(obs[:, -3], obs[:, -2], color=colors[i], s=10)
                else:
                    axx.scatter(obs[:, i * 2],
                                obs[:, i * 2 + 1],
                                color=colors[i],
                                s=10)
                axx.set_xlim(-3, 3)
                axx.set_ylim(-3, 3)
        path = logger.get_snapshot_dir() + '/final_dist'
        if not os.path.exists(path):
            os.makedirs(path)
        plt.savefig('%s/%d.png' % (path, itr))
        np.save(path + "/" + str(itr), obs)

        return ex_trajs['stats']