Exemple #1
0
class Runner(AbstractEnvRunner):

    def __init__(self, env, model, nsteps, icm , gamma , curiosity):
        super().__init__(env=env, model=model, nsteps=nsteps , icm=icm)
        assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
        assert isinstance(env, VecFrameStack)

        self.nact = env.action_space.n
        nenv = self.nenv
        self.nbatch = nenv * nsteps
        self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape

        self.obs = env.reset()
        self.obs_dtype = env.observation_space.dtype
        self.ac_dtype = env.action_space.dtype
        self.nstack = self.env.nstack
        self.nc = self.batch_ob_shape[-1] // self.nstack

        # >
        self.curiosity = curiosity
        if self.curiosity :
            self.rff = RewardForwardFilter(gamma)
            self.rff_rms = RunningMeanStd()


        # >


    def run(self):
        # enc_obs = np.split(self.obs, self.nstack, axis=3)  # so now list of obs steps
        enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
        mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards, mb_next_states = [], [], [], [], [], []
        icm_testing_rewards = []

        for _ in range(self.nsteps):
            actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
            mb_obs.append(np.copy(self.obs))
            mb_actions.append(actions)
            mb_mus.append(mus)
            mb_dones.append(self.dones)

            # >
            if self.curiosity :
                # print("3 icm here ")
                icm_states = self.obs
            # >

            obs, rewards, dones, _ = self.env.step(actions)
            # states information for statefull models like LSTM
            
            if self.curiosity :
                icm_next_states = obs 
                # print("Sent parameters for \n icm_states {} , icm_next_states {} , actions {}".format(
                    # icm_states.shape , icm_next_states.shape , actions.shape))
                icm_rewards =  self.icm.calculate_intrinsic_reward(icm_states,icm_next_states,actions)
                icm_testing_rewards.append(icm_rewards)

            mb_next_states.append(np.copy(obs)) # s_t+1

            self.states = states
            self.dones = dones
            self.obs = obs
            mb_rewards.append(rewards)
            enc_obs.append(obs[..., -self.nc:])

        mb_obs.append(np.copy(self.obs))
        mb_dones.append(self.dones)
        mb_next_states.append(np.copy(obs))
        
        icm_actions = mb_actions 

        # >

        if self.curiosity :
        #     # print("5 icm here ")
        #     icm_testing_rewards.append(rewards)
        
            icm_testing_rewards = np.array(icm_testing_rewards , dtype=np.float32).swapaxes(1, 0)

        # >




        enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
        mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
        mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)

        # >
        icm_actions.append(actions)
        icm_actions = np.asarray(icm_actions, dtype=self.ac_dtype).swapaxes(1, 0)
        
        mb_next_states = np.array(mb_next_states, dtype=self.obs_dtype).swapaxes(1,0)
        # >
        
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)

        if self.curiosity:  # r_e + discounted( r_i )
            rffs = np.array([self.rff.update(rew) for rew in icm_testing_rewards.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = icm_testing_rewards / np.sqrt(self.rff_rms.var)

            mb_rewards = rews + mb_rewards


        mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)

        mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)

        mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
        mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards

        # shapes are now [nenv, nsteps, []]
        # When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.

        # print("sent parameters \n mb_obs {} next_obs {} mb_actions {} mb_rewards {} , mb_icm_actions {} , icm_testing_rewards {} ".format( 
            # mb_obs.shape, mb_next_states.shape , mb_actions.shape, mb_rewards.shape , icm_actions.shape , icm_testing_rewards.shape) )



        return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks, mb_next_states, icm_actions
