def learn(policy, env, test_env, seed, master_ts = 1, worker_ts = 8, cell = 256,
          ent_coef = 0.01, vf_coef = 0.5, max_grad_norm = 0.5, lr = 7e-4,
          alpha = 0.99, epsilon = 1e-5, total_timesteps = int(80e6), lrschedule = 'linear',
          log_interval = 10, gamma = 0.99, load_path="saved_nets-data/hrl_a2c/%s/data"%start_time,
          algo='regular', beta=1e-3):

    tf.reset_default_graph()
    set_global_seeds(seed)
    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    print(str(nenvs)+"------------"+str(ob_space)+"-----------"+str(ac_space))
    max_grad_norm_tune = max_grad_norm
    max_grad_norm_tune_env_list = ['BreakoutNoFrameskip-v4', 'MsPacmanNoFrameskip-v4']
    global  env_name
    if env_name in max_grad_norm_tune_env_list:
        print('tune max grad norm')
        max_grad_norm_tune = 1.0
    model = Model(policy = policy, ob_space = ob_space, ac_space = ac_space, nenvs = nenvs, master_ts=master_ts, worker_ts=worker_ts,
                  ent_coef = ent_coef, vf_coef = vf_coef, max_grad_norm = max_grad_norm_tune, lr = lr, cell = cell,
                  alpha = alpha, epsilon = epsilon, total_timesteps = total_timesteps, lrschedule = lrschedule,
                  algo=algo, beta=beta)
    try:
        model.load(load_path)
    except Exception as e:
        print("no data to load!!"+str(e))
    runner = Runner(env = env, model = model, nsteps=master_ts*worker_ts, gamma=gamma)
    test_runner = test_Runner(env = test_env, model = model)

    tf.get_default_graph().finalize()
    nbatch = nenvs * master_ts * worker_ts
    tstart = time.time()
    reward_list = []
    prev_r = 0.
    exp_coef = .1
    for update in range(1, total_timesteps//nbatch+1):
        b_obs, b_whs, states, b_rewards, b_wmasks, b_actions, b_values = runner.run()
        tloss, value_loss, policy_loss, policy_entropy = model.train(b_obs, b_whs, states, b_rewards, b_wmasks, b_actions, b_values)
        nseconds = time.time()-tstart
        fps = int((update*nbatch)/nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(b_values, b_rewards)
            logger.record_tabular("fps", fps)
            logger.record_tabular("tloss", float(tloss))
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("policy_loss", float(policy_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.dump_tabular()
            if update % (200*log_interval) == 0:
                save_th = update//(200*log_interval)
                model.save("saved_nets-data/%s/hrl_a2c/%s/%s/data" % (env_name, start_time, save_th))
                episode_r = exp_coef*(test_runner.run()) + (1-exp_coef)*prev_r
                prev_r = np.copy(episode_r)
                reward_list.append(episode_r)
                logger.record_tabular('episode_r', float(episode_r))
                logger.dump_tabular()
    env.close()
    return reward_list
Beispiel #2
0
Datei: vpg.py Projekt: sisl/MPHRL
    def learn(self, experience):
        trajectories = experience_to_traj(experience)
        returns = get_returns(trajectories)
        time_steps, obvs, acts, rews, next_obvs, dones, internal_states = zip(
            *experience)
        advs = self.calculate_advantage(returns=returns, observations=obvs)
        _, self.bl_loss, self.baseline = self.sess.run(
            [self.train_bl_op, self.bl_loss_tr, self.bl_tr],
            feed_dict={
                self.obv_ph: obvs,
                self.bl_target_ph: returns
            })

        self.next_s_probs = self.calculate_next_s_probs(experience)

        _, self.belief_loss = self.sess.run(
            [self.train_belief_op, self.belief_loss_tr], feed_dict={
                self.obv_ph: obvs,
                self.next_s_probs_ph: self.next_s_probs
            })

        self.fd = {
            self.obv_ph: obvs,
            self.act_ph: acts,
            self.adv_ph: advs,
            self.bl_target_ph: returns,
            self.next_s_probs_ph: self.next_s_probs
        }

        _, self.policy_loss, belief_prob, mean1, std1 = self.sess.run(
            [self.train_policy_op, self.pi_loss_tr,
             self.belief_prob_target, self.weighted_action_means,
             self.action_std], feed_dict=self.fd)
        mean2, std2 = self.sess.run([self.weighted_action_means,
                                     self.action_std], feed_dict=self.fd)
        assert np.allclose(np.sum(belief_prob, axis=1), 1)

        if self.t % c.update_master_interval == 0:
            print('Updating Belief Network...')
            self.sess.run(self.assign_raw_belief_op)
            print('Updated')

        self.results['baseline_losses'][self.t] = self.bl_loss
        self.results['policy_losses'][self.t] = self.policy_loss
        self.results['belief_losses'][self.t] = self.belief_loss
        self.results['explained_variance'][self.t] = explained_variance(
            self.baseline, returns)
        self.results['policy_kl_divergence'][self.t] = self.sess.run(
            self.kl_divergence, feed_dict={
                self.weighted_action_means_ph1: mean1,
                self.weighted_action_means_ph2: mean2,
                self.action_std_ph1: std1,
                self.action_std_ph2: std2
            })
Beispiel #3
0
    def train(self, states, actions, rewards, dones, values, log_probs,
              next_values):
        returns = self.get_gae(rewards, values.copy(), next_values, dones)
        values = np.vstack(
            values)  # .reshape((len(values[0]) * self.n_workers,))
        advs = returns - values
        advs = (advs - advs.mean(1).reshape((-1, 1))) / (advs.std(1).reshape(
            (-1, 1)) + 1e-8)
        for epoch in range(self.epochs):
            for state, action, q_value, adv, old_value, old_log_prob in self.choose_mini_batch(
                    states, actions, returns, advs, values, log_probs):
                state = torch.ByteTensor(state).permute([0, 3, 1,
                                                         2]).to(self.device)
                action = torch.Tensor(action).to(self.device)
                adv = torch.Tensor(adv).to(self.device)
                q_value = torch.Tensor(q_value).to(self.device)
                old_value = torch.Tensor(old_value).to(self.device)
                old_log_prob = torch.Tensor(old_log_prob).to(self.device)

                dist, value = self.current_policy(state)
                entropy = dist.entropy().mean()
                new_log_prob = self.calculate_log_probs(
                    self.current_policy, state, action)
                ratio = (new_log_prob - old_log_prob).exp()
                actor_loss = self.compute_ac_loss(ratio, adv)

                clipped_value = old_value + torch.clamp(
                    value.squeeze() - old_value, -self.epsilon, self.epsilon)
                clipped_v_loss = (clipped_value - q_value).pow(2)
                unclipped_v_loss = (value.squeeze() - q_value).pow(2)
                critic_loss = 0.5 * torch.max(clipped_v_loss,
                                              unclipped_v_loss).mean()

                total_loss = critic_loss + actor_loss - 0.01 * entropy
                self.optimize(total_loss)

        return total_loss.item(), entropy.item(), \
               explained_variance(values.reshape((len(returns[0]) * self.n_workers,)),
                                  returns.reshape((len(returns[0]) * self.n_workers,)))
    def update(self):

        summary = tf.Summary()

        #Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(safemean(list(self.I.statlists["eprew"])))
        self.ep_rews.append(eprews[0])

        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            # logger.info(f"All scores {sorted(self.scores)}")
            logger.info(f"Experiment name {self.exp_name}")
            summary.value.add(tag='Episode_mean_reward', simple_value=eprews[0])


        # Normalize intrinsic rewards.
        rffs_int = np.array([self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = safemean(rews_int)
        self.max_int_rew = np.max(rews_int)

        # Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int)

        # Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
            if self.use_news:
                nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
            else:
                nextnew = 0.0 # No dones for intrinsic reward.
            nextvals = self.I.buf_vpreds_int[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last
            nextnotnew = 1 - nextnew
            delta = rews_int[:, t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:, t]
            self.I.buf_advs_int[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        # Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
            nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
            # Use dones for extrinsic reward.
            nextvals = self.I.buf_vpreds_ext[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last
            nextnotnew = 1 - nextnew
            delta = rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:, t]
            self.I.buf_advs_ext[:, t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        # Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = self.int_coeff*self.I.buf_advs_int + self.ext_coeff*self.I.buf_advs_ext

        # Collects info for reporting.
        info = dict(
            advmean = self.I.buf_advs.mean(),
            advstd  = self.I.buf_advs.std(),
            retintmean = rets_int.mean(), # previously retmean
            retintstd  = rets_int.std(), # previously retstd
            retextmean = rets_ext.mean(), # previously not there
            retextstd  = rets_ext.std(), # previously not there
            rewintmean_unnorm = rewmean, # previously rewmean
            rewintmax_unnorm = rewmax, # previously not there
            rewintmean_norm = self.mean_int_rew, # previously rewintmean
            rewintmax_norm = self.max_int_rew, # previously rewintmax
            rewintstd_unnorm  = rewstd, # previously rewstd
            vpredintmean = self.I.buf_vpreds_int.mean(), # previously vpredmean
            vpredintstd  = self.I.buf_vpreds_int.std(), # previously vrpedstd
            vpredextmean = self.I.buf_vpreds_ext.mean(), # previously not there
            vpredextstd  = self.I.buf_vpreds_ext.std(), # previously not there
            ev_int = np.clip(explained_variance(self.I.buf_vpreds_int.ravel(), rets_int.ravel()), -1, None),
            ev_ext = np.clip(explained_variance(self.I.buf_vpreds_ext.ravel(), rets_ext.ravel()), -1, None),
            rooms = SemicolonList(self.rooms),
            n_rooms = len(self.rooms),
            best_ret = self.best_ret,
            reset_counter = self.I.reset_counter,
            max_table = self.stochpol.max_table
        )

        info[f'mem_available'] = psutil.virtual_memory().available

        to_record = {'acs': self.I.buf_acs,
                     'rews_int': self.I.buf_rews_int,
                     'rews_int_norm': rews_int,
                     'rews_ext': self.I.buf_rews_ext,
                     'vpred_int': self.I.buf_vpreds_int,
                     'vpred_ext': self.I.buf_vpreds_ext,
                     'adv_int': self.I.buf_advs_int,
                     'adv_ext': self.I.buf_advs_ext,
                     'ent': self.I.buf_ent,
                     'ret_int': rets_int,
                     'ret_ext': rets_ext,
                     }
        if self.I.venvs[0].record_obs:
            to_record['obs'] = self.I.buf_obs[None]

        # Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches

        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.stochpol.ph_ret_ext, rets_ext),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        verbose = False
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info("buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i " % (
                    str(self.I.buf_advs.shape),
                    samples, samples//self.nminibatches,
                    samples*self.comm_train_size, samples*self.comm_train_size//self.nminibatches))
            logger.info(" "*6 + fmt_row(13, self.loss_names))


        epoch = 0
        start = 0
        # Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph : buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr : self.lr, self.ph_cliprange : self.cliprange})
            all_obs = np.concatenate([self.I.buf_obs[None][mbenvinds], self.I.buf_ob_last[None][mbenvinds, None]], 1)
            fd[self.stochpol.ph_ob[None]] = all_obs
            assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \
                [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)]
            fd.update({self.stochpol.ph_mean:self.stochpol.ob_rms.mean, self.stochpol.ph_std:self.stochpol.ob_rms.var**0.5})

            ret = tf.get_default_session().run(self._losses+[self._train], feed_dict=fd)[:-1]

            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            # Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop('maxkl')
            lossdict = dict_gather(self.comm_train, lossdict, op='mean')
            maxmaxkl = dict_gather(self.comm_train, {"maxkl":_maxkl}, op='max')
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info("%i:%03i %s" % (epoch, start, fmt_row(13, [lossdict[n] for n in self.loss_names])))
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([('opt_'+n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info['tps'] = self.nsteps * self.I.nenvs / (tnow - self.I.t_last_update)
            info['time_elapsed'] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization( # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs               # NOTE: not shared via MPI
        )

        self.summary_writer.add_summary(summary, self.I.stats['n_updates'])
        self.summary_writer.flush()

        return info
    def update(self):

        #Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(
            np.mean(list(self.I.statlists["eprew"])))
        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            logger.info(f"All scores {sorted(self.scores)}")

        #Normalize intrinsic rewards.
        rffs_int = np.array(
            [self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = np.mean(rews_int)
        self.max_int_rew = np.max(rews_int)

        #Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(
        ), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int)

        #Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            if self.use_news:
                nextnew = self.I.buf_news[:, t +
                                          1] if t + 1 < self.nsteps else self.I.buf_new_last
            else:
                nextnew = 0.0  #No dones for intrinsic reward.
            nextvals = self.I.buf_vpreds_int[:, t +
                                             1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last
            nextnotnew = 1 - nextnew
            delta = rews_int[:,
                             t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:,
                                                                                             t]
            self.I.buf_advs_int[:,
                                t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        #Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.I.buf_news[:, t +
                                      1] if t + 1 < self.nsteps else self.I.buf_new_last
            #Use dones for extrinsic reward.
            nextvals = self.I.buf_vpreds_ext[:, t +
                                             1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last
            nextnotnew = 1 - nextnew
            delta = rews_ext[:,
                             t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:,
                                                                                                 t]
            self.I.buf_advs_ext[:,
                                t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        #Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = self.int_coeff * self.I.buf_advs_int + self.ext_coeff * self.I.buf_advs_ext

        #Collects info for reporting.
        info = dict(
            advmean=self.I.buf_advs.mean(),
            advstd=self.I.buf_advs.std(),
            retintmean=rets_int.mean(),  # previously retmean
            retintstd=rets_int.std(),  # previously retstd
            retextmean=rets_ext.mean(),  # previously not there
            retextstd=rets_ext.std(),  # previously not there
            rewintmean_unnorm=rewmean,  # previously rewmean
            rewintmax_unnorm=rewmax,  # previously not there
            rewintmean_norm=self.mean_int_rew,  # previously rewintmean
            rewintmax_norm=self.max_int_rew,  # previously rewintmax
            rewintstd_unnorm=rewstd,  # previously rewstd
            vpredintmean=self.I.buf_vpreds_int.mean(),  # previously vpredmean
            vpredintstd=self.I.buf_vpreds_int.std(),  # previously vrpedstd
            vpredextmean=self.I.buf_vpreds_ext.mean(),  # previously not there
            vpredextstd=self.I.buf_vpreds_ext.std(),  # previously not there
            ev_int=np.clip(
                explained_variance(self.I.buf_vpreds_int.ravel(),
                                   rets_int.ravel()), -1, None),
            ev_ext=np.clip(
                explained_variance(self.I.buf_vpreds_ext.ravel(),
                                   rets_ext.ravel()), -1, None),
            rooms=SemicolonList(self.rooms),
            n_rooms=len(self.rooms),
            best_ret=self.best_ret,
            reset_counter=self.I.reset_counter)

        info[f'mem_available'] = psutil.virtual_memory().available

        to_record = {
            'acs': self.I.buf_acs,
            'rews_int': self.I.buf_rews_int,
            'rews_int_norm': rews_int,
            'rews_ext': self.I.buf_rews_ext,
            'vpred_int': self.I.buf_vpreds_int,
            'vpred_ext': self.I.buf_vpreds_ext,
            'adv_int': self.I.buf_advs_int,
            'adv_ext': self.I.buf_advs_ext,
            'ent': self.I.buf_ent,
            'ret_int': rets_int,
            'ret_ext': rets_ext,
        }

        if self.I.venvs[0].record_obs:
            if None in self.I.buf_obs:
                to_record['obs'] = self.I.buf_obs[None]
            else:
                to_record['obs'] = self.I.buf_obs['normal']

        self.recorder.record(bufs=to_record, infos=self.I.buf_epinfos)

        #Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches
        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        #verbose = True
        verbose = False
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info(
                "buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i "
                % (str(self.I.buf_advs.shape), samples, samples //
                   self.nminibatches, samples * self.comm_train_size,
                   samples * self.comm_train_size // self.nminibatches))
            logger.info(" " * 6 + fmt_row(13, self.loss_names))

        to_record_attention = None
        attention_output = None
        if os.environ['EXPERIMENT_LVL'] == 'attention' or os.environ[
                'EXPERIMENT_LVL'] == 'ego':
            try:
                #attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined:0")
                #attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined/kernel:0")
                attention_output = tf.get_default_graph().get_tensor_by_name(
                    "ppo/pol/augmented2/attention_output_combined/Conv2D:0")
            except Exception as e:
                logger.error("Exception in attention_output: {}".format(e))
                attention_output = None

        epoch = 0
        start = 0
        #Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})

            if None in self.stochpol.ph_ob:
                fd[self.stochpol.ph_ob[None]] = np.concatenate([
                    self.I.buf_obs[None][mbenvinds],
                    self.I.buf_ob_last[None][mbenvinds, None]
                ], 1)
                assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \
                [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)]

            else:
                fd[self.stochpol.ph_ob['normal']] = np.concatenate([
                    self.I.buf_obs['normal'][mbenvinds],
                    self.I.buf_ob_last['normal'][mbenvinds, None]
                ], 1)
                fd[self.stochpol.ph_ob['ego']] = np.concatenate([
                    self.I.buf_obs['ego'][mbenvinds],
                    self.I.buf_ob_last['ego'][mbenvinds, None]
                ], 1)

                assert list(fd[self.stochpol.ph_ob['normal']].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['normal'].shape), \
                [fd[self.stochpol.ph_ob['normal']].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['normal'].shape)]
                assert list(fd[self.stochpol.ph_ob['ego']].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['ego'].shape), \
                [fd[self.stochpol.ph_ob['ego']].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.spaces['ego'].shape)]

            fd.update({
                self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5
            })

            if attention_output is not None:
                _train_losses = [attention_output, self._train]
            else:
                _train_losses = [self._train]

            ret = tf.get_default_session().run(self._losses + _train_losses,
                                               feed_dict=fd)[:-1]

            if attention_output is not None:
                attn_output = ret[-1]
                ret = ret[:-1]
                if None in self.I.buf_obs:
                    outshape = list(
                        self.I.buf_obs[None][mbenvinds].shape[:2]) + list(
                            attn_output.shape[1:])
                else:
                    # does not matter if it's normal or ego, the first 2 axes are the same
                    outshape = list(
                        self.I.buf_obs['normal'][mbenvinds].shape[:2]) + list(
                            attn_output.shape[1:])
                attn_output = np.reshape(attn_output, outshape)
                attn_output = attn_output[:, :, :, :, :64]

            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            #Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop('maxkl')
            lossdict = dict_gather(self.comm_train, lossdict, op='mean')
            maxmaxkl = dict_gather(self.comm_train, {"maxkl": _maxkl},
                                   op='max')
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info(
                    "%i:%03i %s" %
                    (epoch, start,
                     fmt_row(13, [lossdict[n] for n in self.loss_names])))
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

                if attention_output is not None:
                    if to_record_attention is None:
                        to_record_attention = attn_output
                    else:
                        to_record_attention = np.concatenate(
                            [to_record_attention, attn_output])

        # if to_record_attention is not None:
        #     if None in self.I.buf_obs:
        #         to_record['obs'] = self.I.buf_obs[None]
        #     else:
        #         to_record['obs'] = self.I.buf_obs['normal']

        #     to_record['attention'] = to_record_attention

        to_record_attention = None

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([('opt_' + n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info['tps'] = self.nsteps * self.I.nenvs / (tnow -
                                                        self.I.t_last_update)
            info['time_elapsed'] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization(  # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs  # NOTE: not shared via MPI
        )
        return info
Beispiel #6
0
import time
Beispiel #7
0
    def train(self):

        start_time = time.time()

        self.episodes = self.env.generate_episodes(config.NUM_EPISODES, self)

        # Computing returns and estimating advantage function.
        for episode in self.episodes:
            episode["baseline"] = self.value_func.predict(episode)
            episode["returns"] = utils.discount(episode["rewards"],
                                                config.GAMMA)
            episode["advantage"] = episode["returns"] - episode["baseline"]

        # Updating policy.
        actions_dist_n = np.concatenate(
            [episode["actions_dist"] for episode in self.episodes])
        states_n = np.concatenate(
            [episode["states"] for episode in self.episodes])
        actions_n = np.concatenate(
            [episode["actions"] for episode in self.episodes])
        baseline_n = np.concatenate(
            [episode["baseline"] for episode in self.episodes])
        returns_n = np.concatenate(
            [episode["returns"] for episode in self.episodes])

        # Standardize the advantage function to have mean=0 and std=1.
        advantage_n = np.concatenate(
            [episode["advantage"] for episode in self.episodes])
        advantage_n -= advantage_n.mean()
        advantage_n /= (advantage_n.std() + 1e-8)

        # Computing baseline function for next iter.
        print(states_n.shape, actions_n.shape, advantage_n.shape,
              actions_dist_n.shape)
        feed = {
            self.policy.state: states_n,
            self.action: actions_n,
            self.advantage: advantage_n,
            self.policy.pi_theta_old: actions_dist_n
        }

        episoderewards = np.array(
            [episode["rewards"].sum() for episode in self.episodes])

        #print("\n********** Iteration %i ************" % i)

        self.value_func.fit(self.episodes)
        self.theta_old = self.current_theta()

        def fisher_vector_product(p):
            feed[self.flat_tangent] = p
            return self.session.run(self.fisher_vect_prod,
                                    feed) + config.CG_DAMP * p

        self.g = self.session.run(self.surr_loss_grad, feed_dict=feed)

        self.grad_step = utils.conjugate_gradient(fisher_vector_product,
                                                  -self.g)

        self.sAs = .5 * self.grad_step.dot(
            fisher_vector_product(self.grad_step))

        self.beta_inv = np.sqrt(self.sAs / config.MAX_KL)
        self.full_grad_step = self.grad_step / self.beta_inv

        self.negdot_grad_step = -self.g.dot(self.grad_step)

        def loss(th):
            self.set_theta(th)
            return self.session.run(self.surr_loss, feed_dict=feed)

        self.theta = utils.line_search(loss, self.theta_old,
                                       self.full_grad_step,
                                       self.negdot_grad_step / self.beta_inv)
        self.set_theta(self.theta)

        surr_loss_new = -self.session.run(self.surr_loss, feed_dict=feed)
        KL_old_new = self.session.run(self.KL, feed_dict=feed)
        entropy = self.session.run(self.entropy, feed_dict=feed)

        old_new_norm = np.sum((self.theta - self.theta_old)**2)

        if np.abs(KL_old_new) > 2.0 * config.MAX_KL:
            print("Keeping old theta")
            self.set_theta(self.theta_old)

        stats = {}
        stats["L2 of old - new"] = old_new_norm
        stats["Total number of episodes"] = len(self.episodes)
        stats["Average sum of rewards per episode"] = episoderewards.mean()
        stats["Entropy"] = entropy
        exp = utils.explained_variance(np.array(baseline_n),
                                       np.array(returns_n))
        stats["Baseline explained"] = exp
        stats["Time elapsed"] = "%.2f mins" % (
            (time.time() - start_time) / 60.0)
        stats["KL between old and new distribution"] = KL_old_new
        stats["Surrogate loss"] = surr_loss_new
        self.stats.append(stats)
        utils.write_dict(stats)
        save_path = self.saver.save(self.session, "./checkpoints/model.ckpt")
        print('Saved checkpoint to %s' % save_path)
        for k, v in stats.items():
            print(k + ": " + " " * (40 - len(k)) + str(v))
    def update(self):
        # Some logic gathering best ret, rooms etc using MPI.
        temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
        temp = sorted(list(set(temp)))
        self.rooms = temp

        temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
        temp = sorted(list(set(temp)))
        self.scores = temp

        temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
        self.best_ret = max(temp)

        eprews = MPI.COMM_WORLD.allgather(
            np.mean(list(self.I.statlists["eprew"])))
        local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
        n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])

        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info(f"Rooms visited {self.rooms}")
            logger.info(f"Best return {self.best_ret}")
            logger.info(f"Best local return {sorted(local_best_rets)}")
            logger.info(f"eprews {sorted(eprews)}")
            logger.info(f"n_rooms {sorted(n_rooms)}")
            logger.info(f"Extrinsic coefficient {self.ext_coeff}")
            logger.info(f"Gamma {self.gamma}")
            logger.info(f"Gamma ext {self.gamma_ext}")
            logger.info(f"All scores {sorted(self.scores)}")

        # Normalize intrinsic rewards.
        rffs_int = np.array(
            [self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
        self.I.rff_rms_int.update(rffs_int.ravel())
        rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
        self.mean_int_rew = np.mean(rews_int)
        self.max_int_rew = np.max(rews_int)

        # Don't normalize extrinsic rewards.
        rews_ext = self.I.buf_rews_ext

        rewmean, rewstd, rewmax = (
            self.I.buf_rews_int.mean(),
            self.I.buf_rews_int.std(),
            np.max(self.I.buf_rews_int),
        )

        # Calculate intrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            if self.use_news:
                nextnew = (self.I.buf_news[:, t + 1]
                           if t + 1 < self.nsteps else self.I.buf_new_last)
            else:
                nextnew = 0.0  # No dones for intrinsic reward.
            nextvals = (self.I.buf_vpreds_int[:, t + 1]
                        if t + 1 < self.nsteps else self.I.buf_vpred_int_last)
            nextnotnew = 1 - nextnew
            delta = (rews_int[:, t] + self.gamma * nextvals * nextnotnew -
                     self.I.buf_vpreds_int[:, t])
            self.I.buf_advs_int[:, t] = lastgaelam = (
                delta + self.gamma * self.lam * nextnotnew * lastgaelam)
        rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int

        # Calculate extrinsic returns and advantages.
        lastgaelam = 0
        for t in range(self.nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = (self.I.buf_news[:, t + 1]
                       if t + 1 < self.nsteps else self.I.buf_new_last)
            # Use dones for extrinsic reward.
            nextvals = (self.I.buf_vpreds_ext[:, t + 1]
                        if t + 1 < self.nsteps else self.I.buf_vpred_ext_last)
            nextnotnew = 1 - nextnew
            delta = (rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew -
                     self.I.buf_vpreds_ext[:, t])
            self.I.buf_advs_ext[:, t] = lastgaelam = (
                delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam)
        rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext

        # Combine the extrinsic and intrinsic advantages.
        self.I.buf_advs = (self.int_coeff * self.I.buf_advs_int +
                           self.ext_coeff * self.I.buf_advs_ext)

        # Collects info for reporting.
        info = dict(
            advmean=self.I.buf_advs.mean(),
            advstd=self.I.buf_advs.std(),
            retintmean=rets_int.mean(),  # previously retmean
            retintstd=rets_int.std(),  # previously retstd
            retextmean=rets_ext.mean(),  # previously not there
            retextstd=rets_ext.std(),  # previously not there
            rewintmean_unnorm=rewmean,  # previously rewmean
            rewintmax_unnorm=rewmax,  # previously not there
            rewintmean_norm=self.mean_int_rew,  # previously rewintmean
            rewintmax_norm=self.max_int_rew,  # previously rewintmax
            rewintstd_unnorm=rewstd,  # previously rewstd
            vpredintmean=self.I.buf_vpreds_int.mean(),  # previously vpredmean
            vpredintstd=self.I.buf_vpreds_int.std(),  # previously vrpedstd
            vpredextmean=self.I.buf_vpreds_ext.mean(),  # previously not there
            vpredextstd=self.I.buf_vpreds_ext.std(),  # previously not there
            ev_int=np.clip(
                explained_variance(self.I.buf_vpreds_int.ravel(),
                                   rets_int.ravel()),
                -1,
                None,
            ),
            ev_ext=np.clip(
                explained_variance(self.I.buf_vpreds_ext.ravel(),
                                   rets_ext.ravel()),
                -1,
                None,
            ),
            rooms=SemicolonList(self.rooms),
            n_rooms=len(self.rooms),
            best_ret=self.best_ret,
            reset_counter=self.I.reset_counter,
        )

        info[f"mem_available"] = psutil.virtual_memory().available

        to_record = {
            "acs": self.I.buf_acs,
            "rews_int": self.I.buf_rews_int,
            "rews_int_norm": rews_int,
            "rews_ext": self.I.buf_rews_ext,
            "vpred_int": self.I.buf_vpreds_int,
            "vpred_ext": self.I.buf_vpreds_ext,
            "adv_int": self.I.buf_advs_int,
            "adv_ext": self.I.buf_advs_ext,
            "ent": self.I.buf_ent,
            "ret_int": rets_int,
            "ret_ext": rets_ext,
        }
        if self.I.venvs[0].record_obs:
            to_record["obs"] = self.I.buf_obs['obs']
        self.recorder.record(bufs=to_record, infos=self.I.buf_epinfos)

        # Create feeddict for optimization.
        envsperbatch = self.I.nenvs // self.nminibatches
        ph_buf = [
            (self.stochpol.ph_ac, self.I.buf_acs),
            (self.ph_ret_int, rets_int),
            (self.ph_ret_ext, rets_ext),
            (self.ph_oldnlp, self.I.buf_nlps),
            (self.ph_adv, self.I.buf_advs),
        ]
        if self.I.mem_state is not NO_STATES:
            ph_buf.extend([
                (self.stochpol.ph_istate, self.I.seg_init_mem_state),
                (self.stochpol.ph_new, self.I.buf_news),
            ])

        verbose = True
        if verbose and self.is_log_leader:
            samples = np.prod(self.I.buf_advs.shape)
            logger.info(
                f"buffer shape {self.I.buf_advs.shape}, "
                f"samples_per_mpi={samples:d}, "
                f"mini_per_mpi={samples // self.nminibatches:d}, "
                f"samples={samples * self.comm_train_size:d}, "
                f"mini={samples * self.comm_train_size // self.nminibatches:d} "
            )
            logger.info(" " * 6 + fmt_row(13, self.loss_names))

        epoch = 0
        start = 0
        # Optimizes on current data for several epochs.
        while epoch < self.nepochs:
            end = start + envsperbatch
            mbenvinds = slice(start, end, None)

            fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf}
            fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange})
            fd[self.stochpol.ph_ob['obs']] = np.concatenate(
                [
                    self.I.buf_obs['obs'][mbenvinds],
                    self.I.buf_ob_last['obs'][mbenvinds, None],
                ],
                1,
            )

            if self.meta_rl:
                fd[self.stochpol.ph_ob['prev_acs']] = one_hot(
                    self.I.buf_acs[mbenvinds], self.ac_space.n)
                fd[self.stochpol.ph_ob['prev_rew']] = self.I.buf_rews_ext[
                    mbenvinds, ..., None]

            assert list(fd[self.stochpol.ph_ob['obs']].shape) == [
                self.I.nenvs // self.nminibatches,
                self.nsteps + 1,
            ] + list(self.ob_space.shape), [
                fd[self.stochpol.ph_ob['obs']].shape,
                [self.I.nenvs // self.nminibatches, self.nsteps + 1] +
                list(self.ob_space.shape),
            ]
            fd.update({
                self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5,
            })

            ret = tf.get_default_session().run(self._losses + [self._train],
                                               feed_dict=fd)[:-1]
            if not self.testing:
                lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
            else:
                lossdict = {}
            # Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
            _maxkl = lossdict.pop("maxkl")
            lossdict = dict_gather(self.comm_train, lossdict, op="mean")
            maxmaxkl = dict_gather(self.comm_train, {"maxkl": _maxkl},
                                   op="max")
            lossdict["maxkl"] = maxmaxkl["maxkl"]
            if verbose and self.is_log_leader:
                logger.info(
                    f"{epoch:d}:{start:03d} {fmt_row(13, [lossdict[n] for n in self.loss_names])}"
                )
            start += envsperbatch
            if start == self.I.nenvs:
                epoch += 1
                start = 0

        if self.is_train_leader:
            self.I.stats["n_updates"] += 1
            info.update([("opt_" + n, lossdict[n]) for n in self.loss_names])
            tnow = time.time()
            info["tps"] = self.nsteps * self.I.nenvs / (tnow -
                                                        self.I.t_last_update)
            info["time_elapsed"] = time.time() - self.t0
            self.I.t_last_update = tnow
        self.stochpol.update_normalization(
            # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
            ob=self.I.buf_obs  # NOTE: not shared via MPI
        )
        return info
Beispiel #9
0
def main():
    env = make_env()
    set_global_seeds(env, args.seed)

    agent = PPO(env=env)

    batch_steps = args.n_envs * args.batch_steps  # number of steps per update

    if args.save_interval and logger.get_dir():
        # some saving jobs
        pass

    ep_info_buffer = deque(maxlen=100)
    t_train_start = time.time()
    n_updates = args.n_steps // batch_steps
    runner = Runner(env, agent)

    for update in range(1, n_updates + 1):
        t_start = time.time()
        frac = 1.0 - (update - 1.0) / n_updates
        lr_now = args.lr  # maybe dynamic change
        clip_range_now = args.clip_range  # maybe dynamic change
        obs, returns, masks, acts, vals, neglogps, advs, rewards, ep_infos = \
            runner.run(args.batch_steps, frac)
        ep_info_buffer.extend(ep_infos)
        loss_infos = []

        idxs = np.arange(batch_steps)
        for _ in range(args.n_epochs):
            np.random.shuffle(idxs)
            for start in range(0, batch_steps, args.minibatch):
                end = start + args.minibatch
                mb_idxs = idxs[start:end]
                minibatch = [
                    arr[mb_idxs] for arr in
                    [obs, returns, masks, acts, vals, neglogps, advs]
                ]
                loss_infos.append(
                    agent.train(lr_now, clip_range_now, *minibatch))

        t_now = time.time()
        time_this_batch = t_now - t_start
        if update % args.log_interval == 0:
            ev = float(explained_variance(vals, returns))
            logger.logkv('updates', str(update) + '/' + str(n_updates))
            logger.logkv('serial_steps', update * args.batch_steps)
            logger.logkv('total_steps', update * batch_steps)
            logger.logkv('time', time_this_batch)
            logger.logkv('fps', int(batch_steps / (t_now - t_start)))
            logger.logkv('total_time', t_now - t_train_start)
            logger.logkv("explained_variance", ev)
            logger.logkv('avg_reward',
                         np.mean([e['r'] for e in ep_info_buffer]))
            logger.logkv('avg_ep_len',
                         np.mean([e['l'] for e in ep_info_buffer]))
            logger.logkv('adv_mean', np.mean(returns - vals))
            logger.logkv('adv_variance', np.std(returns - vals)**2)
            loss_infos = np.mean(loss_infos, axis=0)
            for loss_name, loss_info in zip(agent.loss_names, loss_infos):
                logger.logkv(loss_name, loss_info)
            logger.dumpkvs()

        if args.save_interval and update % args.save_interval == 0 and logger.get_dir(
        ):
            pass
    env.close()
Beispiel #10
0
def learn(policy, env, test_env, seed, master_ts = 1, worker_ts=8, cell = 256,
          ent_coef = 0.01, vf_coef = 0.5, max_grad_norm = 2.5, lr = 7e-4,
          alpha = 0.99, epsilon = 1e-5, total_timesteps = int(80e6), lrschedule = 'linear',
          ib_alpha = 1e-3, sv_M = 32, algo='use_svib_uniform',
          log_interval = 10, gamma = 0.99, load_path="saved_nets-data/hrl_a2c/%s/data"%start_time):

    tf.reset_default_graph()
    set_global_seeds(seed)
    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    print(str(nenvs)+"------------"+str(ob_space)+"-----------"+str(ac_space))
    model = Model_A2C_SVIB(policy = policy, ob_space = ob_space, ac_space = ac_space, nenvs = nenvs, master_ts=master_ts, worker_ts=worker_ts,
                           ent_coef = ent_coef, vf_coef = vf_coef, max_grad_norm = max_grad_norm, lr = lr, cell = cell,
                           ib_alpha = ib_alpha, sv_M = sv_M, algo=algo,
                           alpha = alpha, epsilon = epsilon, total_timesteps = total_timesteps, lrschedule = lrschedule)
    try:
        model.load(load_path)
    except Exception as e:
        print("no data to load!!"+str(e))
    runner = Runner_svib(env = env, model = model, nsteps=master_ts*worker_ts, gamma=gamma)
    test_runner = test_Runner(env = test_env, model = model)

    tf.get_default_graph().finalize()
    nbatch = nenvs * master_ts * worker_ts
    tstart = time.time()
    reward_list = []
    value_list = []
    prev_r = 0.
    exp_coef = .1
    for update in range(1, total_timesteps//nbatch+1):
        b_obs, b_whs, states, b_rewards, b_wmasks, b_actions, b_values, b_noises = runner.run()
        # print(np.max(b_obs[0]))
        # if algo == 'use_svib_gaussian':
        #     tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, gaussian_gradients, repr_grad_norm =\
        #         model.train(b_obs, b_whs, states, b_rewards, b_wmasks, b_actions, b_values)
        #     # print(gaussian_gradients[0, 3:5, 0:30])
        #     # print('1')
        # else:
        tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, repr_grad_norm, represent_loss, anchor_loss, sv_loss = \
            model.train(b_obs, b_whs, states, b_rewards, b_wmasks, b_actions, b_values, b_noises)
        # print('b_whs:', b_whs[0, 0:60])
        # print('sv_grad:',SV_GRAD[0, 0:40])
        # print('exploit:',EXPLOIT[0, 0:40])
        # print('log_p_grads:',LOG_P_GRADS[0, 3:5, 0:40])
        # print('explore:',EXPLORE[0, 0:40])
        nseconds = time.time()-tstart
        fps = int((update*nbatch)/nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(b_values, b_rewards)
            logger.record_tabular("fps", fps)
            logger.record_tabular("tloss", float(tloss))
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("policy_loss", float(policy_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.record_tabular('repr_grad_norm', float(repr_grad_norm))
            logger.record_tabular('rl_grad_norm', float(rl_grad_norm))
            logger.record_tabular('repr_loss', float(represent_loss))
            logger.record_tabular('anchor_loss',float(anchor_loss))
            logger.record_tabular('sv_loss',float(np.mean(sv_loss)))
            # if algo == 'use_svib_gaussian':
            #     logger.record_tabular('gaussian_grad_norm_without_clip', float(np.mean(np.abs(gaussian_gradients[0]))))
            # logger.record_tabular('represent_loss', float(represent_loss))
            logger.record_tabular('represent_mean', float(np.mean(b_whs[0])))
            logger.dump_tabular()
            if update % (200*log_interval) == 0 or update == 1:
                save_th = update//(200*log_interval)
                model.save("saved_nets-data/%s/hrl_a2c_svib/%s/%s/data" % (env_name, start_time, save_th))
                # model.train_mine(b_obs, b_whs)
                # episode_r = exp_coef*(test_runner.run()) + (1-exp_coef)*prev_r
                episode_r = exp_coef*(test_runner.run()) + (1-exp_coef)*prev_r
                prev_r = np.copy(episode_r)
                reward_list.append(episode_r)
                value_list.append(value_loss)
                logger.record_tabular('episode_r', float(episode_r))
                logger.dump_tabular()
    env.close()
    tf.reset_default_graph()
    return reward_list, value_list
Beispiel #11
0
def train(
    env_name,
    batch_size,
    minibatch_size,
    updates,
    epochs,
    hparam,
    hp_summary_writer,
    save_model=False,
    load_path=None,
):
    """
    Main learning function

    Args:
        batch_size: size of the buffer, may have multiple trajecties inside
        minibatch_size: one batch is seperated into several minibatches. Each has this size.
        epochs: in one epoch, buffer is fully filled, and trained multiple times with minibatches.
    """
    actor_critic = Actor_Critic(hparam)

    if load_path is not None:
        print("Loading model ...")
        load_path = osp.expanduser(load_path)
        ckpt = tf.train.Checkpoint(model=actor_critic)
        manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=5)
        ckpt.restore(manager.latest_checkpoint)

    # set env
    with SC2EnvWrapper(
        map_name=env_name,
        players=[sc2_env.Agent(sc2_env.Race.random)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_minimap=MINIMAP_RES, feature_screen=MINIMAP_RES
        ),
        step_mul=FLAGS.step_mul,
        game_steps_per_episode=FLAGS.game_steps_per_episode,
        disable_fog=FLAGS.disable_fog,
    ) as env:
        actor_critic.set_act_spec(env.action_spec()[0])  # assume one agent

        def train_one_update(step, epochs, tracing_on):
            # initialize replay buffer
            buffer = Buffer(
                batch_size,
                minibatch_size,
                MINIMAP_RES,
                MINIMAP_RES,
                env.action_spec()[0],
            )

            # initial observation
            timestep = env.reset()
            step_type, reward, _, obs = timestep[0]
            obs = preprocess(obs)

            ep_ret = []  # episode return (score)
            ep_rew = 0

            # fill in recorded trajectories
            while True:
                tf_obs = (
                    tf.constant(each_obs, shape=(1, *each_obs.shape))
                    for each_obs in obs
                )

                val, act_id, arg_spatial, arg_nonspatial, logp_a = actor_critic.step(
                    *tf_obs
                )

                sc2act_args = translateActionToSC2(
                    arg_spatial, arg_nonspatial, MINIMAP_RES, MINIMAP_RES
                )

                act_mask = get_mask(act_id.numpy().item(), actor_critic.action_spec)
                buffer.add(
                    *obs,
                    act_id.numpy().item(),
                    sc2act_args,
                    act_mask,
                    logp_a.numpy().item(),
                    val.numpy().item()
                )
                step_type, reward, _, obs = env.step(
                    [actions.FunctionCall(act_id.numpy().item(), sc2act_args)]
                )[0]
                # print("action:{}: {} reward {}".format(act_id.numpy().item(), sc2act_args, reward))
                buffer.add_rew(reward)
                obs = preprocess(obs)

                ep_rew += reward

                if step_type == step_type.LAST or buffer.is_full():
                    if step_type == step_type.LAST:
                        buffer.finalize(0)
                    else:
                        # trajectory is cut off, bootstrap last state with estimated value
                        tf_obs = (
                            tf.constant(each_obs, shape=(1, *each_obs.shape))
                            for each_obs in obs
                        )
                        val, _, _, _, _ = actor_critic.step(*tf_obs)
                        buffer.finalize(val)

                    ep_rew += reward
                    ep_ret.append(ep_rew)
                    ep_rew = 0

                    if buffer.is_full():
                        break

                    # respawn env
                    env.render(True)
                    timestep = env.reset()
                    _, _, _, obs = timestep[0]
                    obs = preprocess(obs)

            # train in minibatches
            buffer.post_process()

            mb_loss = []
            for ep in range(epochs):
                buffer.shuffle()

                for ind in range(batch_size // minibatch_size):
                    (
                        player,
                        available_act,
                        minimap,
                        # screen,
                        act_id,
                        act_args,
                        act_mask,
                        logp,
                        val,
                        ret,
                        adv,
                    ) = buffer.minibatch(ind)

                    assert ret.shape == val.shape
                    assert logp.shape == adv.shape
                    if tracing_on:
                        tf.summary.trace_on(graph=True, profiler=False)

                    mb_loss.append(
                        actor_critic.train_step(
                            tf.constant(step, dtype=tf.int64),
                            player,
                            available_act,
                            minimap,
                            # screen,
                            act_id,
                            act_args,
                            act_mask,
                            logp,
                            val,
                            ret,
                            adv,
                        )
                    )
                    step += 1

                    if tracing_on:
                        tracing_on = False
                        with train_summary_writer.as_default():
                            tf.summary.trace_export(name="train_step", step=0)

            batch_loss = np.mean(mb_loss)

            return (
                batch_loss,
                ep_ret,
                buffer.batch_ret,
                np.asarray(buffer.batch_vals, dtype=np.float32),
            )

        num_train_per_update = epochs * (batch_size // minibatch_size)
        for i in range(updates):
            if i == 0:
                tracing_on = True
            else:
                tracing_on = False
            batch_loss, cumulative_rew, batch_ret, batch_vals = train_one_update(
                i * num_train_per_update, epochs, tracing_on
            )
            ev = explained_variance(batch_vals, batch_ret)
            with train_summary_writer.as_default():
                tf.summary.scalar(
                    "batch/cumulative_rewards", np.mean(cumulative_rew), step=i
                )
                tf.summary.scalar("batch/ev", ev, step=i)
                tf.summary.scalar("loss/batch_loss", batch_loss, step=i)
            with hp_summary_writer.as_default():
                tf.summary.scalar("rewards", np.mean(cumulative_rew), step=i)
            print("----------------------------")
            print(
                "epoch {0:2d} loss {1:.3f} batch_ret {2:.3f}".format(
                    i, batch_loss, np.mean(cumulative_rew)
                )
            )
            print("----------------------------")

            # save model
            if save_model and i % 15 == 0:
                print("saving model ...")
                save_path = osp.expanduser(saved_model_dir)
                ckpt = tf.train.Checkpoint(model=actor_critic)
                manager = tf.train.CheckpointManager(ckpt, save_path, max_to_keep=3)
                manager.save()