예제 #1
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(hps=self.hps,
                               ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_advs_int = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_advs_ext = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets_int = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets_ext = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #2
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        self.caculate_number_parameters(params)

        flow_params = [v for v in params if 'flow' in v.name]
        other_params = [v for v in params if 'flow' not in v.name]

        print('length of flow params: ', len(flow_params))
        print('length of agent params: ', len(other_params))
        
        trainer_flow = tf.train.AdamOptimizer(learning_rate=self.flow_lr)
        trainer_agent = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        grads = tf.gradients(self.total_loss, flow_params + other_params)
        grads_flow = grads[:len(flow_params)]
        grads_agent = grads[len(flow_params):]

        train_flow = trainer_flow.apply_gradients(zip(grads_flow, flow_params))
        train_agent = trainer_agent.apply_gradients(zip(grads_agent, other_params))

        self._train = tf.group(train_flow, train_agent)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #3
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        # 在开始与环境交互时定义变量和计算图, 初始化 rollout 类
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        # 定义损失、梯度和反向传播.  在训练时调用 sess.run(self._train) 进行迭代
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        params_dvae = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="dvae_reward")
        print("total params:", np.sum([np.prod(v.get_shape().as_list()) for v in params]))      # 6629459
        print("dvae params:", np.sum([np.prod(v.get_shape().as_list()) for v in params_dvae]))  # 2726144
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        # add bai.  单独计算 DVAE 的梯度
        gradsandvars_dvae = trainer.compute_gradients(self.dynamics_loss, params_dvae)
        self._train_dvae = trainer.apply_gradients(gradsandvars_dvae)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)        # 默认 128
        self.nlump = nlump                       # 默认 1
        self.lump_stride = nenvs // self.nlump   # 128/1=128
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        # 该类在 rollouts.py 中定义
        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        # 环境数(线程数), 周期T
        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #4
