def update(self):

        summary = tf.Summary()

        #Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(safemean(list(self.I.statlists["eprew"])))
        self.ep_rews.append(eprews[0])

        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            # logger.info(f"All scores {sorted(self.scores)}")
            logger.info(f"Experiment name {self.exp_name}")
            summary.value.add(tag='Episode_mean_reward', simple_value=eprews[0])


        # Normalize intrinsic rewards.
        rffs_int = np.array([self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = safemean(rews_int)
        self.max_int_rew = np.max(rews_int)

        # Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int)

        # Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
            if self.use_news:
                nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
            else:
                nextnew = 0.0 # No dones for intrinsic reward.
            nextvals = self.I.buf_vpreds_int[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last
            nextnotnew = 1 - nextnew
            delta = rews_int[:, t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:, t]
            self.I.buf_advs_int[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        # Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
            nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
            # Use dones for extrinsic reward.
            nextvals = self.I.buf_vpreds_ext[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last
            nextnotnew = 1 - nextnew
            delta = rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:, t]
            self.I.buf_advs_ext[:, t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        # Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = self.int_coeff*self.I.buf_advs_int + self.ext_coeff*self.I.buf_advs_ext

        # Collects info for reporting.
        info = dict(
            advmean = self.I.buf_advs.mean(),
            advstd  = self.I.buf_advs.std(),
            retintmean = rets_int.mean(), # previously retmean
            retintstd  = rets_int.std(), # previously retstd
            retextmean = rets_ext.mean(), # previously not there
            retextstd  = rets_ext.std(), # previously not there
            rewintmean_unnorm = rewmean, # previously rewmean
            rewintmax_unnorm = rewmax, # previously not there
            rewintmean_norm = self.mean_int_rew, # previously rewintmean
            rewintmax_norm = self.max_int_rew, # previously rewintmax
            rewintstd_unnorm  = rewstd, # previously rewstd
            vpredintmean = self.I.buf_vpreds_int.mean(), # previously vpredmean
            vpredintstd  = self.I.buf_vpreds_int.std(), # previously vrpedstd
            vpredextmean = self.I.buf_vpreds_ext.mean(), # previously not there
            vpredextstd  = self.I.buf_vpreds_ext.std(), # previously not there
            ev_int = np.clip(explained_variance(self.I.buf_vpreds_int.ravel(), rets_int.ravel()), -1, None),
            ev_ext = np.clip(explained_variance(self.I.buf_vpreds_ext.ravel(), rets_ext.ravel()), -1, None),
            rooms = SemicolonList(self.rooms),
            n_rooms = len(self.rooms),
            best_ret = self.best_ret,
            reset_counter = self.I.reset_counter,
            max_table = self.stochpol.max_table
        )

        info[f'mem_available'] = psutil.virtual_memory().available

        to_record = {'acs': self.I.buf_acs,
                     'rews_int': self.I.buf_rews_int,
                     'rews_int_norm': rews_int,
                     'rews_ext': self.I.buf_rews_ext,
                     'vpred_int': self.I.buf_vpreds_int,
                     'vpred_ext': self.I.buf_vpreds_ext,
                     'adv_int': self.I.buf_advs_int,
                     'adv_ext': self.I.buf_advs_ext,
                     'ent': self.I.buf_ent,
                     'ret_int': rets_int,
                     'ret_ext': rets_ext,
                     }
        if self.I.venvs[0].record_obs:
            to_record['obs'] = self.I.buf_obs[None]

        # Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches

        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.stochpol.ph_ret_ext, rets_ext),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        verbose = False
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info("buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i " % (
                    str(self.I.buf_advs.shape),
                    samples, samples//self.nminibatches,
                    samples*self.comm_train_size, samples*self.comm_train_size//self.nminibatches))
            logger.info(" "*6 + fmt_row(13, self.loss_names))


        epoch = 0
        start = 0
        # Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph : buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr : self.lr, self.ph_cliprange : self.cliprange})
            all_obs = np.concatenate([self.I.buf_obs[None][mbenvinds], self.I.buf_ob_last[None][mbenvinds, None]], 1)
            fd[self.stochpol.ph_ob[None]] = all_obs
            assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \
                [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)]
            fd.update({self.stochpol.ph_mean:self.stochpol.ob_rms.mean, self.stochpol.ph_std:self.stochpol.ob_rms.var**0.5})

            ret = tf.get_default_session().run(self._losses+[self._train], feed_dict=fd)[:-1]

            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            # Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop('maxkl')
            lossdict = dict_gather(self.comm_train, lossdict, op='mean')
            maxmaxkl = dict_gather(self.comm_train, {"maxkl":_maxkl}, op='max')
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info("%i:%03i %s" % (epoch, start, fmt_row(13, [lossdict[n] for n in self.loss_names])))
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([('opt_'+n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info['tps'] = self.nsteps * self.I.nenvs / (tnow - self.I.t_last_update)
            info['time_elapsed'] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization( # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs               # NOTE: not shared via MPI
        )

        self.summary_writer.add_summary(summary, self.I.stats['n_updates'])
        self.summary_writer.flush()

        return info
Beispiel #2
0
def learn(
    env,
    policy_func,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    max_timesteps=0,
    max_episodes=0,
    max_iters=0,
    max_seconds=0,  # time constraint
    callback=None,  # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5,
    schedule='constant'  # annealing for stepsize parameters (epsilon and adam)
):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    adam.sync()

    U.load_state("save/Humanoid-v1")

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        #if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_" + name, lossval)
        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()
        U.save_state("save/Humanoid-v1")
Beispiel #3
0
import time
    def update(self):

        #Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(
            np.mean(list(self.I.statlists["eprew"])))
        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            logger.info(f"All scores {sorted(self.scores)}")

        #Normalize intrinsic rewards.
        rffs_int = np.array(
            [self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = np.mean(rews_int)
        self.max_int_rew = np.max(rews_int)

        #Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(
        ), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int)

        #Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            if self.use_news:
                nextnew = self.I.buf_news[:, t +
                                          1] if t + 1 < self.nsteps else self.I.buf_new_last
            else:
                nextnew = 0.0  #No dones for intrinsic reward.
            nextvals = self.I.buf_vpreds_int[:, t +
                                             1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last
            nextnotnew = 1 - nextnew
            delta = rews_int[:,
                             t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:,
                                                                                             t]
            self.I.buf_advs_int[:,
                                t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        #Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.I.buf_news[:, t +
                                      1] if t + 1 < self.nsteps else self.I.buf_new_last
            #Use dones for extrinsic reward.
            nextvals = self.I.buf_vpreds_ext[:, t +
                                             1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last
            nextnotnew = 1 - nextnew
            delta = rews_ext[:,
                             t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:,
                                                                                                 t]
            self.I.buf_advs_ext[:,
                                t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        #Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = self.int_coeff * self.I.buf_advs_int + self.ext_coeff * self.I.buf_advs_ext

        #Collects info for reporting.
        info = dict(
            advmean=self.I.buf_advs.mean(),
            advstd=self.I.buf_advs.std(),
            retintmean=rets_int.mean(),  # previously retmean
            retintstd=rets_int.std(),  # previously retstd
            retextmean=rets_ext.mean(),  # previously not there
            retextstd=rets_ext.std(),  # previously not there
            rewintmean_unnorm=rewmean,  # previously rewmean
            rewintmax_unnorm=rewmax,  # previously not there
            rewintmean_norm=self.mean_int_rew,  # previously rewintmean
            rewintmax_norm=self.max_int_rew,  # previously rewintmax
            rewintstd_unnorm=rewstd,  # previously rewstd
            vpredintmean=self.I.buf_vpreds_int.mean(),  # previously vpredmean
            vpredintstd=self.I.buf_vpreds_int.std(),  # previously vrpedstd
            vpredextmean=self.I.buf_vpreds_ext.mean(),  # previously not there
            vpredextstd=self.I.buf_vpreds_ext.std(),  # previously not there
            ev_int=np.clip(
                explained_variance(self.I.buf_vpreds_int.ravel(),
                                   rets_int.ravel()), -1, None),
            ev_ext=np.clip(
                explained_variance(self.I.buf_vpreds_ext.ravel(),
                                   rets_ext.ravel()), -1, None),
            rooms=SemicolonList(self.rooms),
            n_rooms=len(self.rooms),
            best_ret=self.best_ret,
            reset_counter=self.I.reset_counter)

        info[f'mem_available'] = psutil.virtual_memory().available

        to_record = {
            'acs': self.I.buf_acs,
            'rews_int': self.I.buf_rews_int,
            'rews_int_norm': rews_int,
            'rews_ext': self.I.buf_rews_ext,
            'vpred_int': self.I.buf_vpreds_int,
            'vpred_ext': self.I.buf_vpreds_ext,
            'adv_int': self.I.buf_advs_int,
            'adv_ext': self.I.buf_advs_ext,
            'ent': self.I.buf_ent,
            'ret_int': rets_int,
            'ret_ext': rets_ext,
        }

        if self.I.venvs[0].record_obs:
            if None in self.I.buf_obs:
                to_record['obs'] = self.I.buf_obs[None]
            else:
                to_record['obs'] = self.I.buf_obs['normal']

        self.recorder.record(bufs=to_record, infos=self.I.buf_epinfos)

        #Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches
        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        #verbose = True
        verbose = False
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info(
                "buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i "
                % (str(self.I.buf_advs.shape), samples, samples //
                   self.nminibatches, samples * self.comm_train_size,
                   samples * self.comm_train_size // self.nminibatches))
            logger.info(" " * 6 + fmt_row(13, self.loss_names))

        to_record_attention = None
        attention_output = None
        if os.environ['EXPERIMENT_LVL'] == 'attention' or os.environ[
                'EXPERIMENT_LVL'] == 'ego':
            try:
                #attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined:0")
                #attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined/kernel:0")
                attention_output = tf.get_default_graph().get_tensor_by_name(
                    "ppo/pol/augmented2/attention_output_combined/Conv2D:0")
            except Exception as e:
                logger.error("Exception in attention_output: {}".format(e))
                attention_output = None

        epoch = 0
        start = 0
        #Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

            if None in self.stochpol.ph_ob:
                fd[self.stochpol.ph_ob[None]] = np.concatenate([
                    self.I.buf_obs[None][mbenvinds],
                    self.I.buf_ob_last[None][mbenvinds, None]
                ], 1)
                assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \
                [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)]

            else:
                fd[self.stochpol.ph_ob['normal']] = np.concatenate([
                    self.I.buf_obs['normal'][mbenvinds],
                    self.I.buf_ob_last['normal'][mbenvinds, None]
                ], 1)
                fd[self.stochpol.ph_ob['ego']] = np.concatenate([
                    self.I.buf_obs['ego'][mbenvinds],
                    self.I.buf_ob_last['ego'][mbenvinds, None]
                ], 1)

                assert list(fd[self.stochpol.ph_ob['normal']].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['normal'].shape), \
                [fd[self.stochpol.ph_ob['normal']].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['normal'].shape)]
                assert list(fd[self.stochpol.ph_ob['ego']].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['ego'].shape), \
                [fd[self.stochpol.ph_ob['ego']].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['ego'].shape)]

            fd.update({
                self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5
            })

            if attention_output is not None:
                _train_losses = [attention_output, self._train]
            else:
                _train_losses = [self._train]

            ret = tf.get_default_session().run(self._losses + _train_losses,
                                               feed_dict=fd)[:-1]

            if attention_output is not None:
                attn_output = ret[-1]
                ret = ret[:-1]
                if None in self.I.buf_obs:
                    outshape = list(
                        self.I.buf_obs[None][mbenvinds].shape[:2]) + list(
                            attn_output.shape[1:])
                else:
                    # does not matter if it's normal or ego, the first 2 axes are the same
                    outshape = list(
                        self.I.buf_obs['normal'][mbenvinds].shape[:2]) + list(
                            attn_output.shape[1:])
                attn_output = np.reshape(attn_output, outshape)
                attn_output = attn_output[:, :, :, :, :64]

            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            #Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop('maxkl')
            lossdict = dict_gather(self.comm_train, lossdict, op='mean')
            maxmaxkl = dict_gather(self.comm_train, {"maxkl": _maxkl},
                                   op='max')
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info(
                    "%i:%03i %s" %
                    (epoch, start,
                     fmt_row(13, [lossdict[n] for n in self.loss_names])))
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

                if attention_output is not None:
                    if to_record_attention is None:
                        to_record_attention = attn_output
                    else:
                        to_record_attention = np.concatenate(
                            [to_record_attention, attn_output])

        # if to_record_attention is not None:
        #     if None in self.I.buf_obs:
        #         to_record['obs'] = self.I.buf_obs[None]
        #     else:
        #         to_record['obs'] = self.I.buf_obs['normal']

        #     to_record['attention'] = to_record_attention

        to_record_attention = None

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([('opt_' + n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info['tps'] = self.nsteps * self.I.nenvs / (tnow -
                                                        self.I.t_last_update)
            info['time_elapsed'] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization(  # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs  # NOTE: not shared via MPI
        )
        return info
    def update(self):
        # Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(
            np.mean(list(self.I.statlists["eprew"])))
        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            logger.info(f"All scores {sorted(self.scores)}")

        # Normalize intrinsic rewards.
        rffs_int = np.array(
            [self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = np.mean(rews_int)
        self.max_int_rew = np.max(rews_int)

        # Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = (
            self.I.buf_rews_int.mean(),
            self.I.buf_rews_int.std(),
            np.max(self.I.buf_rews_int),
        )

        # Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            if self.use_news:
                nextnew = (self.I.buf_news[:, t + 1]
                           if t + 1 < self.nsteps else self.I.buf_new_last)
            else:
                nextnew = 0.0  # No dones for intrinsic reward.
            nextvals = (self.I.buf_vpreds_int[:, t + 1]
                        if t + 1 < self.nsteps else self.I.buf_vpred_int_last)
            nextnotnew = 1 - nextnew
            delta = (rews_int[:, t] + self.gamma * nextvals * nextnotnew -
                     self.I.buf_vpreds_int[:, t])
            self.I.buf_advs_int[:, t] = lastgaelam = (
                delta + self.gamma * self.lam * nextnotnew * lastgaelam)
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        # Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = (self.I.buf_news[:, t + 1]
                       if t + 1 < self.nsteps else self.I.buf_new_last)
            # Use dones for extrinsic reward.
            nextvals = (self.I.buf_vpreds_ext[:, t + 1]
                        if t + 1 < self.nsteps else self.I.buf_vpred_ext_last)
            nextnotnew = 1 - nextnew
            delta = (rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew -
                     self.I.buf_vpreds_ext[:, t])
            self.I.buf_advs_ext[:, t] = lastgaelam = (
                delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam)
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        # Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = (self.int_coeff * self.I.buf_advs_int +
                           self.ext_coeff * self.I.buf_advs_ext)

        # Collects info for reporting.
        info = dict(
            advmean=self.I.buf_advs.mean(),
            advstd=self.I.buf_advs.std(),
            retintmean=rets_int.mean(),  # previously retmean
            retintstd=rets_int.std(),  # previously retstd
            retextmean=rets_ext.mean(),  # previously not there
            retextstd=rets_ext.std(),  # previously not there
            rewintmean_unnorm=rewmean,  # previously rewmean
            rewintmax_unnorm=rewmax,  # previously not there
            rewintmean_norm=self.mean_int_rew,  # previously rewintmean
            rewintmax_norm=self.max_int_rew,  # previously rewintmax
            rewintstd_unnorm=rewstd,  # previously rewstd
            vpredintmean=self.I.buf_vpreds_int.mean(),  # previously vpredmean
            vpredintstd=self.I.buf_vpreds_int.std(),  # previously vrpedstd
            vpredextmean=self.I.buf_vpreds_ext.mean(),  # previously not there
            vpredextstd=self.I.buf_vpreds_ext.std(),  # previously not there
            ev_int=np.clip(
                explained_variance(self.I.buf_vpreds_int.ravel(),
                                   rets_int.ravel()),
                -1,
                None,
            ),
            ev_ext=np.clip(
                explained_variance(self.I.buf_vpreds_ext.ravel(),
                                   rets_ext.ravel()),
                -1,
                None,
            ),
            rooms=SemicolonList(self.rooms),
            n_rooms=len(self.rooms),
            best_ret=self.best_ret,
            reset_counter=self.I.reset_counter,
        )

        info[f"mem_available"] = psutil.virtual_memory().available

        to_record = {
            "acs": self.I.buf_acs,
            "rews_int": self.I.buf_rews_int,
            "rews_int_norm": rews_int,
            "rews_ext": self.I.buf_rews_ext,
            "vpred_int": self.I.buf_vpreds_int,
            "vpred_ext": self.I.buf_vpreds_ext,
            "adv_int": self.I.buf_advs_int,
            "adv_ext": self.I.buf_advs_ext,
            "ent": self.I.buf_ent,
            "ret_int": rets_int,
            "ret_ext": rets_ext,
        }
        if self.I.venvs[0].record_obs:
            to_record["obs"] = self.I.buf_obs['obs']
        self.recorder.record(bufs=to_record, infos=self.I.buf_epinfos)

        # Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches
        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        verbose = True
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info(
                f"buffer shape {self.I.buf_advs.shape}, "
                f"samples_per_mpi={samples:d}, "
                f"mini_per_mpi={samples // self.nminibatches:d}, "
                f"samples={samples * self.comm_train_size:d}, "
                f"mini={samples * self.comm_train_size // self.nminibatches:d} "
            )
            logger.info(" " * 6 + fmt_row(13, self.loss_names))

        epoch = 0
        start = 0
        # Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
            fd[self.stochpol.ph_ob['obs']] = np.concatenate(
                [
                    self.I.buf_obs['obs'][mbenvinds],
                    self.I.buf_ob_last['obs'][mbenvinds, None],
                ],
                1,
            )

            if self.meta_rl:
                fd[self.stochpol.ph_ob['prev_acs']] = one_hot(
                    self.I.buf_acs[mbenvinds], self.ac_space.n)
                fd[self.stochpol.ph_ob['prev_rew']] = self.I.buf_rews_ext[
                    mbenvinds, ..., None]

            assert list(fd[self.stochpol.ph_ob['obs']].shape) == [
                self.I.nenvs // self.nminibatches,
                self.nsteps + 1,
            ] + list(self.ob_space.shape), [
                fd[self.stochpol.ph_ob['obs']].shape,
                [self.I.nenvs // self.nminibatches, self.nsteps + 1] +
                list(self.ob_space.shape),
            ]
            fd.update({
                self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5,
            })

            ret = tf.get_default_session().run(self._losses + [self._train],
                                               feed_dict=fd)[:-1]
            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            # Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop("maxkl")
            lossdict = dict_gather(self.comm_train, lossdict, op="mean")
            maxmaxkl = dict_gather(self.comm_train, {"maxkl": _maxkl},
                                   op="max")
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info(
                    f"{epoch:d}:{start:03d} {fmt_row(13, [lossdict[n] for n in self.loss_names])}"
                )
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([("opt_" + n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info["tps"] = self.nsteps * self.I.nenvs / (tnow -
                                                        self.I.t_last_update)
            info["time_elapsed"] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization(
            # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs  # NOTE: not shared via MPI
        )
        return info