Example #1
0
    def train(self, epis):
        traj = Traj(ddp=True, traj_device=self.device)
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, self.vf)
        traj = ef.compute_rets(traj, args.gamma)
        traj = ef.compute_advs(traj, args.gamma, args.lam)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj,
                                     pol=self.ddp_pol,
                                     vf=self.ddp_vf,
                                     clip_param=self.args.clip_param,
                                     optim_pol=self.optim_pol,
                                     optim_vf=self.optim_vf,
                                     epoch=self.args.epoch_per_iter,
                                     batch_size=self.args.batch_size,
                                     max_grad_norm=self.args.max_grad_norm,
                                     log_enable=self.rank == 0)

        result_dict["traj_num_step"] = traj.num_step
        result_dict["traj_num_epi"] = traj.num_epi
        return result_dict
Example #2
0
    def test_learning_rnn(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = CategoricalPol(
            self.env.observation_space, self.env.action_space, pol_net, rnn=True)

        vf_net = VNetLSTM(self.env.observation_space, h_size=32, cell_size=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net, rnn=True)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=400)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = trpo.train(traj, pol, vf, optim_vf, 1, 2)

        del sampler
Example #3
0
    def test_learning_rnn(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net, rnn=True)

        vf_net = VNetLSTM(self.env.observation_space, h_size=32, cell_size=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net, rnn=True)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=400)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=0.2,
                                     optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=2)
        result_dict = ppo_kl.train(traj=traj, pol=pol, vf=vf, kl_beta=0.1, kl_targ=0.2,
                                   optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=2, max_grad_norm=20)

        del sampler
Example #4
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = CategoricalPol(self.env.ob_space, self.env.ac_space, pol_net)

        vf_net = VNet(self.env.ob_space, h1=32, h2=32)
        vf = DeterministicSVfunc(self.env.ob_space, vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=0.2,
                                     optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=32)
        result_dict = ppo_kl.train(traj=traj, pol=pol, vf=vf, kl_beta=0.1, kl_targ=0.2,
                                   optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=32, max_grad_norm=10)

        del sampler