0
    def no_mpi_start_interaction(self, envs, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)
        last_step = 0
        if self.load:
            last_step = self._load_graph()
        else:
            self._initialize_graph()
            # bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(envs)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = envs

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)
        self.rollout.stats['tcount'] += last_step

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #5
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        param_list = self.stochpol.param_list + self.dynamics.param_list + self.dynamics.auxiliary_task.param_list  # copy a link, not deepcopy.
        self.optimizer = torch.optim.Adam(param_list, lr=self.lr)
        self.optimizer.zero_grad()

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #6
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma, lam,
                 nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env, dynamics, nepochs_dvae):
        self.dynamics = dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space                  # Box(84,84,4)
            self.ac_space = ac_space                  # Discrete(4)
            self.stochpol = stochpol                  # cnn policy 对象
            self.nepochs = nepochs                    # 3
            self.lr = lr                              # 1e-4
            self.cliprange = cliprange                # 0.1
            self.nsteps_per_seg = nsteps_per_seg      # 128
            self.nsegs_per_env = nsegs_per_env        # 1
            self.nminibatches = nminibatches          # 8
            self.gamma = gamma                        # 0.99  ppo中的参数
            self.lam = lam                            # 0.95  ppo中的参数
            self.normrew = normrew                    # 1
            self.normadv = normadv                    # 1
            self.use_news = use_news                  # False
            self.ext_coeff = ext_coeff                # 0.0     完全使用内在激励进行探索
            self.int_coeff = int_coeff                # 1.0
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])    # 记录 -log pi(a|s)
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)    # 之前选择的动作在当前策略下的-log值
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            # 定义 PPO 中的损失: critic损失, actor损失, entropy损失, 并近似KL和clip-frac
            # 计算 value function 损失
            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
            # 计算 critic 损失
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = - self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (- ent_coef) * entropy            # 熵约束, ent_coef=0.001
            approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))   # 近似 KL
            clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy,
                              'approxkl': approxkl, 'clipfrac': clipfrac}

            # add bai.
            self.dynamics_loss = None
            self.nepochs_dvae = nepochs_dvae

    def start_interaction(self, env_fns, dynamics, nlump=2):
        # 在开始与环境交互时定义变量和计算图, 初始化 rollout 类
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        # 定义损失、梯度和反向传播.  在训练时调用 sess.run(self._train) 进行迭代
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        params_dvae = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="dvae_reward")
        print("total params:", np.sum([np.prod(v.get_shape().as_list()) for v in params]))      # 6629459
        print("dvae params:", np.sum([np.prod(v.get_shape().as_list()) for v in params_dvae]))  # 2726144
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        # add bai.  单独计算 DVAE 的梯度
        gradsandvars_dvae = trainer.compute_gradients(self.dynamics_loss, params_dvae)
        self._train_dvae = trainer.apply_gradients(gradsandvars_dvae)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)        # 默认 128
        self.nlump = nlump                       # 默认 1
        self.lump_stride = nenvs // self.nlump   # 128/1=128
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        # 该类在 rollouts.py 中定义
        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        # 环境数(线程数), 周期T
        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        # 这里根据存储的奖励更新 return 和 advantage(GAE), 但写的有点复杂.
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0   从后向前
            nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
            self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:         # 规约奖励, 根据 MPI 从其余线程获取的信息
            rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)

        # 调用本类的函数, 根据奖励序列 rews 计算 advantage function
        self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)

        # 记录一些统计量进行输出
        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            rew_mean_norm=np.mean(rews),
            recent_best_ext_ret=self.rollout.current_max
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages. 对计算得到的 advantage 由 mean 和 std 进行规约.
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])

        # 将本类中定义的 placeholder 与 rollout 类中收集的样本numpy 对应起来, 准备作为 feed-dict
        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),   # 以上是rollout在于环境交互中记录的numpy
            (self.ph_ret, resh(self.buf_rets)),                  # 根据 rollout 记录计算得到的 return
            (self.ph_adv, resh(self.buf_advs)),                  # 根据 rollout 记录计算得到的 advantage.
        ]
        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])
        mblossvals = []          # 记录训练中的损失

        # 训练 Agent 损失
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}     # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])    # 计算损失, 同时进行更新

        # add bai.  单独再次训练 DVAE
        for tmp in range(self.nepochs_dvae):
            print("额外训练dvae. ", tmp)
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):     # 循环8次
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}                       # 构造 feed_dict
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                d_loss, _ = getsess().run([self.dynamics_loss, self._train_dvae], fd)   # 计算dvae损失, 同时进行更新
                print(d_loss, end=", ")
            print("\n")

        mblossvals = [mblossvals[0]]
        info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()    # 收集样本, 计算内在奖励
        update_info = self.update()       # 更新权重
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
예제 #7
0
class PpoOptimizer(object):
    def __init__(self, scope, ob_space, ac_space, policy, use_news, gamma, lam,
                 nepochs, nminibatches, lr, cliprange, nsteps_per_seg,
                 nsegs_per_env, ent_coeff, normrew, normadv, ext_coeff,
                 int_coeff, dynamics):
        self.dynamics = dynamics
        with tf.variable_scope(scope):

            self.bootstrapped = self.dynamics.bootstrapped
            self.flipout = self.dynamics.flipout

            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.policy = policy
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ent_coeff = ent_coeff

            self.placeholder_advantage = tf.placeholder(
                tf.float32, [None, None])
            self.placeholder_ret = tf.placeholder(tf.float32, [None, None])
            self.placeholder_rews = tf.placeholder(tf.float32, [None, None])
            self.placeholder_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.placeholder_oldvpred = tf.placeholder(tf.float32,
                                                       [None, None])
            self.placeholder_lr = tf.placeholder(tf.float32, [])
            self.placeholder_cliprange = tf.placeholder(tf.float32, [])

            # if self.flipout:
            #     self.placeholder_dyn_mean = tf.placeholder(tf.float32, [None,None])
            #     self.dyn_mean = tf.reduce_max(self.placeholder_dyn_mean)

            neglogpa = self.policy.pd.neglogp(self.policy.placeholder_action)
            entropy = tf.reduce_mean(self.policy.pd.entropy())
            vpred = self.policy.vpred

            c1 = .5

            vf_loss = c1 * tf.reduce_mean(
                tf.square(vpred - self.placeholder_ret))
            ratio = tf.exp(self.placeholder_oldnlp - neglogpa)
            negadv = -self.placeholder_advantage

            polgrad_losses1 = negadv * ratio
            polgrad_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.placeholder_cliprange,
                1.0 + self.placeholder_cliprange)
            polgrad_loss_surr = tf.maximum(polgrad_losses1, polgrad_losses2)
            polgrad_loss = tf.reduce_mean(polgrad_loss_surr)
            entropy_loss = (-self.ent_coeff) * entropy

            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpa - self.placeholder_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(
                    tf.abs(polgrad_losses2 - polgrad_loss_surr) > 1e-6))
            if self.dynamics.dropout:
                regloss = tf.reduce_sum(
                    tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                self.total_loss = polgrad_loss + entropy_loss + vf_loss + regloss  #TODO i tried with negative for fun
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': polgrad_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac,
                    'regloss': regloss
                }
                self.dropout_rates = tf.get_collection('DROPOUT_RATES')
            else:
                self.total_loss = polgrad_loss + entropy_loss + vf_loss
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': polgrad_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac
                }

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.placeholder_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.placeholder_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            tf.get_default_session().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            tf.get_default_session(),
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.policy,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()
            if self.dynamics.dropout:
                self.rff2 = RewardForwardFilter(self.gamma)
                self.rff_rms2 = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
            if self.dynamics.dropout:
                rffs2 = np.array([
                    self.rff2.update(rew)
                    for rew in self.rollout.buf_rews_mean.T
                ])
                rffs2_mean, rffs2_std, rffs2_count = mpi_moments(rffs2.ravel())
                self.rff_rms2.update_from_moments(rffs2_mean, rffs2_std**2,
                                                  rffs2_count)
                rews_m = self.rollout.buf_rews_mean / np.sqrt(
                    self.rff_rms2.var)
                rews = rews_m + rews

        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # if self.flipout:
        #     info['dyn_mean'] = np.mean(self.rollout.buf_dyn_rew)
        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.policy.placeholder_action, resh(self.rollout.buf_acs)),
            (self.placeholder_rews, resh(self.rollout.buf_rews)),
            (self.placeholder_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.placeholder_oldnlp, resh(self.rollout.buf_nlps)),
            (self.policy.placeholder_observation, resh(self.rollout.buf_obs)),
            (self.placeholder_ret, resh(self.buf_rets)),
            (self.placeholder_advantage, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        # if self.flipout:
        #     ph_buf.extend([(self.placeholder_dyn_mean, resh(self.buf_n_dyn_rew))])

        if self.bootstrapped:
            ph_buf.extend([
                (self.dynamics.mask_placeholder,
                 self.rollout.buf_mask.reshape(-1, self.dynamics.n_heads, 1))
            ])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.placeholder_lr: self.lr,
                    self.placeholder_cliprange: self.cliprange
                })
                if self.dynamics.dropout:
                    fd.update({self.dynamics.is_training: True})
                mblossvals.append(tf.get_default_session().run(
                    self._losses + (self._train, ), fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.policy.get_var_values()

    def set_var_values(self, vv):
        self.policy.set_var_values(vv)
예제 #8
0
class RnnPpoOptimizer(SaveLoad):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, actionpol, trainpol,
                 ent_coef, gamma, lam, nepochs, lr, cliprange, nminibatches,
                 normrew, normadv, use_news, ext_coeff, int_coeff,
                 nsteps_per_seg, nsegs_per_env, action_dynamics,
                 train_dynamics, policy_mode, logdir, full_tensorboard_log,
                 tboard_period):
        self.action_dynamics = action_dynamics
        self.train_dynamics = train_dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.actionpol = actionpol
            self.trainpol = trainpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.policy_mode = policy_mode  # New
            self.full_tensorboard_log = full_tensorboard_log  # New
            self.tboard_period = tboard_period  # New
            self.ph_adv = tf.placeholder(tf.float32, [None, None],
                                         name='ph_adv')
            self.ph_ret = tf.placeholder(tf.float32, [None, None],
                                         name='ph_ret')
            self.ph_rews = tf.placeholder(tf.float32, [None, None],
                                          name='ph_rews')
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None],
                                            name='ph_oldnlp')
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None],
                                              name='ph_oldvpred')
            self.ph_lr = tf.placeholder(tf.float32, [], name='ph_lr')
            self.ph_cliprange = tf.placeholder(tf.float32, [],
                                               name='ph_cliprange')
            neglogpac = self.trainpol.pd.neglogp(self.trainpol.ph_ac)
            entropy = tf.reduce_mean(self.trainpol.pd.entropy())
            vpred = self.trainpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }

            self.logdir = logdir  #logger.get_dir()
            params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            if self.full_tensorboard_log:  # full Tensorboard logging
                for var in params:
                    tf.summary.histogram(var.name, var)
            if MPI.COMM_WORLD.Get_rank() == 0:
                self.summary_writer = tf.summary.FileWriter(
                    self.logdir, graph=getsess())  # New
                print("tensorboard dir : ", self.logdir)
                self.merged_summary_op = tf.summary.merge_all()  # New

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nminibatches=self.nminibatches,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.actionpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               train_dynamics=self.train_dynamics,
                               action_dynamics=self.action_dynamics,
                               policy_mode=self.policy_mode)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.trainpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.train_dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        ph_buf.extend([
            (self.trainpol.states_ph,
             resh(self.rollout.buf_states_first)),  # rnn inputs
            (self.trainpol.masks_ph, resh(self.rollout.buf_news))
        ])
        if 'err' in self.policy_mode:
            ph_buf.extend([(self.trainpol.pred_error,
                            resh(self.rollout.buf_errs))])  # New
        if 'ac' in self.policy_mode:
            ph_buf.extend([(self.trainpol.ph_ac, resh(self.rollout.buf_acs)),
                           (self.trainpol.ph_ac_first,
                            resh(self.rollout.buf_acs_first))])
        if 'pred' in self.policy_mode:
            ph_buf.extend([(self.trainpol.obs_pred,
                            resh(self.rollout.buf_obpreds))])

        # with open(os.getcwd() + "/record_instruction.txt", 'r') as rec_inst:
        #     rec_n = []
        #     rec_all_n = []
        #     while True:
        #         line = rec_inst.readline()
        #         if not line: break
        #         args = line.split()
        #         rec_n.append(int(args[0]))
        #         if len(args) > 1:
        #             rec_all_n.append(int(args[0]))
        #     if self.n_updates in rec_n and MPI.COMM_WORLD.Get_rank() == 0:
        #         print("Enter!")
        #         with open(self.logdir + '/full_log' + str(self.n_updates) + '.pk', 'wb') as full_log:
        #             import pickle
        #             debug_data = {"buf_obs" : self.rollout.buf_obs,
        #                           "buf_obs_last" : self.rollout.buf_obs_last,
        #                           "buf_acs" : self.rollout.buf_acs,
        #                           "buf_acs_first" : self.rollout.buf_acs_first,
        #                           "buf_news" : self.rollout.buf_news,
        #                           "buf_news_last" : self.rollout.buf_new_last,
        #                           "buf_rews" : self.rollout.buf_rews,
        #                           "buf_ext_rews" : self.rollout.buf_ext_rews}
        #             if self.n_updates in rec_all_n:
        #                 debug_data.update({"buf_err": self.rollout.buf_errs,
        #                                     "buf_err_last": self.rollout.buf_errs_last,
        #                                     "buf_obpreds": self.rollout.buf_obpreds,
        #                                     "buf_obpreds_last": self.rollout.buf_obpreds_last,
        #                                     "buf_vpreds": self.rollout.buf_vpreds,
        #                                     "buf_vpred_last": self.rollout.buf_vpred_last,
        #                                     "buf_states": self.rollout.buf_states,
        #                                     "buf_states_first": self.rollout.buf_states_first,
        #                                     "buf_nlps": self.rollout.buf_nlps,})
        #             pickle.dump(debug_data, full_log)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        # New
        if 'err' in self.policy_mode:
            info["error"] = np.sqrt(np.power(self.rollout.buf_errs, 2).mean())

        if self.n_updates % self.tboard_period == 0 and MPI.COMM_WORLD.Get_rank(
        ) == 0:
            if self.full_tensorboard_log:
                summary = getsess().run(self.merged_summary_op, fd)  # New
                self.summary_writer.add_summary(
                    summary, self.rollout.stats["tcount"])  # New
            for k, v in info.items():
                summary = tf.Summary(value=[
                    tf.Summary.Value(tag=k, simple_value=v),
                ])
                self.summary_writer.add_summary(summary,
                                                self.rollout.stats["tcount"])

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.trainpol.get_var_values()

    def set_var_values(self, vv):
        self.trainpol.set_var_values(vv)
