Ejemplo n.º 1
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32)
        pol = GaussianPol(self.env.observation_space, self.env.action_space,
                          pol_net)

        vf_net = VNet(self.env.observation_space)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        discrim_net = DiscrimNet(self.env.observation_space,
                                 self.env.action_space,
                                 h1=32,
                                 h2=32)
        discrim = DeterministicSAVfunc(self.env.observation_space,
                                       self.env.action_space, discrim_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(discrim_net.parameters(), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'),
                  'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.compute_pseudo_rews(agent_traj, discrim)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = gail.train(agent_traj,
                                 expert_traj,
                                 pol,
                                 vf,
                                 discrim,
                                 optim_vf,
                                 optim_discrim,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32,
                                 discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3,
                                 discrim_ent_beta=1e-5)

        del sampler
Ejemplo n.º 2
0
                         parallel_dim=1 if args.rnn else 0)

discrim_net = DiscrimNet(ob_space,
                         ac_space,
                         h1=args.discrim_h1,
                         h2=args.discrim_h2)
discrim = DeterministicSAVfunc(ob_space,
                               ac_space,
                               discrim_net,
                               data_parallel=args.data_parallel)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)
optim_discrim = torch.optim.Adam(discrim_net.parameters(), args.discrim_lr)

with open(os.path.join(args.expert_dir, args.expert_fname), 'rb') as f:
    expert_epis = pickle.load(f)
expert_traj = Traj()
expert_traj.add_epis(expert_epis)
expert_traj.register_epis()
expert_rewards = [np.sum(epi['rews']) for epi in expert_epis]
expert_mean_rew = np.mean(expert_rewards)
logger.log('expert_score={}'.format(expert_mean_rew))
logger.log('expert_num_epi={}'.format(expert_traj.num_epi))

total_epi = 0
total_step = 0
max_rew = -1e6