示例#1
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    args_dir, logs_dir, models_dir, samples_dir = get_all_save_paths(
        args, 'pretrain', combine_action=args.combine_action)
    eval_log_dir = logs_dir + "_eval"
    utils.cleanup_log_dir(logs_dir)
    utils.cleanup_log_dir(eval_log_dir)

    _, _, intrinsic_models_dir, _ = get_all_save_paths(args,
                                                       'learn_reward',
                                                       load_only=True)
    if args.load_iter != 'final':
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir,
            args.env_name + '_{}.pt'.format(args.load_iter))
    else:
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir, args.env_name + '.pt'.format(args.load_iter))
    intrinsic_arg_file_name = os.path.join(args_dir, 'command.txt')

    # save args to arg_file
    with open(intrinsic_arg_file_name, 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, logs_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)
    else:
        raise NotImplementedError

    if args.use_intrinsic:
        obs_shape = envs.observation_space.shape
        if len(obs_shape) == 3:
            action_dim = envs.action_space.n
        elif len(obs_shape) == 1:
            action_dim = envs.action_space.shape[0]

        if 'NoFrameskip' in args.env_name:
            file_name = os.path.join(
                args.experts_dir, "trajs_ppo_{}.pt".format(
                    args.env_name.split('-')[0].replace('NoFrameskip',
                                                        '').lower()))
        else:
            file_name = os.path.join(
                args.experts_dir,
                "trajs_ppo_{}.pt".format(args.env_name.split('-')[0].lower()))

        rff = RewardForwardFilter(args.gamma)
        intrinsic_rms = RunningMeanStd(shape=())

        if args.intrinsic_module == 'icm':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            inverse_model, forward_dynamics_model, encoder = torch.load(
                intrinsic_model_file_name)
            icm =  IntrinsicCuriosityModule(envs, device, inverse_model, forward_dynamics_model, \
                                            inverse_lr=args.intrinsic_lr, forward_lr=args.intrinsic_lr,\
                                            )

        if args.intrinsic_module == 'vae':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            vae = torch.load(intrinsic_model_file_name)
            icm =  GenerativeIntrinsicRewardModule(envs, device, \
                                                   vae, lr=args.intrinsic_lr, \
                                                   )

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            obs, reward, done, infos = envs.step(action)
            next_obs = obs

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, next_obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.use_intrinsic:
            for step in range(args.num_steps):
                state = rollouts.obs[step]
                action = rollouts.actions[step]
                next_state = rollouts.next_obs[step]
                if args.intrinsic_module == 'icm':
                    state = encoder(state)
                    next_state = encoder(next_state)
                with torch.no_grad():
                    rollouts.rewards[
                        step], pred_next_state = icm.calculate_intrinsic_reward(
                            state, action, next_state, args.lambda_true_action)
            if args.standardize == 'True':
                buf_rews = rollouts.rewards.cpu().numpy()
                intrinsic_rffs = np.array(
                    [rff.update(rew) for rew in buf_rews.T])
                rffs_mean, rffs_std, rffs_count = mpi_moments(
                    intrinsic_rffs.ravel())
                intrinsic_rms.update_from_moments(rffs_mean, rffs_std**2,
                                                  rffs_count)
                mean = intrinsic_rms.mean
                std = np.asarray(np.sqrt(intrinsic_rms.var))
                rollouts.rewards = rollouts.rewards / torch.from_numpy(std).to(
                    device)

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(models_dir, args.algo)
            policy_file_name = os.path.join(save_path, args.env_name + '.pt')

            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], policy_file_name)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "{} Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(args.env_name, j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