예제 #9
0
class PpoOptimizer(object):
    envs = None

    def __init__(self,
                 *,
                 scope,
                 ob_space,
                 ac_space,
                 stochpol,
                 ent_coef,
                 gamma,
                 lam,
                 nepochs,
                 lr,
                 cliprange,
                 nminibatches,
                 normrew,
                 normadv,
                 use_news,
                 ext_coeff,
                 int_coeff,
                 nsteps_per_seg,
                 nsegs_per_env,
                 dynamics,
                 load=False,
                 exp_name):
        self.dynamics = dynamics
        self.load = load
        self.model_path = os.path.join('models/', exp_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self._save_freq = 50000
        self._next_save = self._save_freq
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }
            self.keep_checkpoints = 5

    def _initialize_graph(self):
        # with self.graph.as_default():
        #     self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        #     init = tf.global_variables_initializer()
        #     self.sess.run(init)
        self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        getsess().run(
            tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))

    def _load_graph(self):
        self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
        # logger.info('Loading Model for brain {}'.format(self.brain.brain_name))
        ckpt = tf.train.get_checkpoint_state(self.model_path)
        if ckpt is None:
            logger.info('The model {0} could not be found. Make '
                        'sure you specified the right '
                        '--run-id'.format(self.model_path))
        self.saver.restore(getsess(), ckpt.model_checkpoint_path)
        #Extract from checkpoint filename
        last_step = int(
            os.path.basename(
                ckpt.model_checkpoint_path).split('-')[1].split('.')[0])
        return last_step

    def save_if_ready(self):
        if MPI.COMM_WORLD.Get_rank() != 0:
            return
        steps = self.rollout.stats['tcount']
        if steps >= self._next_save:
            self.save_model(steps)
            self._next_save = steps + self._save_freq

    def save_model(self, steps):
        """
        Saves the model
        :param steps: The number of steps the model was trained for
        :return:
        """
        print("------ ** saving at step:", steps)
        # with self.graph.as_default():
        last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk'
        self.saver.save(getsess(), last_checkpoint)
        tf.train.write_graph(tf.Graph(),
                             self.model_path,
                             'raw_graph_def.pb',
                             as_text=False)

    def no_mpi_start_interaction(self, envs, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)
        last_step = 0
        if self.load:
            last_step = self._load_graph()
        else:
            self._initialize_graph()
            # bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(envs)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = envs

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)
        self.rollout.stats['tcount'] += last_step

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        last_step = 0
        if self.load:
            last_step = self._load_graph()
        else:
            if MPI.COMM_WORLD.Get_rank() == 0:
                self._initialize_graph()
            bcast_tf_vars_from_root(
                getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)
        self.rollout.stats['tcount'] += last_step

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        self.save_if_ready()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
예제 #10
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        self.global_step = tf.Variable(0, trainable=False)
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            if self.agent_num is None:
                trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                           comm=MPI.COMM_WORLD)

        else:
            if self.agent_num is None:
                if self.optim == 'adam':
                    trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
                elif self.optim == 'sgd':
                    print("using sgd")
                    print("________________________")
                    if self.decay:
                        self.decay_lr = tf.train.exponential_decay(
                            self.ph_lr,
                            self.global_step,
                            2500,
                            .96,
                            staircase=True)
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.decay_lr)
                    else:
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.ph_lr)
                elif self.optim == 'momentum':
                    print('using momentum')
                    print('________________________')
                    trainer = tf.train.MomentumOptimizer(
                        learning_rate=self.ph_lr, momentum=0.9)
        if self.agent_num is None:
            gradsandvars = trainer.compute_gradients(self.total_loss, params)
            l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
            if self.log_grads:
                for grad, var in gradsandvars:
                    tf.summary.histogram(var.name + '/gradient', l2_norm(grad))
                    tf.summary.histogram(var.name + '/value', l2_norm(var))
                    grad_mean = tf.reduce_mean(tf.abs(grad))
                    tf.summary.scalar(var.name + '/grad_mean', grad_mean)
                if self.decay:
                    tf.summary.scalar('decay_lr', self.decay_lr)
                self._summary = tf.summary.merge_all()
                tf.add_to_collection("summary_op", self._summary)
            if self.grad_clip > 0:
                grads, gradvars = zip(*gradsandvars)
                grads, _ = tf.clip_by_global_norm(grads, self.grad_clip)
                gradsandvars = list(zip(grads, gradvars))

            self._train = trainer.apply_gradients(gradsandvars,
                                                  global_step=self.global_step)
            self._updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            self._train = tf.group(self._train, self._updates)
            tf.add_to_collection("train_op", self._train)
        else:
            self._train = tf.get_collection("train_op")[0]
            if self.log_grads:
                self._summary = tf.get_collection("summary_op")[0]

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.env_ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               video_log_freq=self.video_log_freq,
                               model_save_freq=self.model_save_freq,
                               use_apples=self.use_apples,
                               multi_envs=self.multi_envs,
                               lstm=self.lstm,
                               lstm1_size=self.lstm1_size,
                               lstm2_size=self.lstm2_size,
                               depth_pred=self.depth_pred,
                               early_stop=self.early_stop,
                               aux_input=self.aux_input)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