Example #5
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space, h1=32, h2=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net)

        vf_net = VNet(self.env.observation_space, h1=32, h2=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = trpo.train(traj, pol, vf, optim_vf, 1, 24)

        del sampler
Example #6
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32)
        pol = GaussianPol(self.env.observation_space, self.env.action_space,
                          pol_net)

        vf_net = VNet(self.env.observation_space)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        discrim_net = DiscrimNet(self.env.observation_space,
                                 self.env.action_space,
                                 h1=32,
                                 h2=32)
        discrim = DeterministicSAVfunc(self.env.observation_space,
                                       self.env.action_space, discrim_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(discrim_net.parameters(), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'),
                  'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.compute_pseudo_rews(agent_traj, discrim)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = gail.train(agent_traj,
                                 expert_traj,
                                 pol,
                                 vf,
                                 discrim,
                                 optim_vf,
                                 optim_discrim,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32,
                                 discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3,
                                 discrim_ent_beta=1e-5)

        del sampler
Example #7
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = GaussianPol(self.env.ob_space, self.env.ac_space, pol_net)

        vf_net = VNet(self.env.ob_space)
        vf = DeterministicSVfunc(self.env.ob_space, vf_net)

        rewf_net = VNet(self.env.ob_space, h1=32, h2=32)
        rewf = DeterministicSVfunc(self.env.ob_space, rewf_net)
        shaping_vf_net = VNet(self.env.ob_space, h1=32, h2=32)
        shaping_vf = DeterministicSVfunc(self.env.ob_space, shaping_vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(
            list(rewf_net.parameters()) + list(shaping_vf_net.parameters()), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'), 'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj = ef.add_next_obs(expert_traj)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.add_next_obs(agent_traj)
        agent_traj = ef.compute_pseudo_rews(
            agent_traj, rew_giver=rewf, state_only=True)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = airl.train(agent_traj, expert_traj, pol, vf, optim_vf, optim_discrim,
                                 rewf=rewf, shaping_vf=shaping_vf,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32, discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3, gamma=0.99)

        del sampler
        for _ in range(args.bc_epoch):
            _ = behavior_clone.train(expert_traj, pol, optim_pol,
                                     args.bc_batch_size)
    torch.save(pol.state_dict(), os.path.join(args.log, 'models',
                                              'pol_bc.pkl'))

while args.max_epis > total_epi:
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, args.gamma)
        traj = ef.compute_advs(traj, args.gamma, args.lam)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        if args.data_parallel:
            pol.dp_run = True
            vf.dp_run = True

        if args.ppo_type == 'clip':
            result_dict = ppo_clip.train(traj=traj,
                                         pol=pol,
                                         vf=vf,
                                         clip_param=args.clip_param,
                                         optim_pol=optim_pol,
                                         optim_vf=optim_vf,
Example #9
0
if args.pretrain:
    with measure('bc pretrain'):
        _ = behavior_clone.train(expert_traj, pol, optim_pol,
                                 args.bc_batch_size, args.bc_epoch)

while args.max_epis > total_epi:
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
    with measure('train'):
        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.compute_pseudo_rews(agent_traj, discrim)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, args.gamma)
        agent_traj = ef.compute_advs(agent_traj, args.gamma, args.lam)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        if args.data_parallel:
            pol.dp_run = True
            vf.dp_run = True
            discrim.dp_run = True

        if args.rl_type == 'trpo':
            result_dict = gail.train(
                agent_traj,
                expert_traj,
                pol,
                vf,
Example #10
0
total_epi = 0
total_step = 0
max_rew = -1e6
while args.max_epis > total_epi:
    with measure('sample'):
        epis1 = sampler1.sample(pol, max_epis=args.max_epis_per_iter)
        epis2 = sampler2.sample(pol, max_epis=args.max_epis_per_iter)
    with measure('train'):
        traj1 = Traj()
        traj2 = Traj()

        traj1.add_epis(epis1)
        traj1 = ef.compute_vs(traj1, vf)
        traj1 = ef.compute_rets(traj1, args.gamma)
        traj1 = ef.compute_advs(traj1, args.gamma, args.lam)
        traj1 = ef.centerize_advs(traj1)
        traj1 = ef.compute_h_masks(traj1)
        traj1.register_epis()

        traj2.add_epis(epis2)
        traj2 = ef.compute_vs(traj2, vf)
        traj2 = ef.compute_rets(traj2, args.gamma)
        traj2 = ef.compute_advs(traj2, args.gamma, args.lam)
        traj2 = ef.centerize_advs(traj2)
        traj2 = ef.compute_h_masks(traj2)
        traj2.register_epis()

        traj1.add_traj(traj2)

        if args.data_parallel:
Example #11
0
    def train(self):
        args = self.args

        # TODO: cuda seems to be broken, I don't care about it right now
        # if args.cuda:
        #     # current_obs = current_obs.cuda()
        #     rollouts.cuda()

        self.train_start_time = time.time()
        total_epi = 0
        total_step = 0
        max_rew = -1e6
        sampler = None

        score_file = os.path.join(self.logger.get_logdir(), "progress.csv")
        logger.add_tabular_output(score_file)

        num_total_frames = args.num_total_frames

        mirror_function = None
        if args.mirror_tuples and hasattr(self.env.unwrapped,
                                          "mirror_indices"):
            mirror_function = get_mirror_function(
                **self.env.unwrapped.mirror_indices)
            num_total_frames *= 2
            if not args.tanh_finish:
                warnings.warn(
                    "When `mirror_tuples` is `True`,"
                    " `tanh_finish` should be set to `True` as well."
                    " Otherwise there is a chance of the training blowing up.")

        while num_total_frames > total_step:
            # setup the correct curriculum learning environment/parameters
            new_curriculum = self.curriculum_handler(total_step /
                                                     args.num_total_frames)

            if total_step == 0 or new_curriculum:
                if sampler is not None:
                    del sampler
                sampler = EpiSampler(
                    self.env,
                    self.pol,
                    num_parallel=self.args.num_processes,
                    seed=self.args.seed + total_step,  # TODO: better fix?
                )

            with measure("sample"):
                epis = sampler.sample(self.pol,
                                      max_steps=args.num_steps *
                                      args.num_processes)

            with measure("train"):
                with measure("epis"):
                    traj = Traj()
                    traj.add_epis(epis)

                    traj = ef.compute_vs(traj, self.vf)
                    traj = ef.compute_rets(traj, args.decay_gamma)
                    traj = ef.compute_advs(traj, args.decay_gamma,
                                           args.gae_lambda)
                    traj = ef.centerize_advs(traj)
                    traj = ef.compute_h_masks(traj)
                    traj.register_epis()

                    if mirror_function:
                        traj.add_traj(mirror_function(traj))

                # if args.data_parallel:
                #     self.pol.dp_run = True
                #     self.vf.dp_run = True

                result_dict = ppo_clip.train(
                    traj=traj,
                    pol=self.pol,
                    vf=self.vf,
                    clip_param=args.clip_eps,
                    optim_pol=self.optim_pol,
                    optim_vf=self.optim_vf,
                    epoch=args.epoch_per_iter,
                    batch_size=args.batch_size
                    if not args.rnn else args.rnn_batch_size,
                    max_grad_norm=args.max_grad_norm,
                )

                # if args.data_parallel:
                #     self.pol.dp_run = False
                #     self.vf.dp_run = False

            ## append the metrics to the `results_dict` (reported in the progress.csv)
            result_dict.update(self.get_extra_metrics(epis))

            total_epi += traj.num_epi
            step = traj.num_step
            total_step += step
            rewards = [np.sum(epi["rews"]) for epi in epis]
            mean_rew = np.mean(rewards)
            logger.record_results(
                self.logger.get_logdir(),
                result_dict,
                score_file,
                total_epi,
                step,
                total_step,
                rewards,
                plot_title=args.env,
            )

            if mean_rew > max_rew:
                self.save_models("max")
                max_rew = mean_rew

            self.save_models("last")

            self.scheduler_pol.step()
            self.scheduler_vf.step()

            del traj
# train loop
while max_episodes > total_epi:
    # sample trajectories
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=max_steps_per_iter)

    # train from trajectories
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)

        # calulate advantage
        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, gamma)
        traj = ef.compute_advs(traj, gamma, lam)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj,
                                     pol=pol,
                                     vf=vf,
                                     clip_param=clip_param,
                                     optim_pol=optim_pol,
                                     optim_vf=optim_vf,
                                     epoch=epoch_per_iter,
                                     batch_size=batch_size,
                                     max_grad_norm=max_grad_norm)
    # update counter and record
    total_epi += traj.num_epi