Exemple #2
0
class Runner(AbstractEnvRunner):
    """
    We use this class to generate batches of experiences

    __init__:
    - Initialize the runner

    run():
    - Make a mini batch of experiences
    """
    def __init__(self, env, model, icm, curiosity, nsteps=5, gamma=0.99):
        super().__init__(env=env, model=model, icm=icm, nsteps=nsteps)
        self.gamma = gamma
        self.icm = icm
        self.curiosity = curiosity
        self.batch_action_shape = [
            x if x is not None else -1
            for x in model.train_model.action.shape.as_list()
        ]
        self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
        self.rff = RewardForwardFilter(self.gamma)
        self.rff_rms = RunningMeanStd()

    def run(self):
        # curiosity = True
        # curiosity = False

        # We initialize the lists that will contain the mb of experiences
        mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_next_states = [],[],[],[],[],[]
        mb_states = self.states
        icm_testing_rewards = []
        for n in range(self.nsteps):
            # Given observations, take action and value (V(s))
            # We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
            actions, values, states, _ = self.model.step(self.obs,
                                                         S=self.states,
                                                         M=self.dones)

            # Append the experiences
            mb_obs.append(np.copy(self.obs))
            mb_actions.append(actions)
            mb_values.append(values)
            mb_dones.append(self.dones)

            if self.curiosity == True:
                icm_states = self.obs

            # Take actions in env and look the results
            obs, rewards, dones, _ = self.env.step(actions)
            # print("received Rewards from step function ")

            # print("received Rewards ",rewards)
            if self.curiosity == True:
                icm_next_states = obs

                icm_rewards = self.icm.calculate_intrinsic_reward(
                    icm_states, icm_next_states, actions)
                # print("shape of icm rewards ",np.shape(icm_rewards))
                icm_testing_rewards.append(icm_rewards)
                # icm_rewards = [icm_rewards] * len(rewards)

                # icm_rewards = icm_rewards * 2
                # print("intrinsic Reward : ",icm_rewards)

                # icm_rewards = np.clip(icm_rewards,-constants['REWARD_CLIP'], constants['REWARD_CLIP'])

                # print("icm _ rewards : ",icm_rewards)

                # rewards = icm_rewards  + rewards
                # print("Rewards icm {} , commulative reward {} ".format(icm_rewards , rewards))

                # rewards = np.clip(rewards,-constants['REWARD_CLIP'], +constants['REWARD_CLIP'])
                # print("icm rewards ", rewards)

                # print("calculated rewards ",rewards)

            mb_next_states.append(np.copy(obs))
            self.states = states
            self.dones = dones
            for n, done in enumerate(dones):
                if done:
                    self.obs[n] = self.obs[n] * 0
            self.obs = obs
            mb_rewards.append(rewards)
        mb_dones.append(self.dones)

        # Batch of steps to batch of rollouts
        mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(
            1, 0).reshape(self.batch_ob_shape)
        mb_next_states = np.asarray(mb_next_states,
                                    dtype=self.ob_dtype).swapaxes(
                                        1, 0).reshape(self.batch_ob_shape)
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
        # > testing mean std of rewards
        if self.curiosity:
            icm_testing_rewards = np.asarray(icm_testing_rewards,
                                             dtype=np.float32).swapaxes(1, 0)
            # print("Icm rewards" ,icm_testing_rewards)
        # > testing mean std of rewards
        mb_actions = np.asarray(
            mb_actions,
            dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
        mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
        mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
        mb_masks = mb_dones[:, :-1]
        mb_dones = mb_dones[:, 1:]

        # > passing reward to reward forward filter

        # print("Merged things obs {} rewards {} actions {} dones {}".
        # format(np.shape(mb_obs) , np.shape(mb_rewards) , np.shape(mb_actions) , np.shape(mb_dones)))

        # >
        # rffs = np.array([self.rff.update(rew) for rew in mb_rewards.T])

        if self.curiosity == True:
            rffs = np.array(
                [self.rff.update(rew) for rew in icm_testing_rewards.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = icm_testing_rewards / np.sqrt(self.rff_rms.var)

            # mb_rewards = rews

            mb_rewards = mb_rewards + rews

            # now clipping the reward (-1,1)

            # mb_rewards = np.clip(mb_rewards,-constants['REWARD_CLIP'], constants['REWARD_CLIP'])
            # print(mb_rewards)
            # print(" shape of normalized reward ", np.shape(rews))

            # icm_testing_rewards = (icm_testing_rewards >  rffs_mean).astype(np.float32)
            # np.place(icm_testing_rewards, icm_testing_rewards > 0, 0.2)

        # np.interp(icm_testing_rewards.ravel() , (rffs_mean+)  , ())

        # icm_testing_rewards = icm_testing_rewards.ravel()
        # print("\n\nIcm Rewards : ",icm_testing_rewards)

        # print(" icm testing rewards ")
        # print("icm testing reward : mean {} , std {} , division {} ".format(rffs_mean , rffs_std , ((rffs_mean + rffs_std)/2 ) ) )

        # print("ICM testing rewards " , icm_testing_rewards)

        # icm_testing_rewards[icm_testing_rewards > rffs_mean] = 0.5
        # icm_testing_rewards[icm_testing_rewards < rffs_mean] = 0
        # icm_testing_rewards[icm_testing_rewards < rffs_mean] = 0
        # print("icm rewards ", icm_testing_rewards)

        # mb_rewards = icm_testing_rewards + mb_rewards

        # print( mb_rewards)
        # mb_rewards = mb_rewards[mb_rewards > 1]
        # mb_rewards = [1 if mb_rewards[mb_rewards >1 ] else 1]
        # mb_rewards[mb_rewards > 1] = 1
        # mask = mb_rewards[((icm_testing_rewards + mb_rewards ) % 2) == 0]

        # print("Mask ",mask)
        # mb_rewards[mask == 0] = 1

        # print("Mb reward ",mb_rewards )

        # print("Icm Rewards : ",icm_testing_rewards)
        # self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
        # rews = mb_rewards / np.sqrt(self.rff_rms.var)
        # >

        # print("update : rffs_mean {} , rffs_std {} , rffs_count {} ".format(
        # np.shape(rffs_mean),np.shape(rffs_std),np.shape(rffs_count)))

        # print(" update :  final rews {} rff_rms.var {} ".format(
        # rews , np.shape(self.rff_rms.var)))

        # print(">> the shape of rffs testing ", np.shape(rffs))

        # mb_rewards_copy = mb_rewards

        if self.curiosity == True:
            if self.gamma > 0.0:
                # Discount/bootstrap off value fn
                last_values = self.model.value(self.obs,
                                               S=self.states,
                                               M=self.dones).tolist()
                for n, (rewards, dones, value) in enumerate(
                        zip(mb_rewards, mb_dones, last_values)):
                    rewards = rewards.tolist()
                    dones = dones.tolist()
                    # if dones[-1] == 0:
                    rewards = discount_with_dones(rewards + [value],
                                                  dones + [0], self.gamma)[:-1]
                    # else:
                    # rewards = discount_with_dones(rewards, dones, self.gamma)

                    mb_rewards[n] = rewards
        else:
            # print(" Before discount_with_dones ")
            # print("Rewards " , mb_rewards)

            # print("Before rewards and values ")
            # print("Reward {} values {} ".format(mb_rewards , mb_values))
            if self.gamma > 0.0:
                # Discount/bootstrap off value fn
                last_values = self.model.value(self.obs,
                                               S=self.states,
                                               M=self.dones).tolist()
                for n, (rewards, dones, value) in enumerate(
                        zip(mb_rewards, mb_dones, last_values)):
                    rewards = rewards.tolist()
                    dones = dones.tolist()
                    if dones[-1] == 0:
                        rewards = discount_with_dones(rewards + [value],
                                                      dones + [0],
                                                      self.gamma)[:-1]
                    else:
                        rewards = discount_with_dones(rewards, dones,
                                                      self.gamma)

                    mb_rewards[n] = rewards

        # print(" After discount_with_dones ")
        # print("Orgnal discounterd Rewards " , np.shape(mb_rewards))

        # rffs_mean, rffs_std, rffs_count = mpi_moments(mb_rewards.ravel())
        # self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
        # mb_rewards = mb_rewards_copy / np.sqrt(self.rff_rms.var)

        mb_actions = mb_actions.reshape(self.batch_action_shape)

        mb_rewards = mb_rewards.flatten()
        mb_values = mb_values.flatten()
        mb_masks = mb_masks.flatten()

        if self.curiosity == True:
            mb_rews_icm = rews.flatten()

        # mb_new_updated_reward = mb_rews_icm + mb_rewards

        # print("New udpated rewards ",mb_new_updated_reward)

        # rffs_mean, rffs_std, rffs_count = mpi_moments(mb_new_updated_reward.ravel())
        # self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
        # rews = mb_new_updated_reward / np.sqrt(self.rff_rms.var)

        # print("After normalized",rews)

        # mb_new_rew = rews.flatten()

        # print("Flatten rewards and values ")
        # print("Reward {} ".format(mb_rewards ))

        # print("Merged things after obs {} rewards {} actions {} masks {}".
        # format(np.shape(mb_obs) , np.shape(mb_rewards) , np.shape(mb_actions) , np.shape(mb_masks)))

        return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, mb_next_states  # , mb_rews_icm, mb_new_updated_reward #, mb_new_rew
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 unity, dynamics_list):
        self.dynamics_list = dynamics_list
        with tf.variable_scope(scope):
            self.unity = unity
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }

    def start_interaction(self, env_fns, dynamics_list, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]
        if self.unity:
            for i in tqdm(range(int(nenvs * 2.5 + 10))):
                time.sleep(1)
            print('... long overdue sleep ends now')
            sys.stdout.flush()

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics_list=dynamics_list)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
            if self.rollout.current_max is not None else 0,
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # store images for debugging
        # from PIL import Image
        # if not os.path.exists('logs/images/'):
        #         os.makedirs('logs/images/')
        # for i in range(self.rollout.buf_obs_last.shape[0]):
        #     obs = self.rollout.buf_obs_last[i][0]
        #     Image.fromarray((obs*255.).astype(np.uint8)).save('logs/images/%04d.png'%i)

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics_list[0].last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
class PpoOptimizer(object):
    def __init__(self, scope, ob_space, ac_space, policy, use_news, gamma, lam,
                 nepochs, nminibatches, lr, cliprange, nsteps_per_seg,
                 nsegs_per_env, ent_coeff, normrew, normadv, ext_coeff,
                 int_coeff, dynamics):
        self.dynamics = dynamics
        with tf.variable_scope(scope):

            self.bootstrapped = self.dynamics.bootstrapped
            self.flipout = self.dynamics.flipout

            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.policy = policy
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ent_coeff = ent_coeff

            self.placeholder_advantage = tf.placeholder(
                tf.float32, [None, None])
            self.placeholder_ret = tf.placeholder(tf.float32, [None, None])
            self.placeholder_rews = tf.placeholder(tf.float32, [None, None])
            self.placeholder_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.placeholder_oldvpred = tf.placeholder(tf.float32,
                                                       [None, None])
            self.placeholder_lr = tf.placeholder(tf.float32, [])
            self.placeholder_cliprange = tf.placeholder(tf.float32, [])

            # if self.flipout:
            #     self.placeholder_dyn_mean = tf.placeholder(tf.float32, [None,None])
            #     self.dyn_mean = tf.reduce_max(self.placeholder_dyn_mean)

            neglogpa = self.policy.pd.neglogp(self.policy.placeholder_action)
            entropy = tf.reduce_mean(self.policy.pd.entropy())
            vpred = self.policy.vpred

            c1 = .5

            vf_loss = c1 * tf.reduce_mean(
                tf.square(vpred - self.placeholder_ret))
            ratio = tf.exp(self.placeholder_oldnlp - neglogpa)
            negadv = -self.placeholder_advantage

            polgrad_losses1 = negadv * ratio
            polgrad_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.placeholder_cliprange,
                1.0 + self.placeholder_cliprange)
            polgrad_loss_surr = tf.maximum(polgrad_losses1, polgrad_losses2)
            polgrad_loss = tf.reduce_mean(polgrad_loss_surr)
            entropy_loss = (-self.ent_coeff) * entropy

            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpa - self.placeholder_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(
                    tf.abs(polgrad_losses2 - polgrad_loss_surr) > 1e-6))
            if self.dynamics.dropout:
                regloss = tf.reduce_sum(
                    tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                self.total_loss = polgrad_loss + entropy_loss + vf_loss + regloss  #TODO i tried with negative for fun
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': polgrad_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac,
                    'regloss': regloss
                }
                self.dropout_rates = tf.get_collection('DROPOUT_RATES')
            else:
                self.total_loss = polgrad_loss + entropy_loss + vf_loss
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': polgrad_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac
                }

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.placeholder_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.placeholder_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            tf.get_default_session().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            tf.get_default_session(),
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.policy,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()
            if self.dynamics.dropout:
                self.rff2 = RewardForwardFilter(self.gamma)
                self.rff_rms2 = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
            if self.dynamics.dropout:
                rffs2 = np.array([
                    self.rff2.update(rew)
                    for rew in self.rollout.buf_rews_mean.T
                ])
                rffs2_mean, rffs2_std, rffs2_count = mpi_moments(rffs2.ravel())
                self.rff_rms2.update_from_moments(rffs2_mean, rffs2_std**2,
                                                  rffs2_count)
                rews_m = self.rollout.buf_rews_mean / np.sqrt(
                    self.rff_rms2.var)
                rews = rews_m + rews

        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # if self.flipout:
        #     info['dyn_mean'] = np.mean(self.rollout.buf_dyn_rew)
        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.policy.placeholder_action, resh(self.rollout.buf_acs)),
            (self.placeholder_rews, resh(self.rollout.buf_rews)),
            (self.placeholder_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.placeholder_oldnlp, resh(self.rollout.buf_nlps)),
            (self.policy.placeholder_observation, resh(self.rollout.buf_obs)),
            (self.placeholder_ret, resh(self.buf_rets)),
            (self.placeholder_advantage, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        # if self.flipout:
        #     ph_buf.extend([(self.placeholder_dyn_mean, resh(self.buf_n_dyn_rew))])

        if self.bootstrapped:
            ph_buf.extend([
                (self.dynamics.mask_placeholder,
                 self.rollout.buf_mask.reshape(-1, self.dynamics.n_heads, 1))
            ])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.placeholder_lr: self.lr,
                    self.placeholder_cliprange: self.cliprange
                })
                if self.dynamics.dropout:
                    fd.update({self.dynamics.is_training: True})
                mblossvals.append(tf.get_default_session().run(
                    self._losses + (self._train, ), fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.policy.get_var_values()

    def set_var_values(self, vv):
        self.policy.set_var_values(vv)
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma, lam,
                 nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env, dynamics, nepochs_dvae):
        self.dynamics = dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space                  # Box(84,84,4)
            self.ac_space = ac_space                  # Discrete(4)
            self.stochpol = stochpol                  # cnn policy 对象
            self.nepochs = nepochs                    # 3
            self.lr = lr                              # 1e-4
            self.cliprange = cliprange                # 0.1
            self.nsteps_per_seg = nsteps_per_seg      # 128
            self.nsegs_per_env = nsegs_per_env        # 1
            self.nminibatches = nminibatches          # 8
            self.gamma = gamma                        # 0.99  ppo中的参数
            self.lam = lam                            # 0.95  ppo中的参数
            self.normrew = normrew                    # 1
            self.normadv = normadv                    # 1
            self.use_news = use_news                  # False
            self.ext_coeff = ext_coeff                # 0.0     完全使用内在激励进行探索
            self.int_coeff = int_coeff                # 1.0
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])    # 记录 -log pi(a|s)
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)    # 之前选择的动作在当前策略下的-log值
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            # 定义 PPO 中的损失: critic损失, actor损失, entropy损失, 并近似KL和clip-frac
            # 计算 value function 损失
            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
            # 计算 critic 损失
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = - self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (- ent_coef) * entropy            # 熵约束, ent_coef=0.001
            approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))   # 近似 KL
            clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy,
                              'approxkl': approxkl, 'clipfrac': clipfrac}

            # add bai.
            self.dynamics_loss = None
            self.nepochs_dvae = nepochs_dvae

    def start_interaction(self, env_fns, dynamics, nlump=2):
        # 在开始与环境交互时定义变量和计算图, 初始化 rollout 类
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        # 定义损失、梯度和反向传播.  在训练时调用 sess.run(self._train) 进行迭代
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        params_dvae = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="dvae_reward")
        print("total params:", np.sum([np.prod(v.get_shape().as_list()) for v in params]))      # 6629459
        print("dvae params:", np.sum([np.prod(v.get_shape().as_list()) for v in params_dvae]))  # 2726144
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        # add bai.  单独计算 DVAE 的梯度
        gradsandvars_dvae = trainer.compute_gradients(self.dynamics_loss, params_dvae)
        self._train_dvae = trainer.apply_gradients(gradsandvars_dvae)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)        # 默认 128
        self.nlump = nlump                       # 默认 1
        self.lump_stride = nenvs // self.nlump   # 128/1=128
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        # 该类在 rollouts.py 中定义
        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        # 环境数(线程数), 周期T
        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        # 这里根据存储的奖励更新 return 和 advantage(GAE), 但写的有点复杂.
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0   从后向前
            nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
            self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:         # 规约奖励, 根据 MPI 从其余线程获取的信息
            rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)

        # 调用本类的函数, 根据奖励序列 rews 计算 advantage function
        self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)

        # 记录一些统计量进行输出
        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            rew_mean_norm=np.mean(rews),
            recent_best_ext_ret=self.rollout.current_max
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages. 对计算得到的 advantage 由 mean 和 std 进行规约.
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])

        # 将本类中定义的 placeholder 与 rollout 类中收集的样本numpy 对应起来, 准备作为 feed-dict
        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),   # 以上是rollout在于环境交互中记录的numpy
            (self.ph_ret, resh(self.buf_rets)),                  # 根据 rollout 记录计算得到的 return
            (self.ph_adv, resh(self.buf_advs)),                  # 根据 rollout 记录计算得到的 advantage.
        ]
        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])
        mblossvals = []          # 记录训练中的损失

        # 训练 Agent 损失
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}     # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])    # 计算损失, 同时进行更新

        # add bai.  单独再次训练 DVAE
        for tmp in range(self.nepochs_dvae):
            print("额外训练dvae. ", tmp)
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):     # 循环8次
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}                       # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                d_loss, _ = getsess().run([self.dynamics_loss, self._train_dvae], fd)   # 计算dvae损失, 同时进行更新
                print(d_loss, end=", ")
            print("\n")

        mblossvals = [mblossvals[0]]
        info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()    # 收集样本, 计算内在奖励
        update_info = self.update()       # 更新权重
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
class RnnPpoOptimizer(SaveLoad):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, actionpol, trainpol,
                 ent_coef, gamma, lam, nepochs, lr, cliprange, nminibatches,
                 normrew, normadv, use_news, ext_coeff, int_coeff,
                 nsteps_per_seg, nsegs_per_env, action_dynamics,
                 train_dynamics, policy_mode, logdir, full_tensorboard_log,
                 tboard_period):
        self.action_dynamics = action_dynamics
        self.train_dynamics = train_dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.actionpol = actionpol
            self.trainpol = trainpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.policy_mode = policy_mode  # New
            self.full_tensorboard_log = full_tensorboard_log  # New
            self.tboard_period = tboard_period  # New
            self.ph_adv = tf.placeholder(tf.float32, [None, None],
                                         name='ph_adv')
            self.ph_ret = tf.placeholder(tf.float32, [None, None],
                                         name='ph_ret')
            self.ph_rews = tf.placeholder(tf.float32, [None, None],
                                          name='ph_rews')
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None],
                                            name='ph_oldnlp')
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None],
                                              name='ph_oldvpred')
            self.ph_lr = tf.placeholder(tf.float32, [], name='ph_lr')
            self.ph_cliprange = tf.placeholder(tf.float32, [],
                                               name='ph_cliprange')
            neglogpac = self.trainpol.pd.neglogp(self.trainpol.ph_ac)
            entropy = tf.reduce_mean(self.trainpol.pd.entropy())
            vpred = self.trainpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }

            self.logdir = logdir  #logger.get_dir()
            params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            if self.full_tensorboard_log:  # full Tensorboard logging
                for var in params:
                    tf.summary.histogram(var.name, var)
            if MPI.COMM_WORLD.Get_rank() == 0:
                self.summary_writer = tf.summary.FileWriter(
                    self.logdir, graph=getsess())  # New
                print("tensorboard dir : ", self.logdir)
                self.merged_summary_op = tf.summary.merge_all()  # New

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nminibatches=self.nminibatches,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.actionpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               train_dynamics=self.train_dynamics,
                               action_dynamics=self.action_dynamics,
                               policy_mode=self.policy_mode)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.trainpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.train_dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        ph_buf.extend([
            (self.trainpol.states_ph,
             resh(self.rollout.buf_states_first)),  # rnn inputs
            (self.trainpol.masks_ph, resh(self.rollout.buf_news))
        ])
        if 'err' in self.policy_mode:
            ph_buf.extend([(self.trainpol.pred_error,
                            resh(self.rollout.buf_errs))])  # New
        if 'ac' in self.policy_mode:
            ph_buf.extend([(self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
                           (self.trainpol.ph_ac_first,
                            resh(self.rollout.buf_acs_first))])
        if 'pred' in self.policy_mode:
            ph_buf.extend([(self.trainpol.obs_pred,
                            resh(self.rollout.buf_obpreds))])

        # with open(os.getcwd() + "/record_instruction.txt", 'r') as rec_inst:
        #     rec_n = []
        #     rec_all_n = []
        #     while True:
        #         line = rec_inst.readline()
        #         if not line: break
        #         args = line.split()
        #         rec_n.append(int(args[0]))
        #         if len(args) > 1:
        #             rec_all_n.append(int(args[0]))
        #     if self.n_updates in rec_n and MPI.COMM_WORLD.Get_rank() == 0:
        #         print("Enter!")
        #         with open(self.logdir + '/full_log' + str(self.n_updates) + '.pk', 'wb') as full_log:
        #             import pickle
        #             debug_data = {"buf_obs" : self.rollout.buf_obs,
        #                           "buf_obs_last" : self.rollout.buf_obs_last,
        #                           "buf_acs" : self.rollout.buf_acs,
        #                           "buf_acs_first" : self.rollout.buf_acs_first,
        #                           "buf_news" : self.rollout.buf_news,
        #                           "buf_news_last" : self.rollout.buf_new_last,
        #                           "buf_rews" : self.rollout.buf_rews,
        #                           "buf_ext_rews" : self.rollout.buf_ext_rews}
        #             if self.n_updates in rec_all_n:
        #                 debug_data.update({"buf_err": self.rollout.buf_errs,
        #                                     "buf_err_last": self.rollout.buf_errs_last,
        #                                     "buf_obpreds": self.rollout.buf_obpreds,
        #                                     "buf_obpreds_last": self.rollout.buf_obpreds_last,
        #                                     "buf_vpreds": self.rollout.buf_vpreds,
        #                                     "buf_vpred_last": self.rollout.buf_vpred_last,
        #                                     "buf_states": self.rollout.buf_states,
        #                                     "buf_states_first": self.rollout.buf_states_first,
        #                                     "buf_nlps": self.rollout.buf_nlps,})
        #             pickle.dump(debug_data, full_log)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        # New
        if 'err' in self.policy_mode:
            info["error"] = np.sqrt(np.power(self.rollout.buf_errs, 2).mean())

        if self.n_updates % self.tboard_period == 0 and MPI.COMM_WORLD.Get_rank(
        ) == 0:
            if self.full_tensorboard_log:
                summary = getsess().run(self.merged_summary_op, fd)  # New
                self.summary_writer.add_summary(
                    summary, self.rollout.stats["tcount"])  # New
            for k, v in info.items():
                summary = tf.Summary(value=[
                    tf.Summary.Value(tag=k, simple_value=v),
                ])
                self.summary_writer.add_summary(summary,
                                                self.rollout.stats["tcount"])

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.trainpol.get_var_values()

    def set_var_values(self, vv):
        self.trainpol.set_var_values(vv)
Exemple #7
0
    def train(self, T_max, graph_name=None):
        step = 0
        self.num_lookahead = 5

        self.reset_workers()
        self.wait_for_workers()

        stat = {
            'ploss': [],
            'vloss': [],
            'score': [],
            'int_reward': [],
            'entropy': [],
            'fwd_kl_div': [],
            'running_loss': 0
        }

        reward_tracker = RunningMeanStd()
        reward_buffer = np.empty((self.batch_size, self.num_lookahead),
                                 dtype=np.float32)
        while step < T_max:

            # these will keep tensors, which we'll use later for backpropagation
            values = []
            log_probs = []
            rewards = []
            entropies = []

            actions = []
            actions_pred = []
            features = []
            features_pred = []

            state = torch.from_numpy(self.sh_state).to(self.device)

            for i in range(self.num_lookahead):
                step += self.batch_size

                logit, value = self.model(state)
                prob = torch.softmax(logit, dim=1)
                log_prob = torch.log_softmax(logit, dim=1)
                entropy = -(prob * log_prob).sum(1, keepdim=True)

                action = prob.multinomial(1)
                sampled_lp = log_prob.gather(1, action)

                # one-hot action
                oh_action = torch.zeros(self.batch_size,
                                        self.num_actions,
                                        device=self.device).scatter_(
                                            1, action, 1)

                self.broadcast_actions(action)
                self.wait_for_workers()

                next_state = torch.from_numpy(self.sh_state).to(self.device)
                s1, s1_pred, action_pred = self.icm(state, oh_action,
                                                    next_state)

                with torch.no_grad():
                    int_reward = 0.5 * (s1 - s1_pred).pow(2).sum(dim=1,
                                                                 keepdim=True)
                reward_buffer[:, i] = int_reward.cpu().numpy().ravel()

                state = next_state

                # save variables for gradient descent
                values.append(value)
                log_probs.append(sampled_lp)
                rewards.append(int_reward)
                entropies.append(entropy)

                if not self.random:
                    actions.append(action.flatten())
                    actions_pred.append(action_pred)
                features.append(s1)
                features_pred.append(s1_pred)

                stat['entropy'].append(entropy.sum(dim=1).mean().item())
                stat['fwd_kl_div'].append(
                    torch.kl_div(s1_pred, s1).mean().item())

            # may have to update reward_buffer with gamma first
            reward_mean, reward_std, count = mpi_moments(reward_buffer.ravel())
            reward_tracker.update_from_moments(reward_mean, reward_std**2,
                                               count)
            std = np.sqrt(reward_tracker.var)
            rewards = [rwd / std for rwd in rewards]
            for rwd in rewards:
                stat['int_reward'].append(rwd.mean().item())

            state = torch.from_numpy(self.sh_state.astype(np.float32)).to(
                self.device)
            with torch.no_grad():
                _, R = self.model(state)  # R is the estimated return

            values.append(R)

            ploss = 0
            vloss = 0
            fwd_loss = 0
            inv_loss = 0

            delta = torch.zeros((self.batch_size, 1),
                                dtype=torch.float,
                                device=self.device)
            for i in reversed(range(self.num_lookahead)):
                R = rewards[i] + self.gamma * R
                advantage = R - values[i]
                vloss += (0.5 * advantage.pow(2)).mean()

                delta = rewards[i] + self.gamma * values[
                    i + 1].detach() - values[i].detach()
                ploss += -(log_probs[i] * delta +
                           0.01 * entropies[i]).mean()  # beta = 0.01

                fwd_loss += 0.5 * (features[i] -
                                   features_pred[i]).pow(2).sum(dim=1).mean()
                if not self.random:
                    inv_loss += self.cross_entropy(actions_pred[i], actions[i])

            self.optim.zero_grad()

            # inv_loss is 0 if using random features
            loss = ploss + vloss + fwd_loss + inv_loss  # 2018 Large scale curiosity paper simply sums them (no lambda and beta anymore)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                list(self.model.parameters()) + list(self.icm.parameters()),
                40)
            self.optim.step()

            while not self.channel.empty():
                score = self.channel.get()
                stat['score'].append(score)

            stat['ploss'].append(ploss.item() / self.num_lookahead)
            stat['vloss'].append(vloss.item() / self.num_lookahead)
            stat['running_loss'] = 0.99 * stat[
                'running_loss'] + 0.01 * loss.item() / self.num_lookahead

            if len(stat['score']) > 20 and step % (self.batch_size *
                                                   1000) == 0:
                now = datetime.datetime.now().strftime("%H:%M")
                print(
                    f"Step {step: <10} | Running loss: {stat['running_loss']:.4f} | Running score: {np.mean(stat['score'][-10:]):.2f} | Time: {now}"
                )
                if graph_name is not None and step % (self.batch_size *
                                                      10000) == 0:
                    plot(step,
                         stat['score'],
                         stat['int_reward'],
                         stat['ploss'],
                         stat['vloss'],
                         stat['entropy'],
                         name=graph_name)