예제 #11
0
class PpoOptimizer(object):
    envs = None

    def __init__(self,
                 *,
                 hps,
                 scope,
                 ob_space,
                 env_ob_space,
                 ac_space,
                 stochpol,
                 ent_coef,
                 gamma,
                 lam,
                 nepochs,
                 lr,
                 cliprange,
                 nminibatches,
                 normrew,
                 normadv,
                 use_news,
                 ext_coeff,
                 int_coeff,
                 nsteps_per_seg,
                 nsegs_per_env,
                 dynamics,
                 exp_name,
                 env_name,
                 video_log_freq,
                 model_save_freq,
                 use_apples,
                 agent_num=None,
                 restore_name=None,
                 multi_envs=None,
                 lstm=False,
                 lstm1_size=512,
                 lstm2_size=0,
                 depth_pred=0,
                 beta_d=.1,
                 early_stop=0,
                 aux_input=0,
                 optim='adam',
                 decay=0,
                 grad_clip=0.0,
                 log_grads=0,
                 logdir='logs'):
        self.dynamics = dynamics
        self.exp_name = exp_name
        self.env_name = env_name
        self.video_log_freq = video_log_freq
        self.model_save_freq = model_save_freq
        self.use_apples = use_apples
        self.agent_num = agent_num
        self.multi_envs = multi_envs
        self.lstm = lstm
        self.lstm1_size = lstm1_size
        self.lstm2_size = lstm2_size
        self.depth_pred = depth_pred
        self.aux_input = aux_input
        self.early_stop = early_stop
        self.optim = optim
        self.decay = decay
        self.log_grads = log_grads
        self.grad_clip = grad_clip
        if log_grads:
            self.grad_writer = tf.summary.FileWriter(logdir + '/grads/' +
                                                     exp_name)
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.env_ob_space = env_ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ent_coeff = ent_coef
            self.beta_d = beta_d

            def mask(target, mask):
                mask_h = tf.abs(mask - 1)
                return tf.stop_gradient(mask_h * target) + mask * target

            if self.agent_num is None:
                self.ph_adv = tf.placeholder(tf.float32, [None, None],
                                             name='adv')
                self.ph_ret = tf.placeholder(tf.float32, [None, None],
                                             name='ret')
                self.ph_rews = tf.placeholder(tf.float32, [None, None],
                                              name='rews')
                self.ph_oldnlp = tf.placeholder(tf.float32, [None, None],
                                                name='oldnlp')
                self.ph_oldvpred = tf.placeholder(tf.float32, [None, None],
                                                  name='oldvpred')
                self.ph_lr = tf.placeholder(tf.float32, [], name='lr')
                self.ph_cliprange = tf.placeholder(tf.float32, [],
                                                   name='cliprange')
                self.ph_gradmask = tf.placeholder(tf.float32, [None, None],
                                                  name='gradmask')
                neglogpac = mask(self.stochpol.pd.neglogp(self.stochpol.ph_ac),
                                 self.ph_gradmask)
                entropy = tf.reduce_mean(self.stochpol.pd.entropy(),
                                         name='agent_entropy')
                vpred = mask(self.stochpol.vpred, self.ph_gradmask)
                vf_loss = 0.5 * tf.reduce_mean(
                    (vpred - mask(self.ph_ret, self.ph_gradmask))**2,
                    name='vf_loss')
                ratio = tf.exp(self.ph_oldnlp - neglogpac,
                               name='ratio')  # p_new / p_old
                negadv = -mask(self.ph_adv, self.ph_gradmask)
                pg_losses1 = negadv * ratio
                pg_losses2 = negadv * tf.clip_by_value(ratio,
                                                       1.0 - self.ph_cliprange,
                                                       1.0 + self.ph_cliprange,
                                                       name='pglosses2')
                pg_loss_surr = tf.maximum(pg_losses1,
                                          pg_losses2,
                                          name='loss_surr')
                pg_loss = tf.reduce_mean(pg_loss_surr, name='pg_loss')
                ent_loss = (-ent_coef) * entropy
                if self.depth_pred:
                    depth_loss = self.stochpol.depth_loss * beta_d
                approxkl = .5 * tf.reduce_mean(
                    tf.square(neglogpac - self.ph_oldnlp), name='approxkl')
                clipfrac = tf.reduce_mean(
                    tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6),
                    name='clipfrac')

                self.total_loss = pg_loss + ent_loss + vf_loss
                if self.depth_pred:
                    self.total_loss = self.total_loss + depth_loss
                    #self.total_loss = depth_loss
                    #print("adding depth loss to total loss for optimization")
                #self.total_loss = depth_loss
                self.to_report = {
                    'tot': self.total_loss,
                    'pg': pg_loss,
                    'vf': vf_loss,
                    'ent': entropy,
                    'approxkl': approxkl,
                    'clipfrac': clipfrac
                }
                if self.depth_pred:
                    self.to_report.update({'depth_loss': depth_loss})
                tf.add_to_collection('adv', self.ph_adv)
                tf.add_to_collection('ret', self.ph_ret)
                tf.add_to_collection('rews', self.ph_rews)
                tf.add_to_collection('oldnlp', self.ph_oldnlp)
                tf.add_to_collection('oldvpred', self.ph_oldvpred)
                tf.add_to_collection('lr', self.ph_lr)
                tf.add_to_collection('cliprange', self.ph_cliprange)
                tf.add_to_collection('agent_entropy', entropy)
                tf.add_to_collection('vf_loss', vf_loss)
                tf.add_to_collection('ratio', ratio)
                tf.add_to_collection('pg_losses2', pg_losses2)
                tf.add_to_collection('loss_surr', pg_loss_surr)
                tf.add_to_collection('pg_loss', pg_loss)
                if self.depth_pred:
                    tf.add_to_collection('depth_loss', depth_loss)
                tf.add_to_collection('approxkl', approxkl)
                tf.add_to_collection('clipfrac', clipfrac)
            else:
                self.restore()

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        self.global_step = tf.Variable(0, trainable=False)
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            if self.agent_num is None:
                trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                           comm=MPI.COMM_WORLD)

        else:
            if self.agent_num is None:
                if self.optim == 'adam':
                    trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
                elif self.optim == 'sgd':
                    print("using sgd")
                    print("________________________")
                    if self.decay:
                        self.decay_lr = tf.train.exponential_decay(
                            self.ph_lr,
                            self.global_step,
                            2500,
                            .96,
                            staircase=True)
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.decay_lr)
                    else:
                        trainer = tf.train.GradientDescentOptimizer(
                            learning_rate=self.ph_lr)
                elif self.optim == 'momentum':
                    print('using momentum')
                    print('________________________')
                    trainer = tf.train.MomentumOptimizer(
                        learning_rate=self.ph_lr, momentum=0.9)
        if self.agent_num is None:
            gradsandvars = trainer.compute_gradients(self.total_loss, params)
            l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
            if self.log_grads:
                for grad, var in gradsandvars:
                    tf.summary.histogram(var.name + '/gradient', l2_norm(grad))
                    tf.summary.histogram(var.name + '/value', l2_norm(var))
                    grad_mean = tf.reduce_mean(tf.abs(grad))
                    tf.summary.scalar(var.name + '/grad_mean', grad_mean)
                if self.decay:
                    tf.summary.scalar('decay_lr', self.decay_lr)
                self._summary = tf.summary.merge_all()
                tf.add_to_collection("summary_op", self._summary)
            if self.grad_clip > 0:
                grads, gradvars = zip(*gradsandvars)
                grads, _ = tf.clip_by_global_norm(grads, self.grad_clip)
                gradsandvars = list(zip(grads, gradvars))

            self._train = trainer.apply_gradients(gradsandvars,
                                                  global_step=self.global_step)
            self._updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            self._train = tf.group(self._train, self._updates)
            tf.add_to_collection("train_op", self._train)
        else:
            self._train = tf.get_collection("train_op")[0]
            if self.log_grads:
                self._summary = tf.get_collection("summary_op")[0]

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.env_ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               video_log_freq=self.video_log_freq,
                               model_save_freq=self.model_save_freq,
                               use_apples=self.use_apples,
                               multi_envs=self.multi_envs,
                               lstm=self.lstm,
                               lstm1_size=self.lstm1_size,
                               lstm2_size=self.lstm2_size,
                               depth_pred=self.depth_pred,
                               early_stop=self.early_stop,
                               aux_input=self.aux_input)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def mask(x, grad_mask):
            if self.early_stop:
                #print("x shape: {}".format(np.shape(x)))
                #grad_mask = self.rollout.grad_mask
                #print("mask shape: {}".format(np.shape(pseudo_dones)))
                #no_grad_mask = 1 - grad_mask
                sh = np.shape(x)
                if sh[1] < np.shape(grad_mask)[1]:
                    return x
                broadcast_shape = (sh[0], sh[1]) + sh[2:]
                #print("mask shape: {}".format(broadcast_shape))
                for i in range(len(broadcast_shape) - 2):
                    #    no_grad_mask = tf.expand_dims(no_grad_mask, -1)
                    grad_mask = np.expand_dims(grad_mask, -1)
                #no_grad_mask =tf.cast(no_grad_mask, x.dtype)
                #grad_mask = tf.cast(grad_mask, x.dtype)
                #result = tf.placeholder(x.dtype, shape=broadcast_shape)
                #result = tf.stop_gradient(tf.multiply(no_grad_mask, x)) + tf.multiply(grad_mask, x)
                #print("Result size: {}".format(result.shape))
                result = np.multiply(grad_mask, x)
                return result
            else:
                return x

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        new_count = np.count_nonzero(self.rollout.buf_news)
        print(self.rollout.buf_news)
        if self.early_stop:
            print(self.rollout.grad_mask)
        print(new_count)
        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        if self.depth_pred:
            ph_buf.extend([
                (self.stochpol.ph_depths, resh(self.rollout.buf_depths)),
            ])
        if self.aux_input:
            ph_buf.extend([
                (self.stochpol.ph_vel, resh(self.rollout.buf_vels)),
                (self.stochpol.ph_prev_rew,
                 resh(self.rollout.buf_prev_ext_rews)),
                (self.stochpol.ph_prev_ac, resh(self.rollout.buf_prev_acs)),
            ])
        if self.dynamics.auxiliary_task.features_shared_with_policy:
            ph_buf.extend([
                (self.dynamics.auxiliary_task.ph_features,
                 resh(self.rollout.buf_feats)),
                (self.dynamics.auxiliary_task.ph_last_features,
                 resh(np.expand_dims(self.rollout.buf_feat_last, axis=1))),
            ])
        #print("Buff obs shape: {}".format(self.rollout.buf_obs.shape))
        #print("Buff rew shape: {}".format(self.rollout.buf_rews.shape))
        #print("Buff nlps shape: {}".format(self.rollout.buf_nlps.shape))
        #print("Buff vpreds shape: {}".format(self.rollout.buf_vpreds.shape))
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []
        #if self.lstm:
        #print("Train lstm 1 state: {}, {}".format(self.rollout.train_lstm1_c, self.rollout.train_lstm1_h))
        #if self.lstm2_size:
        #print("Train lstm2 state: {}, {}".format(self.rollout.train_lstm2_c, self.rollout.train_lstm2_h))
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                #mbenvinds = tf.convert_to_tensor(mbenvinds)
                #fd = {ph: buf[mbenvinds] if type(buf) is np.ndarray else buf.eval()[mbenvinds] for (ph, buf) in ph_buf}
                if self.early_stop:
                    grad_mask = self.rollout.grad_mask[mbenvinds]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_gradmask: grad_mask})
                else:
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                if self.lstm:
                    fd.update({
                        self.stochpol.c_in_1:
                        self.rollout.train_lstm1_c[mbenvinds, :],
                        self.stochpol.h_in_1:
                        self.rollout.train_lstm1_h[mbenvinds, :]
                    })
                if self.lstm and self.lstm2_size:
                    fd.update({
                        self.stochpol.c_in_2:
                        self.rollout.train_lstm2_c[mbenvinds, :],
                        self.stochpol.h_in_2:
                        self.rollout.train_lstm2_h[mbenvinds, :]
                    })
                if self.log_grads:
                    outs = getsess().run(
                        self._losses + (self._train, self._summary), fd)
                    losses = outs[:-2]
                    summary = outs[-1]
                    mblossvals.append(losses)
                    wandb.tensorflow.log(tf.summary.merge_all())
                    self.grad_writer.add_summary(
                        summary,
                        getsess().run(self.global_step))
                else:
                    mblossvals.append(getsess().run(
                        self._losses + (self._train, ), fd)[:-1])
        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        #print("Collecting rollout")
        self.rollout.collect_rollout()
        #print("Performing update")
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)

    def restore(self):
        # self.stochpol.vpred = tf.get_collection("vpred")[0]
        # self.stochpol.a_samp = tf.get_collection("a_samp")[0]
        # self.stochpol.entropy = tf.get_collection("entropy")[0]
        # self.stochpol.nlp_samp = tf.get_collection("nlp_samp")[0]
        # self.stochpol.ph_ob = tf.get_collection("ph_ob")[0]
        self.ph_adv = tf.get_collection("adv")[0]
        self.ph_ret = tf.get_collection("ret")[0]
        self.ph_rews = tf.get_collection("rews")[0]
        self.ph_oldnlp = tf.get_collection("oldnlp")[0]
        self.ph_oldvpred = tf.get_collection("oldvpred")[0]
        self.ph_lr = tf.get_collection("lr")[0]
        self.ph_cliprange = tf.get_collection("cliprange")[0]
        neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
        entropy = tf.get_collection("agent_entropy")[0]
        vpred = self.stochpol.vpred
        vf_loss = tf.get_collection("vf_loss")[0]
        ratio = tf.get_collection("ratio")[0]
        negadv = -self.ph_adv
        pg_losses1 = negadv * ratio
        pg_losses2 = tf.get_collection("pg_losses2")[0]
        pg_loss_surr = tf.get_collection("loss_surr")[0]
        pg_loss = tf.get_collection("pg_loss")[0]
        ent_loss = (-self.ent_coeff) * entropy
        approxkl = tf.get_collection("approxkl")[0]
        clipfrac = tf.get_collection("clipfrac")[0]

        self.total_loss = pg_loss + ent_loss + vf_loss
        self.to_report = {
            'tot': self.total_loss,
            'pg': pg_loss,
            'vf': vf_loss,
            'ent': entropy,
            'approxkl': approxkl,
            'clipfrac': clipfrac
        }