示例#2
0
class PpoOptimizer(object):
    envs = None

    def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma,
                 lam, nepochs, lr, cliprange, nminibatches, normrew, normadv,
                 use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
                 dynamics):
        self.dynamics = dynamics
        self.use_recorder = True
        self.n_updates = 0
        self.scope = scope
        self.ob_space = ob_space
        self.ac_space = ac_space
        self.stochpol = stochpol
        self.nepochs = nepochs
        self.lr = lr
        self.cliprange = cliprange
        self.nsteps_per_seg = nsteps_per_seg
        self.nsegs_per_env = nsegs_per_env
        self.nminibatches = nminibatches
        self.gamma = gamma
        self.lam = lam
        self.normrew = normrew
        self.normadv = normadv
        self.use_news = use_news
        self.ent_coef = ent_coef
        self.ext_coeff = ext_coeff
        self.int_coeff = int_coeff

    def start_interaction(self, env_fns, dynamics, nlump=2):
        param_list = self.stochpol.param_list + self.dynamics.param_list + self.dynamics.auxiliary_task.param_list  # copy a link, not deepcopy.
        self.optimizer = torch.optim.Adam(param_list, lr=self.lr)
        self.optimizer.zero_grad()

        self.all_visited_rooms = []
        self.all_scores = []
        self.nenvs = nenvs = len(env_fns)
        self.nlump = nlump
        self.lump_stride = nenvs // self.nlump
        self.envs = [
            VecEnv(env_fns[l * self.lump_stride:(l + 1) * self.lump_stride],
                   spaces=[self.ob_space, self.ac_space])
            for l in range(self.nlump)
        ]

        self.rollout = Rollout(ob_space=self.ob_space,
                               ac_space=self.ac_space,
                               nenvs=nenvs,
                               nsteps_per_seg=self.nsteps_per_seg,
                               nsegs_per_env=self.nsegs_per_env,
                               nlumps=self.nlump,
                               envs=self.envs,
                               policy=self.stochpol,
                               int_rew_coeff=self.int_coeff,
                               ext_rew_coeff=self.ext_coeff,
                               record_rollouts=self.use_recorder,
                               dynamics=dynamics)

        self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
        self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)

        if self.normrew:
            self.rff = RewardForwardFilter(self.gamma)
            self.rff_rms = RunningMeanStd()

        self.step_count = 0
        self.t_last_update = time.time()
        self.t_start = time.time()

    def stop_interaction(self):
        for env in self.envs:
            env.close()

    def calculate_advantages(self, rews, use_news, gamma, lam):
        nsteps = self.rollout.nsteps
        lastgaelam = 0
        for t in range(nsteps - 1, -1, -1):  # nsteps-2 ... 0
            nextnew = self.rollout.buf_news[:, t +
                                            1] if t + 1 < nsteps else self.rollout.buf_new_last
            if not use_news:
                nextnew = 0
            nextvals = self.rollout.buf_vpreds[:, t +
                                               1] if t + 1 < nsteps else self.rollout.buf_vpred_last
            nextnotnew = 1 - nextnew
            delta = rews[:,
                         t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:,
                                                                                      t]
            self.buf_advs[:,
                          t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
        self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds

    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        to_report = {
            'total': 0.0,
            'pg': 0.0,
            'vf': 0.0,
            'ent': 0.0,
            'approxkl': 0.0,
            'clipfrac': 0.0,
            'aux': 0.0,
            'dyn_loss': 0.0,
            'feat_var': 0.0
        }

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]

                acs = self.rollout.buf_acs[mbenvinds]
                rews = self.rollout.buf_rews[mbenvinds]
                vpreds = self.rollout.buf_vpreds[mbenvinds]
                nlps = self.rollout.buf_nlps[mbenvinds]
                obs = self.rollout.buf_obs[mbenvinds]
                rets = self.buf_rets[mbenvinds]
                advs = self.buf_advs[mbenvinds]
                last_obs = self.rollout.buf_obs_last[mbenvinds]

                lr = self.lr
                cliprange = self.cliprange

                self.stochpol.update_features(obs, acs)
                self.dynamics.auxiliary_task.update_features(obs, last_obs)
                self.dynamics.update_features(obs, last_obs)

                feat_loss = torch.mean(self.dynamics.auxiliary_task.get_loss())
                dyn_loss = torch.mean(self.dynamics.get_loss())

                acs = torch.tensor(flatten_dims(acs, len(self.ac_space.shape)))
                neglogpac = self.stochpol.pd.neglogp(acs)
                entropy = torch.mean(self.stochpol.pd.entropy())
                vpred = self.stochpol.vpred
                vf_loss = 0.5 * torch.mean(
                    (vpred.squeeze() - torch.tensor(rets))**2)

                nlps = torch.tensor(flatten_dims(nlps, 0))
                ratio = torch.exp(nlps - neglogpac.squeeze())

                advs = flatten_dims(advs, 0)
                negadv = torch.tensor(-advs)
                pg_losses1 = negadv * ratio
                pg_losses2 = negadv * torch.clamp(
                    ratio, min=1.0 - cliprange, max=1.0 + cliprange)
                pg_loss_surr = torch.max(pg_losses1, pg_losses2)
                pg_loss = torch.mean(pg_loss_surr)
                ent_loss = (-self.ent_coef) * entropy

                approxkl = 0.5 * torch.mean((neglogpac - nlps)**2)
                clipfrac = torch.mean(
                    (torch.abs(pg_losses2 - pg_loss_surr) > 1e-6).float())
                feat_var = torch.std(self.dynamics.auxiliary_task.features)

                total_loss = pg_loss + ent_loss + vf_loss + feat_loss + dyn_loss

                total_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                to_report['total'] += total_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['pg'] += pg_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['vf'] += vf_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['ent'] += ent_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['approxkl'] += approxkl.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['clipfrac'] += clipfrac.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['feat_var'] += feat_var.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['aux'] += feat_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['dyn_loss'] += dyn_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)

        info.update(to_report)
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = self.rollout.nsteps * self.nenvs / (
            tnow - self.t_last_update)  # MPI.COMM_WORLD.Get_size() *
        self.t_last_update = tnow

        return info

    def step(self):
        self.rollout.collect_rollout()
        update_info = self.update()
        return {'update': update_info}

    def get_var_values(self):
        return self.stochpol.get_var_values()

    def set_var_values(self, vv):
        self.stochpol.set_var_values(vv)