Exemple #8
0
class PpoOptimizer(object):
    envs = None

    def __init__(self,
                 *,
                 scope,
                 ob_space,
                 ac_space,
                 stochpol,
                 ent_coef,
                 gamma,
                 lam,
                 nepochs,
                 lr,
                 cliprange,
                 nminibatches,
                 normrew,
                 normadv,
                 use_news,
                 ext_coeff,
                 int_coeff,
                 nsteps_per_seg,
                 nsegs_per_env,
                 dynamics,
                 load=False,
                 exp_name):
        self.dynamics = dynamics
        self.load = load
        self.model_path = os.path.join('models/', exp_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self._save_freq = 50000
        self._next_save = self._save_freq
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }
            self.keep_checkpoints = 5

    def _initialize_graph(self):
        # with self.graph.as_default():
        #     self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        #     init = tf.global_variables_initializer()
        #     self.sess.run(init)
        self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        getsess().run(
            tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))

    def _load_graph(self):
        self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        # logger.info('Loading Model for brain {}'.format(self.brain.brain_name))
        ckpt = tf.train.get_checkpoint_state(self.model_path)
        if ckpt is None:
            logger.info('The model {0} could not be found. Make '
                        'sure you specified the right '
                        '--run-id'.format(self.model_path))
        self.saver.restore(getsess(), ckpt.model_checkpoint_path)
        #Extract from checkpoint filename
        last_step = int(
            os.path.basename(
                ckpt.model_checkpoint_path).split('-')[1].split('.')[0])
        return last_step

    def save_if_ready(self):
        if MPI.COMM_WORLD.Get_rank() != 0:
            return
        steps = self.rollout.stats['tcount']
        if steps >= self._next_save:
            self.save_model(steps)
            self._next_save = steps + self._save_freq

    def save_model(self, steps):
        """
        Saves the model
        :param steps: The number of steps the model was trained for
        :return:
        """
        print("------ ** saving at step:", steps)
        # with self.graph.as_default():
        last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk'
        self.saver.save(getsess(), last_checkpoint)
        tf.train.write_graph(tf.Graph(),
                             self.model_path,
                             'raw_graph_def.pb',
                             as_text=False)

    def no_mpi_start_interaction(self, envs, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)
        last_step = 0
        if self.load:
            last_step = self._load_graph()
        else:
            self._initialize_graph()
            # bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(envs)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = envs

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)
        self.rollout.stats['tcount'] += last_step

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        last_step = 0
        if self.load:
            last_step = self._load_graph()
        else:
            if MPI.COMM_WORLD.Get_rank() == 0:
                self._initialize_graph()
            bcast_tf_vars_from_root(
                getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)
        self.rollout.stats['tcount'] += last_step

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        self.save_if_ready()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
Exemple #9
0
class PpoOptimizer(object):
    envs = None

    def __init__(self,
                 *,
                 hps,
                 scope,
                 ob_space,
                 env_ob_space,
                 ac_space,
                 stochpol,
                 ent_coef,
                 gamma,
                 lam,
                 nepochs,
                 lr,
                 cliprange,
                 nminibatches,
                 normrew,
                 normadv,
                 use_news,
                 ext_coeff,
                 int_coeff,
                 nsteps_per_seg,
                 nsegs_per_env,
                 dynamics,
                 exp_name,
                 env_name,
                 video_log_freq,
                 model_save_freq,
                 use_apples,
                 agent_num=None,
                 restore_name=None,
                 multi_envs=None,
                 lstm=False,
                 lstm1_size=512,
                 lstm2_size=0,
                 depth_pred=0,
                 beta_d=.1,
                 early_stop=0,
                 aux_input=0,
                 optim='adam',
                 decay=0,
                 grad_clip=0.0,
                 log_grads=0,
                 logdir='logs'):
        self.dynamics = dynamics
        self.exp_name = exp_name
        self.env_name = env_name
        self.video_log_freq = video_log_freq
        self.model_save_freq = model_save_freq
        self.use_apples = use_apples
        self.agent_num = agent_num
        self.multi_envs = multi_envs
        self.lstm = lstm
        self.lstm1_size = lstm1_size
        self.lstm2_size = lstm2_size
        self.depth_pred = depth_pred
        self.aux_input = aux_input
        self.early_stop = early_stop
        self.optim = optim
        self.decay = decay
        self.log_grads = log_grads
        self.grad_clip = grad_clip
        if log_grads:
            self.grad_writer = tf.summary.FileWriter(logdir + '/grads/' +
                                                     exp_name)
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.env_ob_space = env_ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ent_coeff = ent_coef
            self.beta_d = beta_d

            def mask(target, mask):
                mask_h = tf.abs(mask - 1)
                return tf.stop_gradient(mask_h * target) + mask * target

            if self.agent_num is None:
                self.ph_adv = tf.placeholder(tf.float32, [None, None],
                                             name='adv')
                self.ph_ret = tf.placeholder(tf.float32, [None, None],
                                             name='ret')
                self.ph_rews = tf.placeholder(tf.float32, [None, None],
                                              name='rews')
                self.ph_oldnlp = tf.placeholder(tf.float32, [None, None],
                                                name='oldnlp')
                self.ph_oldvpred = tf.placeholder(tf.float32, [None, None],
                                                  name='oldvpred')
                self.ph_lr = tf.placeholder(tf.float32, [], name='lr')
                self.ph_cliprange = tf.placeholder(tf.float32, [],
                                                   name='cliprange')
                self.ph_gradmask = tf.placeholder(tf.float32, [None, None],
                                                  name='gradmask')
                neglogpac = mask(self.stochpol.pd.neglogp(self.stochpol.ph_ac),
                                 self.ph_gradmask)
                entropy = tf.reduce_mean(self.stochpol.pd.entropy(),
                                         name='agent_entropy')
                vpred = mask(self.stochpol.vpred, self.ph_gradmask)
                vf_loss = 0.5 * tf.reduce_mean(
                    (vpred - mask(self.ph_ret, self.ph_gradmask))**2,
                    name='vf_loss')
                ratio = tf.exp(self.ph_oldnlp - neglogpac,
                               name='ratio')  # p_new / p_old
                negadv = -mask(self.ph_adv, self.ph_gradmask)
                pg_losses1 = negadv * ratio
                pg_losses2 = negadv * tf.clip_by_value(ratio,
                                                       1.0 - self.ph_cliprange,
                                                       1.0 + self.ph_cliprange,
                                                       name='pglosses2')
                pg_loss_surr = tf.maximum(pg_losses1,
                                          pg_losses2,
                                          name='loss_surr')
                pg_loss = tf.reduce_mean(pg_loss_surr, name='pg_loss')
                ent_loss = (-ent_coef) * entropy
                if self.depth_pred:
                    depth_loss = self.stochpol.depth_loss * beta_d
                approxkl = .5 * tf.reduce_mean(
                    tf.square(neglogpac - self.ph_oldnlp), name='approxkl')
                clipfrac = tf.reduce_mean(
                    tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6),
                    name='clipfrac')

                self.total_loss = pg_loss + ent_loss + vf_loss
                if self.depth_pred:
                    self.total_loss = self.total_loss + depth_loss
                    #self.total_loss = depth_loss
                    #print("adding depth loss to total loss for optimization")
                #self.total_loss = depth_loss
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': pg_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac
                }
                if self.depth_pred:
                    self.to_report.update({'depth_loss': depth_loss})
                tf.add_to_collection('adv', self.ph_adv)
                tf.add_to_collection('ret', self.ph_ret)
                tf.add_to_collection('rews', self.ph_rews)
                tf.add_to_collection('oldnlp', self.ph_oldnlp)
                tf.add_to_collection('oldvpred', self.ph_oldvpred)
                tf.add_to_collection('lr', self.ph_lr)
                tf.add_to_collection('cliprange', self.ph_cliprange)
                tf.add_to_collection('agent_entropy', entropy)
                tf.add_to_collection('vf_loss', vf_loss)
                tf.add_to_collection('ratio', ratio)
                tf.add_to_collection('pg_losses2', pg_losses2)
                tf.add_to_collection('loss_surr', pg_loss_surr)
                tf.add_to_collection('pg_loss', pg_loss)
                if self.depth_pred:
                    tf.add_to_collection('depth_loss', depth_loss)
                tf.add_to_collection('approxkl', approxkl)
                tf.add_to_collection('clipfrac', clipfrac)
            else:
                self.restore()

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        self.global_step = tf.Variable(0, trainable=False)
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            if self.agent_num is None:
                trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                           comm=MPI.COMM_WORLD)

        else:
            if self.agent_num is None:
                if self.optim == 'adam':
                    trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
                elif self.optim == 'sgd':
                    print("using sgd")
                    print("________________________")
                    if self.decay:
                        self.decay_lr = tf.train.exponential_decay(
                            self.ph_lr,
                            self.global_step,
                            2500,
                            .96,
                            staircase=True)
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.decay_lr)
                    else:
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.ph_lr)
                elif self.optim == 'momentum':
                    print('using momentum')
                    print('________________________')
                    trainer = tf.train.MomentumOptimizer(
                        learning_rate=self.ph_lr, momentum=0.9)
        if self.agent_num is None:
            gradsandvars = trainer.compute_gradients(self.total_loss, params)
            l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
            if self.log_grads:
                for grad, var in gradsandvars:
                    tf.summary.histogram(var.name + '/gradient', l2_norm(grad))
                    tf.summary.histogram(var.name + '/value', l2_norm(var))
                    grad_mean = tf.reduce_mean(tf.abs(grad))
                    tf.summary.scalar(var.name + '/grad_mean', grad_mean)
                if self.decay:
                    tf.summary.scalar('decay_lr', self.decay_lr)
                self._summary = tf.summary.merge_all()
                tf.add_to_collection("summary_op", self._summary)
            if self.grad_clip > 0:
                grads, gradvars = zip(*gradsandvars)
                grads, _ = tf.clip_by_global_norm(grads, self.grad_clip)
                gradsandvars = list(zip(grads, gradvars))

            self._train = trainer.apply_gradients(gradsandvars,
                                                  global_step=self.global_step)
            self._updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            self._train = tf.group(self._train, self._updates)
            tf.add_to_collection("train_op", self._train)
        else:
            self._train = tf.get_collection("train_op")[0]
            if self.log_grads:
                self._summary = tf.get_collection("summary_op")[0]

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.env_ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               video_log_freq=self.video_log_freq,
                               model_save_freq=self.model_save_freq,
                               use_apples=self.use_apples,
                               multi_envs=self.multi_envs,
                               lstm=self.lstm,
                               lstm1_size=self.lstm1_size,
                               lstm2_size=self.lstm2_size,
                               depth_pred=self.depth_pred,
                               early_stop=self.early_stop,
                               aux_input=self.aux_input)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def mask(x, grad_mask):
            if self.early_stop:
                #print("x shape: {}".format(np.shape(x)))
                #grad_mask = self.rollout.grad_mask
                #print("mask shape: {}".format(np.shape(pseudo_dones)))
                #no_grad_mask = 1 - grad_mask
                sh = np.shape(x)
                if sh[1] < np.shape(grad_mask)[1]:
                    return x
                broadcast_shape = (sh[0], sh[1]) + sh[2:]
                #print("mask shape: {}".format(broadcast_shape))
                for i in range(len(broadcast_shape) - 2):
                    #    no_grad_mask = tf.expand_dims(no_grad_mask, -1)
                    grad_mask = np.expand_dims(grad_mask, -1)
                #no_grad_mask =tf.cast(no_grad_mask, x.dtype)
                #grad_mask = tf.cast(grad_mask, x.dtype)
                #result = tf.placeholder(x.dtype, shape=broadcast_shape)
                #result = tf.stop_gradient(tf.multiply(no_grad_mask, x)) + tf.multiply(grad_mask, x)
                #print("Result size: {}".format(result.shape))
                result = np.multiply(grad_mask, x)
                return result
            else:
                return x

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        new_count = np.count_nonzero(self.rollout.buf_news)
        print(self.rollout.buf_news)
        if self.early_stop:
            print(self.rollout.grad_mask)
        print(new_count)
        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        if self.depth_pred:
            ph_buf.extend([
                (self.stochpol.ph_depths, resh(self.rollout.buf_depths)),
            ])
        if self.aux_input:
            ph_buf.extend([
                (self.stochpol.ph_vel, resh(self.rollout.buf_vels)),
                (self.stochpol.ph_prev_rew,
                 resh(self.rollout.buf_prev_ext_rews)),
                (self.stochpol.ph_prev_ac, resh(self.rollout.buf_prev_acs)),
            ])
        if self.dynamics.auxiliary_task.features_shared_with_policy:
            ph_buf.extend([
                (self.dynamics.auxiliary_task.ph_features,
                 resh(self.rollout.buf_feats)),
                (self.dynamics.auxiliary_task.ph_last_features,
                 resh(np.expand_dims(self.rollout.buf_feat_last, axis=1))),
            ])
        #print("Buff obs shape: {}".format(self.rollout.buf_obs.shape))
        #print("Buff rew shape: {}".format(self.rollout.buf_rews.shape))
        #print("Buff nlps shape: {}".format(self.rollout.buf_nlps.shape))
        #print("Buff vpreds shape: {}".format(self.rollout.buf_vpreds.shape))
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []
        #if self.lstm:
        #print("Train lstm 1 state: {}, {}".format(self.rollout.train_lstm1_c, self.rollout.train_lstm1_h))
        #if self.lstm2_size:
        #print("Train lstm2 state: {}, {}".format(self.rollout.train_lstm2_c, self.rollout.train_lstm2_h))
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                #mbenvinds = tf.convert_to_tensor(mbenvinds)
                #fd = {ph: buf[mbenvinds] if type(buf) is np.ndarray else buf.eval()[mbenvinds] for (ph, buf) in ph_buf}
                if self.early_stop:
                    grad_mask = self.rollout.grad_mask[mbenvinds]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_gradmask: grad_mask})
                else:
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                if self.lstm:
                    fd.update({
                        self.stochpol.c_in_1:
                        self.rollout.train_lstm1_c[mbenvinds, :],
                        self.stochpol.h_in_1:
                        self.rollout.train_lstm1_h[mbenvinds, :]
                    })
                if self.lstm and self.lstm2_size:
                    fd.update({
                        self.stochpol.c_in_2:
                        self.rollout.train_lstm2_c[mbenvinds, :],
                        self.stochpol.h_in_2:
                        self.rollout.train_lstm2_h[mbenvinds, :]
                    })
                if self.log_grads:
                    outs = getsess().run(
                        self._losses + (self._train, self._summary), fd)
                    losses = outs[:-2]
                    summary = outs[-1]
                    mblossvals.append(losses)
                    wandb.tensorflow.log(tf.summary.merge_all())
                    self.grad_writer.add_summary(
                        summary,
                        getsess().run(self.global_step))
                else:
                    mblossvals.append(getsess().run(
                        self._losses + (self._train, ), fd)[:-1])
        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        #print("Collecting rollout")
        self.rollout.collect_rollout()
        #print("Performing update")
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)

    def restore(self):
        # self.stochpol.vpred = tf.get_collection("vpred")[0]
        # self.stochpol.a_samp = tf.get_collection("a_samp")[0]
        # self.stochpol.entropy = tf.get_collection("entropy")[0]
        # self.stochpol.nlp_samp = tf.get_collection("nlp_samp")[0]
        # self.stochpol.ph_ob = tf.get_collection("ph_ob")[0]
        self.ph_adv = tf.get_collection("adv")[0]
        self.ph_ret = tf.get_collection("ret")[0]
        self.ph_rews = tf.get_collection("rews")[0]
        self.ph_oldnlp = tf.get_collection("oldnlp")[0]
        self.ph_oldvpred = tf.get_collection("oldvpred")[0]
        self.ph_lr = tf.get_collection("lr")[0]
        self.ph_cliprange = tf.get_collection("cliprange")[0]
        neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
        entropy = tf.get_collection("agent_entropy")[0]
        vpred = self.stochpol.vpred
        vf_loss = tf.get_collection("vf_loss")[0]
        ratio = tf.get_collection("ratio")[0]
        negadv = -self.ph_adv
        pg_losses1 = negadv * ratio
        pg_losses2 = tf.get_collection("pg_losses2")[0]
        pg_loss_surr = tf.get_collection("loss_surr")[0]
        pg_loss = tf.get_collection("pg_loss")[0]
        ent_loss = (-self.ent_coeff) * entropy
        approxkl = tf.get_collection("approxkl")[0]
        clipfrac = tf.get_collection("clipfrac")[0]

        self.total_loss = pg_loss + ent_loss + vf_loss
        self.to_report = {
            'tot': self.total_loss,
            'pg': pg_loss,
            'vf': vf_loss,
            'ent': entropy,
            'approxkl': approxkl,
            'clipfrac': clipfrac
        }