예제 #12
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol,
                 ent_coef, gamma, lam, nepochs, lr, cliprange,
                 nminibatches,
                 normrew, normadv, use_news, ext_coeff, int_coeff,
                 nsteps_per_seg, nsegs_per_env, dynamics, flow_lr=None, update_periods=None):

        self.dynamics = dynamics
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            # 
            self.flow_lr = flow_lr
            self.update_periods = update_periods
            self.flow_lr_scale = self.flow_lr / self.lr

            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = - self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (- ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            print('===========================')
            print('Initializing PpoOptimizer...')
            print('flow lr', self.flow_lr)

            ### Flow-based intrinsic module part ###
            if update_periods is not None:
                print('Update periods in PpoOptimizer: ', update_periods)

                self.target_flow = int(update_periods.split(':')[0])
                self.target_all = int(update_periods.split(':')[1])
                self.period = self.target_flow + self.target_all
                self.target = self.target_flow
                self.update_flow = True
                
                print('update flow:   ', self.update_flow)
                print('target:        ', self.target)
            
            print('===========================')

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy,
                              'approxkl': approxkl, 'clipfrac': clipfrac}

    def caculate_number_parameters(self, params):
        total_parameters = 0
        for variable in params:
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters

        print('Number of total parameters: ', total_parameters)


    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))
        
        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        self.caculate_number_parameters(params)

        flow_params = [v for v in params if 'flow' in v.name]
        other_params = [v for v in params if 'flow' not in v.name]

        print('length of flow params: ', len(flow_params))
        print('length of agent params: ', len(other_params))
        
        trainer_flow = tf.train.AdamOptimizer(learning_rate=self.flow_lr)
        trainer_agent = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        grads = tf.gradients(self.total_loss, flow_params + other_params)
        grads_flow = grads[:len(flow_params)]
        grads_agent = grads[len(flow_params):]

        train_flow = trainer_flow.apply_gradients(zip(grads_flow, flow_params))
        train_agent = trainer_agent.apply_gradients(zip(grads_agent, other_params))

        self._train = tf.group(train_flow, train_agent)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
            self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])
        mblossvals = []

        if self.update_periods is not None:
            for _ in range(self.nepochs):
                np.random.shuffle(envinds)
                for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

                    if self.update_flow: 
                        mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])
                    else:
                        mblossvals.append(getsess().run(self._losses + (self._train_agent,), fd)[:-1])
        else:
            for _ in range(self.nepochs):
                np.random.shuffle(envinds)
                for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                    fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

                    mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])


        mblossvals = [mblossvals[0]]
        info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        if self.update_periods is not None:
            if self.n_updates % self.period == self.target:
                self.update_flow = not self.update_flow
                
                if self.update_flow == True:
                    self.target = self.target_flow
                else:
                    self.target = self.target_all

                print('--------- Reach target ---------')
                print('Update flow:  ', self.update_flow)
                print('New target:   ', self.target)

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
예제 #13
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 dynamics, exp_name, env_name, to_eval):
        self.dynamics = dynamics
        self.exp_name = exp_name
        self.env_name = env_name
        self.to_eval = to_eval
        with tf.variable_scope(scope):
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            self.pd_logstd_min = tf.math.reduce_min(self.stochpol.pd.logstd)
            self.pd_logstd_max = tf.math.reduce_max(self.stochpol.pd.logstd)
            self.pd_std_min = tf.math.reduce_min(self.stochpol.pd.std)
            self.pd_std_max = tf.math.reduce_max(self.stochpol.pd.std)
            self.pd_mean_min = tf.math.reduce_min(self.stochpol.pd.mean)
            self.pd_mean_max = tf.math.reduce_max(self.stochpol.pd.mean)
            self.stat_report = {
                'pd_logstd_max': self.pd_logstd_max,
                'pd_logstd_min': self.pd_logstd_min,
                'pd_std_max': self.pd_std_max,
                'pd_std_min': self.pd_std_min,
                'pd_mean_max': self.pd_mean_max,
                'pd_mean_min': self.pd_mean_min
            }

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }  #, 'pd_logstd':pd_logstd, 'pd_std':pd_std, 'pd_mean':pd_mean}

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        #gvs = trainer.compute_gradients(self.total_loss, params)
        #self.gshape = gs
        #gs = [g for (g,v) in gvs]
        #self.normg = tf.linalg.global_norm(gs)
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs i]
        #self.nnormg = tf.linalg.global_norm(new_g)
        def ClipIfNotNone(grad):
            return tf.clip_by_value(grad, -25.0,
                                    25.0) if grad is not None else grad

        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        #gs = [g for (g,v) in gradsandvars]
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs if g is not None]
        gradsandvars = [(ClipIfNotNone(g), v) for g, v in gradsandvars]

        #new_g = [g for (g,v) in gradsandvars]
        #self.nnormg = tf.linalg.global_norm(new_g)
        #gradsandvars = [(ClipIfNotNone(grad), var) for grad, var in gradsandvars]
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        print('-------NENVS-------', self.nenvs)
        self.nlump = nlump
        print('----------NLUMPS-------', self.nlump)
        self.lump_stride = nenvs // self.nlump
        print('-------LSTRIDE----', self.lump_stride)
        print('--------OBS SPACE ---------', self.ob_space)
        print('-------------AC SPACE-----', self.ac_space)
        #assert 1==2
        print('-----BEFORE VEC ENV------')
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]
        print('-----AFTER VEC ENV------')
        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               to_eval=self.to_eval)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
        self.saver = tf.train.Saver(max_to_keep=5)

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
            if self.rollout.current_max is not None else 0,
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics.last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []
        #gradvals = []
        statvals = []
        self.stat_names, self.stats = zip(*list(self.stat_report.items()))
        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])
                #print(fd)
                statvals.append(tf.get_default_session().run(self.stats, fd))
                #_ = getsess().run(self.gradsandvars, fd)[:-1]
                #assert 1==2
                #gradvals.append(getsess().run(self.grads, fd))

        mblossvals = [mblossvals[0]]
        statvals = [statvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info.update(
            zip(['opt_' + ln for ln in self.stat_names],
                np.mean([statvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        info[
            'video_log'] = self.rollout.buf_obs if self.n_updates % 50 == 0 else None
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        print('---------INSIDE STEP--------------')
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)

    def save_model(self, logdir, exp_name, global_step=0):
        path = osp.join(logdir, exp_name + ".ckpt")
        self.saver.save(getsess(), path, global_step=global_step)
        print("Model saved to path", path)

    def restore_model(self, logdir, exp_name):
        path = logdir
        self.saver.restore(getsess(), path)
        print("Model Restored from path", path)
예제 #14
0
    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)

        #gvs = trainer.compute_gradients(self.total_loss, params)
        #self.gshape = gs
        #gs = [g for (g,v) in gvs]
        #self.normg = tf.linalg.global_norm(gs)
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs i]
        #self.nnormg = tf.linalg.global_norm(new_g)
        def ClipIfNotNone(grad):
            return tf.clip_by_value(grad, -25.0,
                                    25.0) if grad is not None else grad

        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        #gs = [g for (g,v) in gradsandvars]
        #new_g = [tf.clip_by_norm(g,10.0) for g in gs if g is not None]
        gradsandvars = [(ClipIfNotNone(g), v) for g, v in gradsandvars]

        #new_g = [g for (g,v) in gradsandvars]
        #self.nnormg = tf.linalg.global_norm(new_g)
        #gradsandvars = [(ClipIfNotNone(grad), var) for grad, var in gradsandvars]
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        print('-------NENVS-------', self.nenvs)
        self.nlump = nlump
        print('----------NLUMPS-------', self.nlump)
        self.lump_stride = nenvs // self.nlump
        print('-------LSTRIDE----', self.lump_stride)
        print('--------OBS SPACE ---------', self.ob_space)
        print('-------------AC SPACE-----', self.ac_space)
        #assert 1==2
        print('-----BEFORE VEC ENV------')
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]
        print('-----AFTER VEC ENV------')
        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics,
                               exp_name=self.exp_name,
                               env_name=self.env_name,
                               to_eval=self.to_eval)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()
        self.saver = tf.train.Saver(max_to_keep=5)
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, hps, scope, ob_space, ac_space, stochpol,
                 ent_coef, gamma, gamma_ext, lam, nepochs, lr, cliprange,
                 nminibatches,
                 normrew, normadv, use_news, ext_coeff, int_coeff,
                 nsteps_per_seg, nsegs_per_env, dynamics):
        self.dynamics = dynamics
        with tf.variable_scope(scope):
            self.hps = hps
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.gamma_ext = gamma_ext
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff

            self.ph_adv = tf.placeholder(tf.float32, [None, None])        
     
            self.ph_ret_int = tf.placeholder(tf.float32, [None, None])            
            self.ph_ret_ext = tf.placeholder(tf.float32, [None, None])            
            self.ph_ret = tf.placeholder(tf.float32, [None, None])

            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            
            vpred = self.stochpol.vpred
            
            if hps['num_vf']==2: 
                # Separate vf_loss for intrinsic and extrinsic rewards
                vf_loss_int = 0.5 * tf.reduce_mean(tf.square(self.stochpol.vpred_int - self.ph_ret_int))
                vf_loss_ext = 0.5 * tf.reduce_mean(tf.square(self.stochpol.vpred_ext - self.ph_ret_ext))
                vf_loss = vf_loss_int + vf_loss_ext
            else:
                vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)

            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = - self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (- ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy,
                              'approxkl': approxkl, 'clipfrac': clipfrac}

    def start_interaction(self, env_fns, dynamics, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
            l in range(self.nlump)]

        self.rollout = Rollout(hps=self.hps, ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_advs_int = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_advs_ext = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets_int = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets_ext = np.zeros((nenvs, self.rollout.nsteps), np.float32)


        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()


    def update(self):
        # Rewards normalization
        # if self.normrew:
        #     rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
        #     rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
        #     self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
        #     rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        
        # Intrinsic Rewards Normalization
        if self.normrew:
            rffs_int = np.array([self.rff.update(rew) for rew in self.rollout.buf_int_rews.T])
            self.rff_rms.update(rffs_int.ravel())        
            int_rews = self.rollout.buf_int_rews / np.sqrt(self.rff_rms.var)
        else:
            int_rews = np.copy(self.rollout.buf_int_rews)
        
        mean_int_rew = np.mean(int_rews)
        max_int_rew = np.max(int_rews)
        
        # Do not normalize extrinsic rewards 
        ext_rews = self.rollout.buf_ext_rews

        nsteps = self.rollout.nsteps

        # If separate value fcn are used
        if self.hps['num_vf']==2:
            #Calculate intrinsic returns and advantages.
            lastgaelam = 0
            for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
                if self.use_news:
                    nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
                else:
                    nextnew = 0 # No dones for intrinsic rewards with self.use_news=False
                nextvals = self.rollout.buf_vpreds_int[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_int_last
                nextnotnew = 1 - nextnew
                delta = int_rews[:, t] + self.gamma * nextvals * nextnotnew - self.rollout.buf_vpreds_int[:, t]
                self.buf_advs_int[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
            self.buf_rets_int[:] = self.buf_advs_int + self.rollout.buf_vpreds_int

            #Calculate extrinsic returns and advantages.
            lastgaelam = 0

            for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
                nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
                nextvals = self.rollout.buf_vpreds_ext[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_ext_last
                nextnotnew = 1 - nextnew
                delta = ext_rews[:, t] + self.gamma_ext * nextvals * nextnotnew - self.rollout.buf_vpreds_ext[:, t]
                self.buf_advs_ext[:, t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
            self.buf_rets_ext[:] = self.buf_advs_ext + self.rollout.buf_vpreds_ext
            
            #Combine the extrinsic and intrinsic advantages.
            self.buf_advs = self.int_coeff*self.buf_advs_int + self.ext_coeff*self.buf_advs_ext
        else:
            #Calculate mixed intrinsic and extrinsic returns and advantages.
            rews = self.rollout.buf_rews = self.rollout.reward_fun(int_rew=int_rews, ext_rew=ext_rews)            
            lastgaelam = 0
            for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
                nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
                nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
                nextnotnew = 1 - nextnew
                delta = rews[:, t] + self.gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
                self.buf_advs[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
            self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds
        
        info = dict(
            # advmean=self.buf_advs.mean(),
            # advstd=self.buf_advs.std(),  
            recent_best_ext_ret=self.rollout.current_max,
            recent_best_eplen = self.rollout.current_minlen,
            recent_worst_eplen = self.rollout.current_maxlen   
        )

        if self.hps['num_vf'] ==2:
            info['retmean_int']=self.buf_rets_int.mean()
            info['retmean_ext']=self.buf_rets_ext.mean()
            info['retstd_int']=self.buf_rets_int.std()
            info['retstd_ext']=self.buf_rets_ext.std()
            info['vpredmean_int']=self.rollout.buf_vpreds_int.mean()
            info['vpredmean_ext']=self.rollout.buf_vpreds_ext.mean()
            info['vpredstd_int']=self.rollout.buf_vpreds_int.std()
            info['vpredstd_ext']=self.rollout.buf_vpreds_ext.std()
            info['ev_int']=explained_variance(self.rollout.buf_vpreds_int.ravel(), self.buf_rets_int.ravel())            
            info['ev_ext']=explained_variance(self.rollout.buf_vpreds_ext.ravel(), self.buf_rets_ext.ravel())            
            info['rew_int_mean']=mean_int_rew
            info['recent_best_int_rew']=max_int_rew
        else:
            # info['retmean']=self.buf_rets.mean()
            # info['retstd']=self.buf_rets.std()
            # info['vpredmean']=self.rollout.buf_vpreds.mean()
            # info['vpredstd']=self.rollout.buf_vpreds.std()
            info['rew_mean']=np.mean(self.rollout.buf_rews)
            info['eplen_std']=np.std(self.rollout.statlists['eplen'])            
            info['eprew_std']=np.std(self.rollout.statlists['eprew'])
            # info['ev']=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel())            

        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret
            info['best_eplen'] = self.rollout.best_eplen

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])
        
        #Create feed_dict for optimization.
        ph_buf = [
                (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
                (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
                (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
                (self.ph_adv, resh(self.buf_advs)),
                ]

        if self.hps['num_vf']==2:
            ph_buf.extend([                
                (self.ph_ret_int, resh(self.buf_rets_int)),
                (self.ph_ret_ext, resh(self.buf_rets_ext)),
            ])       
        else:
            ph_buf.extend([
                (self.ph_rews, resh(self.rollout.buf_rews)),
                (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
                (self.ph_ret, resh(self.buf_rets)),
            ])

        ph_buf.extend([
            (self.dynamics.last_ob,
             self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
        ])

        #Optimizes on current data for several epochs.
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}                
                fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
                mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1])

        mblossvals = [mblossvals[0]]
        # info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
        # info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        # info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        # info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 unity, dynamics_list):
        self.dynamics_list = dynamics_list
        with tf.variable_scope(scope):
            self.unity = unity
            self.use_recorder = True
            self.n_updates = 0
            self.scope = scope
            self.ob_space = ob_space
            self.ac_space = ac_space
            self.stochpol = stochpol
            self.nepochs = nepochs
            self.lr = lr
            self.cliprange = cliprange
            self.nsteps_per_seg = nsteps_per_seg
            self.nsegs_per_env = nsegs_per_env
            self.nminibatches = nminibatches
            self.gamma = gamma
            self.lam = lam
            self.normrew = normrew
            self.normadv = normadv
            self.use_news = use_news
            self.ext_coeff = ext_coeff
            self.int_coeff = int_coeff
            self.ph_adv = tf.placeholder(tf.float32, [None, None])
            self.ph_ret = tf.placeholder(tf.float32, [None, None])
            self.ph_rews = tf.placeholder(tf.float32, [None, None])
            self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
            self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
            self.ph_lr = tf.placeholder(tf.float32, [])
            self.ph_cliprange = tf.placeholder(tf.float32, [])
            neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
            entropy = tf.reduce_mean(self.stochpol.pd.entropy())
            vpred = self.stochpol.vpred

            vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret)**2)
            ratio = tf.exp(self.ph_oldnlp - neglogpac)  # p_new / p_old
            negadv = -self.ph_adv
            pg_losses1 = negadv * ratio
            pg_losses2 = negadv * tf.clip_by_value(
                ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
            pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
            pg_loss = tf.reduce_mean(pg_loss_surr)
            ent_loss = (-ent_coef) * entropy
            approxkl = .5 * tf.reduce_mean(
                tf.square(neglogpac - self.ph_oldnlp))
            clipfrac = tf.reduce_mean(
                tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))

            self.total_loss = pg_loss + ent_loss + vf_loss
            self.to_report = {
                'tot': self.total_loss,
                'pg': pg_loss,
                'vf': vf_loss,
                'ent': entropy,
                'approxkl': approxkl,
                'clipfrac': clipfrac
            }

    def start_interaction(self, env_fns, dynamics_list, nlump=2):
        self.loss_names, self._losses = zip(*list(self.to_report.items()))

        params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if MPI.COMM_WORLD.Get_size() > 1:
            trainer = MpiAdamOptimizer(learning_rate=self.ph_lr,
                                       comm=MPI.COMM_WORLD)
        else:
            trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
        gradsandvars = trainer.compute_gradients(self.total_loss, params)
        self._train = trainer.apply_gradients(gradsandvars)

        if MPI.COMM_WORLD.Get_rank() == 0:
            getsess().run(
                tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        bcast_tf_vars_from_root(
            getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]
        if self.unity:
            for i in tqdm(range(int(nenvs * 2.5 + 10))):
                time.sleep(1)
            print('... long overdue sleep ends now')
            sys.stdout.flush()

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics_list=dynamics_list)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(
            advmean=self.buf_advs.mean(),
            advstd=self.buf_advs.std(),
            retmean=self.buf_rets.mean(),
            retstd=self.buf_rets.std(),
            vpredmean=self.rollout.buf_vpreds.mean(),
            vpredstd=self.rollout.buf_vpreds.std(),
            ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                  self.buf_rets.ravel()),
            rew_mean=np.mean(self.rollout.buf_rews),
            recent_best_ext_ret=self.rollout.current_max
            if self.rollout.current_max is not None else 0,
        )
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        # store images for debugging
        # from PIL import Image
        # if not os.path.exists('logs/images/'):
        #         os.makedirs('logs/images/')
        # for i in range(self.rollout.buf_obs_last.shape[0]):
        #     obs = self.rollout.buf_obs_last[i][0]
        #     Image.fromarray((obs*255.).astype(np.uint8)).save('logs/images/%04d.png'%i)

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        def resh(x):
            if self.nsegs_per_env == 1:
                return x
            sh = x.shape
            return x.reshape((sh[0] * self.nsegs_per_env,
                              self.nsteps_per_seg) + sh[2:])

        ph_buf = [
            (self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
            (self.ph_rews, resh(self.rollout.buf_rews)),
            (self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
            (self.ph_oldnlp, resh(self.rollout.buf_nlps)),
            (self.stochpol.ph_ob, resh(self.rollout.buf_obs)),
            (self.ph_ret, resh(self.buf_rets)),
            (self.ph_adv, resh(self.buf_advs)),
        ]
        ph_buf.extend([(self.dynamics_list[0].last_ob,
                        self.rollout.buf_obs_last.reshape([
                            self.nenvs * self.nsegs_per_env, 1,
                            *self.ob_space.shape
                        ]))])
        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]
                fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
                fd.update({
                    self.ph_lr: self.lr,
                    self.ph_cliprange: self.cliprange
                })
                mblossvals.append(getsess().run(self._losses + (self._train, ),
                                                fd)[:-1])

        mblossvals = [mblossvals[0]]
        info.update(
            zip(['opt_' + ln for ln in self.loss_names],
                np.mean([mblossvals[0]], axis=0)))
        info["rank"] = MPI.COMM_WORLD.Get_rank()
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = MPI.COMM_WORLD.Get_size(
        ) * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)
예제 #17
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 dynamics):
        self.dynamics = dynamics
        self.use_recorder = True
        self.n_updates = 0
        self.scope = scope
        self.ob_space = ob_space
        self.ac_space = ac_space
        self.stochpol = stochpol
        self.nepochs = nepochs
        self.lr = lr
        self.cliprange = cliprange
        self.nsteps_per_seg = nsteps_per_seg
        self.nsegs_per_env = nsegs_per_env
        self.nminibatches = nminibatches
        self.gamma = gamma
        self.lam = lam
        self.normrew = normrew
        self.normadv = normadv
        self.use_news = use_news
        self.ent_coef = ent_coef
        self.ext_coeff = ext_coeff
        self.int_coeff = int_coeff

    def start_interaction(self, env_fns, dynamics, nlump=2):
        param_list = self.stochpol.param_list + self.dynamics.param_list + self.dynamics.auxiliary_task.param_list  # copy a link, not deepcopy.
        self.optimizer = torch.optim.Adam(param_list, lr=self.lr)
        self.optimizer.zero_grad()

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        to_report = {
            'total': 0.0,
            'pg': 0.0,
            'vf': 0.0,
            'ent': 0.0,
            'approxkl': 0.0,
            'clipfrac': 0.0,
            'aux': 0.0,
            'dyn_loss': 0.0,
            'feat_var': 0.0
        }

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]

                acs = self.rollout.buf_acs[mbenvinds]
                rews = self.rollout.buf_rews[mbenvinds]
                vpreds = self.rollout.buf_vpreds[mbenvinds]
                nlps = self.rollout.buf_nlps[mbenvinds]
                obs = self.rollout.buf_obs[mbenvinds]
                rets = self.buf_rets[mbenvinds]
                advs = self.buf_advs[mbenvinds]
                last_obs = self.rollout.buf_obs_last[mbenvinds]

                lr = self.lr
                cliprange = self.cliprange

                self.stochpol.update_features(obs, acs)
                self.dynamics.auxiliary_task.update_features(obs, last_obs)
                self.dynamics.update_features(obs, last_obs)

                feat_loss = torch.mean(self.dynamics.auxiliary_task.get_loss())
                dyn_loss = torch.mean(self.dynamics.get_loss())

                acs = torch.tensor(flatten_dims(acs, len(self.ac_space.shape)))
                neglogpac = self.stochpol.pd.neglogp(acs)
                entropy = torch.mean(self.stochpol.pd.entropy())
                vpred = self.stochpol.vpred
                vf_loss = 0.5 * torch.mean(
                    (vpred.squeeze() - torch.tensor(rets))**2)

                nlps = torch.tensor(flatten_dims(nlps, 0))
                ratio = torch.exp(nlps - neglogpac.squeeze())

                advs = flatten_dims(advs, 0)
                negadv = torch.tensor(-advs)
                pg_losses1 = negadv * ratio
                pg_losses2 = negadv * torch.clamp(
                    ratio, min=1.0 - cliprange, max=1.0 + cliprange)
                pg_loss_surr = torch.max(pg_losses1, pg_losses2)
                pg_loss = torch.mean(pg_loss_surr)
                ent_loss = (-self.ent_coef) * entropy

                approxkl = 0.5 * torch.mean((neglogpac - nlps)**2)
                clipfrac = torch.mean(
                    (torch.abs(pg_losses2 - pg_loss_surr) > 1e-6).float())
                feat_var = torch.std(self.dynamics.auxiliary_task.features)

                total_loss = pg_loss + ent_loss + vf_loss + feat_loss + dyn_loss

                total_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                to_report['total'] += total_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['pg'] += pg_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['vf'] += vf_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['ent'] += ent_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['approxkl'] += approxkl.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['clipfrac'] += clipfrac.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['feat_var'] += feat_var.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['aux'] += feat_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['dyn_loss'] += dyn_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)

        info.update(to_report)
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = self.rollout.nsteps * self.nenvs / (
            tnow - self.t_last_update)  # MPI.COMM_WORLD.Get_size() *
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)