Exemple #10
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol,
                 ent_coef, gamma, lam, nepochs, lr, cliprange,
                 nminibatches,
                 normrew, normadv, use_news, ext_coeff, int_coeff,
                 nsteps_per_seg, nsegs_per_env, dynamics, flow_lr=None, update_periods=None):

        self.dynamics = dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            # 
            self.flow_lr = flow_lr
            self.update_periods = update_periods
            self.flow_lr_scale = self.flow_lr / self.lr

            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = - self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (- ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            print('===========================')
            print('Initializing PpoOptimizer...')
            print('flow lr', self.flow_lr)

            ### Flow-based intrinsic module part ###
            if update_periods is not None:
                print('Update periods in PpoOptimizer: ', update_periods)

                self.target_flow = int(update_periods.split(':')[0])
                self.target_all = int(update_periods.split(':')[1])
                self.period = self.target_flow + self.target_all
                self.target = self.target_flow
                self.update_flow = True
                
                print('update flow:   ', self.update_flow)
                print('target:        ', self.target)
            
            print('===========================')

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy,
                              'approxkl': approxkl, 'clipfrac': clipfrac}

    def caculate_number_parameters(self, params):
        total_parameters = 0
        for variable in params:
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters

        print('Number of total parameters: ', total_parameters)


    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        self.caculate_number_parameters(params)

        flow_params = [v for v in params if 'flow' in v.name]
        other_params = [v for v in params if 'flow' not in v.name]

        print('length of flow params: ', len(flow_params))
        print('length of agent params: ', len(other_params))
        
        trainer_flow = tf.train.AdamOptimizer(learning_rate=self.flow_lr)
        trainer_agent = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        grads = tf.gradients(self.total_loss, flow_params + other_params)
        grads_flow = grads[:len(flow_params)]
        grads_agent = grads[len(flow_params):]

        train_flow = trainer_flow.apply_gradients(zip(grads_flow, flow_params))
        train_agent = trainer_agent.apply_gradients(zip(grads_agent, other_params))

        self._train = tf.group(train_flow, train_agent)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
            self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])
        mblossvals = []

        if self.update_periods is not None:
            for _ in range(self.nepochs):
                np.random.shuffle(envinds)
                for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

                    if self.update_flow: 
                        mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])
                    else:
                        mblossvals.append(getsess().run(self._losses + (self._train_agent,), fd)[:-1])
        else:
            for _ in range(self.nepochs):
                np.random.shuffle(envinds)
                for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

                    mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])


        mblossvals = [mblossvals[0]]
        info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        if self.update_periods is not None:
            if self.n_updates % self.period == self.target:
                self.update_flow = not self.update_flow
                
                if self.update_flow == True:
                    self.target = self.target_flow
                else:
                    self.target = self.target_all

                print('--------- Reach target ---------')
                print('Update flow:  ', self.update_flow)
                print('New target:   ', self.target)

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 dynamics, exp_name, env_name, to_eval):
        self.dynamics = dynamics
        self.exp_name = exp_name
        self.env_name = env_name
        self.to_eval = to_eval
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            self.pd_logstd_min = tf.math.reduce_min(self.stochpol.pd.logstd)
            self.pd_logstd_max = tf.math.reduce_max(self.stochpol.pd.logstd)
            self.pd_std_min = tf.math.reduce_min(self.stochpol.pd.std)
            self.pd_std_max = tf.math.reduce_max(self.stochpol.pd.std)
            self.pd_mean_min = tf.math.reduce_min(self.stochpol.pd.mean)
            self.pd_mean_max = tf.math.reduce_max(self.stochpol.pd.mean)
            self.stat_report = {
                'pd_logstd_max': self.pd_logstd_max,
                'pd_logstd_min': self.pd_logstd_min,
                'pd_std_max': self.pd_std_max,
                'pd_std_min': self.pd_std_min,
                'pd_mean_max': self.pd_mean_max,
                'pd_mean_min': self.pd_mean_min
            }

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }  #, 'pd_logstd':pd_logstd, 'pd_std':pd_std, 'pd_mean':pd_mean}

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        #gvs = trainer.compute_gradients(self.total_loss, params)
        #self.gshape = gs
        #gs = [g for (g,v) in gvs]
        #self.normg = tf.linalg.global_norm(gs)
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs i]
        #self.nnormg = tf.linalg.global_norm(new_g)
        def ClipIfNotNone(grad):
            return tf.clip_by_value(grad, -25.0,
                                    25.0) if grad is not None else grad

        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        #gs = [g for (g,v) in gradsandvars]
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs if g is not None]
        gradsandvars = [(ClipIfNotNone(g), v) for g, v in gradsandvars]

        #new_g = [g for (g,v) in gradsandvars]
        #self.nnormg = tf.linalg.global_norm(new_g)
        #gradsandvars = [(ClipIfNotNone(grad), var) for grad, var in gradsandvars]
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        print('-------NENVS-------', self.nenvs)
        self.nlump = nlump
        print('----------NLUMPS-------', self.nlump)
        self.lump_stride = nenvs // self.nlump
        print('-------LSTRIDE----', self.lump_stride)
        print('--------OBS SPACE ---------', self.ob_space)
        print('-------------AC SPACE-----', self.ac_space)
        #assert 1==2
        print('-----BEFORE VEC ENV------')
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]
        print('-----AFTER VEC ENV------')
        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               to_eval=self.to_eval)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
        self.saver = tf.train.Saver(max_to_keep=5)

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
            if self.rollout.current_max is not None else 0,
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []
        #gradvals = []
        statvals = []
        self.stat_names, self.stats = zip(*list(self.stat_report.items()))
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])
                #print(fd)
                statvals.append(tf.get_default_session().run(self.stats, fd))
                #_ = getsess().run(self.gradsandvars, fd)[:-1]
                #assert 1==2
                #gradvals.append(getsess().run(self.grads, fd))

        mblossvals = [mblossvals[0]]
        statvals = [statvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info.update(
            zip(['opt_' + ln for ln in self.stat_names],
                np.mean([statvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        info[
            'video_log'] = self.rollout.buf_obs if self.n_updates % 50 == 0 else None
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        print('---------INSIDE STEP--------------')
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)

    def save_model(self, logdir, exp_name, global_step=0):
        path = osp.join(logdir, exp_name + ".ckpt")
        self.saver.save(getsess(), path, global_step=global_step)
        print("Model saved to path", path)

    def restore_model(self, logdir, exp_name):
        path = logdir
        self.saver.restore(getsess(), path)
        print("Model Restored from path", path)
Exemple #12
0
class PPO(object):
    def __init__(self,
                 scope,
                 env,
                 test_env,
                 nenvs,
                 save_dir,
                 log_dir,
                 policy,
                 use_news,
                 gamma,
                 lam,
                 nepochs,
                 nminibatches,
                 nsteps,
                 vf_coef,
                 ent_coef,
                 max_grad_norm,
                 normrew,
                 cliprew,
                 normadv,
                 for_visuals,
                 transfer_load=False,
                 load_path=None,
                 freeze_weights=False):

        self.save_dir = save_dir
        self.log_dir = log_dir

        self.transfer_load = transfer_load
        self.freeze_weights = freeze_weights

        # save the random_idx of the random connections for the future
        # this is only to check that the same connections are established
        if policy.random_idx is not None:
            random_idx = np.asarray(policy.random_idx_dict['train_random_idx'])
            npz_path = os.path.join(self.save_dir, 'train_random_idx.npz')
            np.savez_compressed(npz_path, random_idx=random_idx)

        with tf.variable_scope(scope):
            # ob_space, ac_space is from policy
            self.ob_space = policy.ob_space
            self.ac_space = policy.ac_space
            self.env = env
            self.test_env = test_env
            self.nenvs = nenvs

            self.policy = policy
            self.for_visuals = for_visuals

            # use_news
            self.use_news = use_news
            self.normrew = normrew
            self.cliprew = cliprew
            self.normadv = normadv

            # gamma and lambda
            self.gamma = gamma
            self.lam = lam
            self.max_grad_norm = max_grad_norm

            # update epochs and minibatches
            self.nepochs = nepochs
            self.nminibatches = nminibatches
            # nsteps = number of timesteps per rollout per environment
            self.nsteps = nsteps

            # placeholders
            self.ph_adv = tf.placeholder(tf.float32, [None])
            # ret = advs + vpreds, R = ph_ret
            self.ph_ret = tf.placeholder(tf.float32, [None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None])

            self.ph_lr = tf.placeholder(tf.float32, [])

            self.ph_cliprange = tf.placeholder(tf.float32, [])

            neglogpac = self.policy.pd.neglogp(self.policy.ph_ac)

            ## add to summary
            entropy = tf.reduce_mean(self.policy.pd.entropy())

            # clipped vpred, same as coinrun
            vpred = self.policy.vpred
            vpredclipped = self.ph_oldvpred + tf.clip_by_value(
                self.policy.vpred - self.ph_oldvpred, -self.ph_cliprange,
                self.ph_cliprange)
            vf_losses1 = tf.square(vpred - self.ph_ret)
            vf_losses2 = tf.square(vpredclipped - self.ph_ret)

            ## add to summary
            vf_loss = vf_coef * (
                0.5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)))

            ratio = tf.exp(self.ph_oldnlp - neglogpac)
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)

            ## add to summary
            pg_loss = tf.reduce_mean(pg_loss_surr)

            ent_loss = (-ent_coef) * entropy

            ## add to summary
            approxkl = 0.5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))

            ## add to summary
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            ## add to summary
            self.policy_loss = pg_loss + ent_loss + vf_loss

            # set summaries
            self.to_report = {
                'policy_loss': self.policy_loss,
                'pg_loss': pg_loss,
                'vf_loss': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }

            if self.transfer_load:
                self._pre_load(load_path)

            # initialize various parameters
            self._init()

    def _gradient_summaries(self, gradsandvar):
        for gradient, variable in gradsandvar:
            if isinstance(gradient, ops.IndexedSlices):
                grad_values = gradient.values
            else:
                grad_values = gradient
            tf.summary.histogram(variable.name, variable)
            tf.summary.histogram(variable.name + './gradients', grad_values)
            tf.summary.histogram(variable.name + '/gradient_norms',
                                 clip_ops.global_norm([grad_values]))

    def _init(self):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        print('__init trainable variables in the collection...')
        for p in params:
            print(p)

        # changed epsilon value
        trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr,
                                         epsilon=1e-5)

        gradsandvar = trainer.compute_gradients(self.policy_loss, params)

        grads, var = zip(*gradsandvar)

        # we only do this operation if our features network is not feat_v0
        if self.policy.feat_spec == 'feat_rss_v0':
            if self.policy.policy_spec == 'ls_c_v0':
                # this is a gradient hack to make rss work
                end_idx = -8
            elif self.policy.policy_spec == 'ls_c_hh':
                end_idx = -6
            elif self.policy.policy_spec == 'cr_fc_v0':
                end_idx = -4

            # i do not want to worry about 'full_sparsity' at the moment
            # full sparsity is a must though (if i want to use this for gradient predictions)
            elif self.policy.policy_spec == 'full_sparse':
                end_idx = 0
            else:
                raise NotImplementedError()

            # we changed this because of completely sparse training from + 0 to + 2
            sum_get = [(i + 1) for i in range(self.policy.num_layers)]
            mult = np.sum(sum_get) + 2 * (sum_get[-1] + 1)
            start_idx = end_idx - (mult * 2)

            print('start_idx: {} and end_idx: {}'.format(start_idx, end_idx))
            for g in grads:
                print(g)

            print('''


                ''')

            for i, g in enumerate(grads[start_idx:]):
                print('g: {}'.format(g))
                sparse_idx = self.policy.random_idx[i]
                full_dim = self.policy.full_dim[i]
                mult_conts = np.zeros(full_dim, dtype=np.float32)

                # this is the case for weights
                if isinstance(sparse_idx, list):
                    # we must separate (row, col) coords
                    sparse_idx = np.asarray(sparse_idx)
                    row_idx = sparse_idx[:, 0]
                    col_idx = sparse_idx[:, 1]
                    mult_conts[row_idx, col_idx] = 1.0

                elif isinstance(sparse_idx, int):
                    mult_conts[:] = 1.0

                else:
                    raise TypeError('sparse_idx have not specified type')

                g = tf.multiply(g, tf.convert_to_tensor(mult_conts))

        if self.max_grad_norm is not None:
            grads, _grad_norm = tf.clip_by_global_norm(grads,
                                                       self.max_grad_norm)

        gradsandvar = list(zip(grads, var))

        # add gradient summaries
        # self._gradient_summaries(gradsandvar)

        self._train = trainer.apply_gradients(gradsandvar)

        ## initialize variables
        sess().run(
            tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))

        ## runner
        self.runner = Runner(env=self.env,
                             test_env=self.test_env,
                             nenvs=self.nenvs,
                             policy=self.policy,
                             nsteps=self.nsteps,
                             cliprew=self.cliprew)

        self.buf_advs = np.zeros((self.nenvs, self.nsteps), np.float32)
        self.buf_rets = np.zeros((self.nenvs, self.nsteps), np.float32)

        # set saver
        self.saver = tf.train.Saver(max_to_keep=None)

        if self.transfer_load:
            self.saver = tf.train.Saver(var_list=self.vars_dict,
                                        max_to_keep=None)
        # self.summary_op = tf.summary.merge_all()
        # self.summary_writer = tf.summary.FileWriter(self.log_dir, sess().graph)

        # reward normalization
        if self.normrew:
            self.rff = utils.RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.runner.buf_news[:, t +
                                           1] if t + 1 < nsteps else self.runner.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.runner.buf_vpreds[:, t +
                                              1] if t + 1 < nsteps else self.runner.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.runner.buf_vpreds[:,
                                                                                     t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.runner.buf_vpreds

    def update(self, lr, cliprange):
        # fill rollout buffers
        self.runner.rollout()

        ## TODO: normalized rewards
        # coinrun does NOT normalize its rewards
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.runner.buf_rews.T])
            # print('optimizers.py, class PPO, def update, rffs.shape: {}'.format(rffs.shape))
            rffs_mean, rffs_std, rffs_count = utils.get_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.runner.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.runner.buf_rews)

        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        # this is a little bit different than the original coinrun implementation
        # they only normalize advantages using batch mean and std instead of
        # entire data & we add 1e-7 instead of 1e-8
        if self.normadv:
            mean, std = np.mean(self.buf_advs), np.std(self.buf_advs)
            self.buf_advs = (self.buf_advs - mean) / (std + 1e-7)

        ## this only works for non-recurrent version
        nbatch = self.nenvs * self.nsteps
        nbatch_train = nbatch // self.nminibatches

        # BUG FIXED: np.arange(nbatch_train) to np.arange(nbatch)
        # might be the cause of unstable training performance
        train_idx = np.arange(nbatch)

        # another thing is that they completely shuffle the experiences
        # flatten axes 0 and 1 (we do not swap)
        def f01(x):
            sh = x.shape
            return x.reshape(sh[0] * sh[1], *sh[2:])

        flattened_obs = f01(self.runner.buf_obs)

        ph_buf = [(self.policy.ph_ob, flattened_obs),
                  (self.policy.ph_ac, f01(self.runner.buf_acs)),
                  (self.ph_oldvpred, f01(self.runner.buf_vpreds)),
                  (self.ph_oldnlp, f01(self.runner.buf_nlps)),
                  (self.ph_ret, f01(self.buf_rets)),
                  (self.ph_adv, f01(self.buf_advs))]

        # when we begin to work with curiosity, we might need make a couple of
        # changes to this training strategy

        mblossvals = []

        for e in range(self.nepochs):
            np.random.shuffle(train_idx)

            for start in range(0, nbatch, nbatch_train):
                end = start + nbatch_train
                mbidx = train_idx[start:end]
                fd = {ph: buf[mbidx] for (ph, buf) in ph_buf}
                fd.update({self.ph_lr: lr, self.ph_cliprange: cliprange})

                mblossvals.append(sess().run(self._losses + (self._train, ),
                                             feed_dict=fd)[:-1])

        mblossvals = [mblossvals[0]]

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.runner.buf_vpreds.mean(),
            vpredstd=self.runner.buf_vpreds.std(),
            ev=explained_variance(self.runner.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.runner.buf_rews),
        )

        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))

        return info

    def evaluate(self, nlevels, save_video):
        return self.runner.evaluate(nlevels, save_video)

    def save(self, curr_iter, cliprange):
        save_path = os.path.join(self.save_dir, 'model')
        self.saver.save(sess(), save_path, global_step=curr_iter)

        def f01(x):
            sh = x.shape
            return x.reshape(sh[0] * sh[1], *sh[2:])

        if self.for_visuals:
            obs = f01(self.runner.buf_obs)
            acs = f01(self.runner.buf_acs)
            nlps = f01(self.runner.buf_nlps)
            advs = f01(self.buf_advs)
            oldvpreds = f01(self.runner.buf_vpreds)
            rets = f01(self.buf_rets)

            npz_path = os.path.join(self.save_dir,
                                    'extra-{}.npz'.format(curr_iter))
            np.savez_compressed(npz_path,
                                obs=obs,
                                acs=acs,
                                nlps=nlps,
                                advs=advs,
                                oldvpreds=oldvpreds,
                                rets=rets,
                                cliprange=cliprange)

    def _pre_load(self, load_path):
        print('''

            PRE LOADING...

            ''')

        trainable_variables = tf.get_collection_ref(
            tf.GraphKeys.TRAINABLE_VARIABLES)

        self.vars_dict = {}
        for var_ckpt in tf.train.list_variables(load_path):
            # remove the variables in the ckpt from trainable variables
            for t in trainable_variables:
                if var_ckpt[0] == t.op.name:
                    self.vars_dict[var_ckpt[0]] = t
                    if self.freeze_weights:
                        trainable_variables.remove(t)

    # load buffers
    def load_ph_bufs(self, bufs):
        self.load_fd = {
            self.policy.ph_ob: bufs['obs'],
            self.policy.ph_ac: bufs['acs'],
            self.ph_oldvpred: bufs['oldvpreds'],
            self.ph_oldnlp: bufs['nlps'],
            self.ph_ret: bufs['rets'],
            self.ph_adv: bufs['advs'],
            self.ph_lr: 0.0,
            self.ph_cliprange: bufs['cliprange']
        }

    # write everything to here, because ckpt contains adam variables too
    def variable_assignment(self, load_ckpt):
        trainable_variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)

    # instead of restore, we use this to quickly update our variables
    def _assign_op(self, v_dict, dir_dict, alpha, beta):
        trainable_variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)
        assign_op = []
        for k in v_dict.keys():
            for t in trainable_variables:
                if k == t.op.name:
                    assign_op.append(
                        tf.assign(
                            t, v_dict[t.op.name] +
                            alpha * dir_dict[0][t.op.name] +
                            beta * dir_dict[1][t.op.name]))
        return sess().run(assign_op)

    def get_loss(self, v_dict, dir_dict, alpha, beta):
        self._assign_op(v_dict, dir_dict, alpha, beta)
        return sess().run(self.policy_loss, feed_dict=self.load_fd)

    def re_run_loss(self, v_dict, dir_dict, alpha, beta, bufs):
        self._assign_op(v_dict, dir_dict, alpha, beta)
        self.runner.rollout()

        ## TODO: normalized rewards
        # coinrun does NOT normalize its rewards
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.runner.buf_rews.T])
            # print('optimizers.py, class PPO, def update, rffs.shape: {}'.format(rffs.shape))
            rffs_mean, rffs_std, rffs_count = utils.get_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.runner.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.runner.buf_rews)

        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        # this is a little bit different than the original coinrun implementation
        # they only normalize advantages using batch mean and std instead of
        # entire data & we add 1e-7 instead of 1e-8
        if self.normadv:
            mean, std = np.mean(self.buf_advs), np.std(self.buf_advs)
            self.buf_advs = (self.buf_advs - mean) / (std + 1e-7)

        def f01(x):
            sh = x.shape
            return x.reshape(sh[0] * sh[1], *sh[2:])

        flattened_obs = f01(self.runner.buf_obs)

        ph_buf = {
            self.policy.ph_ob: flattened_obs,
            self.policy.ph_ac: bufs['acs'],
            self.ph_oldvpred: f01(self.runner.buf_vpreds),
            self.ph_oldnlp: f01(self.runner.buf_nlps),
            self.ph_ret: f01(self.buf_rets),
            self.ph_adv: f01(self.buf_advs),
            self.ph_lr: 0.0,
            self.ph_cliprange: bufs['cliprange']
        }

        return sess().run(self.policy_loss, feed_dict=ph_buf)

    def load(self, load_path):

        print('''

            load variables

            ''')
        for variable in tf.train.list_variables(load_path):
            print(variable[0])

        print('''

            global variables

            ''')
        for variable in tf.global_variables():
            print(variable.name)

        self.saver.restore(sess(), load_path)
        print('loaded already trained model from {}'.format(load_path))