Example #1
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space, h1=32, h2=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net)

        vf_net = VNet(self.env.observation_space, h1=32, h2=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = trpo.train(traj, pol, vf, optim_vf, 1, 24)

        del sampler
Example #2
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = CategoricalPol(self.env.ob_space, self.env.ac_space, pol_net)

        vf_net = VNet(self.env.ob_space, h1=32, h2=32)
        vf = DeterministicSVfunc(self.env.ob_space, vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=0.2,
                                     optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=32)
        result_dict = ppo_kl.train(traj=traj, pol=pol, vf=vf, kl_beta=0.1, kl_targ=0.2,
                                   optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=32, max_grad_norm=10)

        del sampler
Example #3
0
    def test_learning(self):
        t_pol_net = PolNet(self.env.observation_space,
                           self.env.action_space, h1=200, h2=100)
        s_pol_net = PolNet(self.env.observation_space,
                           self.env.action_space, h1=190, h2=90)

        t_pol = GaussianPol(
            self.env.observation_space, self.env.action_space, t_pol_net)
        s_pol = GaussianPol(
            self.env.observation_space, self.env.action_space, s_pol_net)

        student_sampler = EpiSampler(self.env, s_pol, num_parallel=1)

        optim_pol = torch.optim.Adam(s_pol.parameters(), 3e-4)

        epis = student_sampler.sample(s_pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_h_masks(traj)
        traj.register_epis()
        result_dict = on_pol_teacher_distill.train(
            traj=traj,
            student_pol=s_pol,
            teacher_pol=t_pol,
            student_optim=optim_pol,
            epoch=1,
            batchsize=32)

        del student_sampler
Example #4
0
    def test_learning_rnn(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = CategoricalPol(
            self.env.observation_space, self.env.action_space, pol_net, rnn=True)

        vf_net = VNetLSTM(self.env.observation_space, h_size=32, cell_size=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net, rnn=True)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=400)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = trpo.train(traj, pol, vf, optim_vf, 1, 2)

        del sampler
Example #5
0
    def test_learning_rnn(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net, rnn=True)

        vf_net = VNetLSTM(self.env.observation_space, h_size=32, cell_size=32)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net, rnn=True)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=400)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, 0.99)
        traj = ef.compute_advs(traj, 0.99, 0.95)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=0.2,
                                     optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=2)
        result_dict = ppo_kl.train(traj=traj, pol=pol, vf=vf, kl_beta=0.1, kl_targ=0.2,
                                   optim_pol=optim_pol, optim_vf=optim_vf, epoch=1, batch_size=2, max_grad_norm=20)

        del sampler
    def __init__(self, world_size, rank=-1, env=None, pol=None, num_parallel=8, prepro=None, seed=256):
        if rank < 0:
            assert env is not None and pol is not None

        self.world_size = world_size
        self.rank = rank

        self.r = get_redis()

        if rank < 0:
            self.env = env
            self.pol = pol
            self.num_parallel = num_parallel // world_size
            self.prepro = prepro
            self.seed = seed

            self.original_num_parallel = num_parallel

        self.scatter_from_master('env')
        self.scatter_from_master('pol')
        self.scatter_from_master('num_parallel')
        self.scatter_from_master('prepro')
        self.scatter_from_master('seed')

        self.seed = self.seed * (self.rank + 23000)

        if not rank < 0:
            self.in_node_sampler = EpiSampler(
                self.env, self.pol, self.num_parallel, self.prepro, self.seed)
            self.launch_sampler()
Example #7
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = GaussianPol(self.env.ob_space, self.env.ac_space, pol_net)

        targ_pol_net = PolNet(self.env.ob_space, self.env.ac_space, 32, 32)
        targ_pol_net.load_state_dict(pol_net.state_dict())
        targ_pol = GaussianPol(
            self.env.ob_space, self.env.ac_space, targ_pol_net)

        qf_net = QNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        qf = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space, qf_net)

        targ_qf_net = QNet(self.env.ob_space, self.env.ac_space, 32, 32)
        targ_qf_net.load_state_dict(targ_qf_net.state_dict())
        targ_qf = DeterministicSAVfunc(
            self.env.ob_space, self.env.ac_space, targ_qf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = svg.train(
            traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, 1, 32, 0.01, 0.9, 1)

        del sampler
Example #8
0
    def test_learning(self):
        ob_space = self.env.real_observation_space
        skill_space = self.env.skill_space
        ob_skill_space = self.env.observation_space
        ac_space = self.env.action_space
        ob_dim = ob_skill_space.shape[0] - 4
        f_dim = ob_dim
        def discrim_f(x): return x

        pol_net = PolNet(ob_skill_space, ac_space)
        pol = GaussianPol(ob_skill_space, ac_space, pol_net)
        qf_net1 = QNet(ob_skill_space, ac_space)
        qf1 = DeterministicSAVfunc(ob_skill_space, ac_space, qf_net1)
        targ_qf_net1 = QNet(ob_skill_space, ac_space)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(ob_skill_space, ac_space, targ_qf_net1)
        qf_net2 = QNet(ob_skill_space, ac_space)
        qf2 = DeterministicSAVfunc(ob_skill_space, ac_space, qf_net2)
        targ_qf_net2 = QNet(ob_skill_space, ac_space)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(ob_skill_space, ac_space, targ_qf_net2)
        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]
        log_alpha = nn.Parameter(torch.ones(()))

        high = np.array([np.finfo(np.float32).max]*f_dim)
        f_space = gym.spaces.Box(-high, high, dtype=np.float32)
        discrim_net = DiaynDiscrimNet(
            f_space, skill_space, h_size=100, discrim_f=discrim_f)
        discrim = DeterministicSVfunc(f_space, discrim_net)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 1e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 1e-4)
        optim_discrim = torch.optim.SGD(discrim.parameters(),
                                        lr=0.001, momentum=0.9)

        off_traj = Traj()
        sampler = EpiSampler(self.env, pol, num_parallel=1)

        epis = sampler.sample(pol, max_steps=200)
        on_traj = Traj()
        on_traj.add_epis(epis)
        on_traj = ef.add_next_obs(on_traj)
        on_traj = ef.compute_diayn_rews(
            on_traj, lambda x: diayn_sac.calc_rewards(x, 4, discrim))
        on_traj.register_epis()
        off_traj.add_traj(on_traj)
        step = on_traj.num_step
        log_alpha = nn.Parameter(np.log(0.1)*torch.ones(()))  # fix alpha
        result_dict = diayn_sac.train(
            off_traj, pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            step, 128, 5e-3, 0.99, 1, discrim, 4, True)
        discrim_losses = diayn.train(
            discrim, optim_discrim, on_traj, 32, 100, 4)

        del sampler
Example #9
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = GaussianPol(self.env.ob_space, self.env.ac_space, pol_net)

        qf_net1 = QNet(self.env.ob_space, self.env.ac_space)
        qf1 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                   qf_net1)
        targ_qf_net1 = QNet(self.env.ob_space, self.env.ac_space)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                        targ_qf_net1)

        qf_net2 = QNet(self.env.ob_space, self.env.ac_space)
        qf2 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                   qf_net2)
        targ_qf_net2 = QNet(self.env.ob_space, self.env.ac_space)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                        targ_qf_net2)

        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]

        log_alpha = nn.Parameter(torch.zeros(()))

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = sac.train(
            traj,
            pol,
            qfs,
            targ_qfs,
            log_alpha,
            optim_pol,
            optim_qfs,
            optim_alpha,
            2,
            32,
            0.01,
            0.99,
            2,
        )

        del sampler
Example #10
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32)
        pol = GaussianPol(self.env.observation_space, self.env.action_space,
                          pol_net)

        vf_net = VNet(self.env.observation_space)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        discrim_net = DiscrimNet(self.env.observation_space,
                                 self.env.action_space,
                                 h1=32,
                                 h2=32)
        discrim = DeterministicSAVfunc(self.env.observation_space,
                                       self.env.action_space, discrim_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(discrim_net.parameters(), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'),
                  'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.compute_pseudo_rews(agent_traj, discrim)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = gail.train(agent_traj,
                                 expert_traj,
                                 pol,
                                 vf,
                                 discrim,
                                 optim_vf,
                                 optim_discrim,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32,
                                 discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3,
                                 discrim_ent_beta=1e-5)

        del sampler
Example #11
0
    def setUpClass(cls):
        cls.env = GymEnv('Pendulum-v0')
        pol = RandomPol(cls.env.observation_space, cls.env.action_space)
        sampler = EpiSampler(cls.env, pol, num_parallel=1)
        epis = sampler.sample(pol, max_steps=32)

        cls.traj = Traj()
        cls.traj.add_epis(epis)
        cls.traj.register_epis()
Example #12
0
    def test_learning_rnn(self):
        def rew_func(next_obs,
                     acs,
                     mean_obs=0.,
                     std_obs=1.,
                     mean_acs=0.,
                     std_acs=1.):
            next_obs = next_obs * std_obs + mean_obs
            acs = acs * std_acs + mean_acs
            # Pendulum
            rews = -(torch.acos(next_obs[:, 0].clamp(min=-1, max=1))**2 + 0.1 *
                     (next_obs[:, 2].clamp(min=-8, max=8)**2) +
                     0.001 * acs.squeeze(-1)**2)
            rews = rews.squeeze(0)

            return rews

        # init models
        dm_net = ModelNetLSTM(self.env.observation_space,
                              self.env.action_space)
        dm = DeterministicSModel(self.env.observation_space,
                                 self.env.action_space,
                                 dm_net,
                                 rnn=True,
                                 data_parallel=False,
                                 parallel_dim=0)

        mpc_pol = MPCPol(self.env.observation_space,
                         self.env.action_space,
                         dm_net,
                         rew_func,
                         1,
                         1,
                         mean_obs=0.,
                         std_obs=1.,
                         mean_acs=0.,
                         std_acs=1.,
                         rnn=True)
        optim_dm = torch.optim.Adam(dm_net.parameters(), 1e-3)

        # sample with mpc policy
        sampler = EpiSampler(self.env, mpc_pol, num_parallel=1)
        epis = sampler.sample(mpc_pol, max_epis=1)

        traj = Traj()
        traj.add_epis(epis)
        traj = ef.add_next_obs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()
        traj.add_traj(traj)

        # train
        result_dict = mpc.train_dm(traj, dm, optim_dm, epoch=1, batch_size=1)

        del sampler
Example #13
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32,
                         deterministic=True)
        noise = OUActionNoise(self.env.action_space)
        pol = DeterministicActionNoisePol(self.env.observation_space,
                                          self.env.action_space, pol_net,
                                          noise)

        targ_pol_net = PolNet(self.env.observation_space,
                              self.env.action_space,
                              32,
                              32,
                              deterministic=True)
        targ_pol_net.load_state_dict(pol_net.state_dict())
        targ_noise = OUActionNoise(self.env.action_space)
        targ_pol = DeterministicActionNoisePol(self.env.observation_space,
                                               self.env.action_space,
                                               targ_pol_net, targ_noise)

        qf_net = QNet(self.env.observation_space,
                      self.env.action_space,
                      h1=32,
                      h2=32)
        qf = DeterministicSAVfunc(self.env.observation_space,
                                  self.env.action_space, qf_net)

        targ_qf_net = QNet(self.env.observation_space, self.env.action_space,
                           32, 32)
        targ_qf_net.load_state_dict(targ_qf_net.state_dict())
        targ_qf = DeterministicSAVfunc(self.env.observation_space,
                                       self.env.action_space, targ_qf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = ddpg.train(traj, pol, targ_pol, qf, targ_qf, optim_pol,
                                 optim_qf, 1, 32, 0.01, 0.9)

        del sampler
Example #14
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = GaussianPol(self.env.ob_space, self.env.ac_space, pol_net)

        vf_net = VNet(self.env.ob_space)
        vf = DeterministicSVfunc(self.env.ob_space, vf_net)

        rewf_net = VNet(self.env.ob_space, h1=32, h2=32)
        rewf = DeterministicSVfunc(self.env.ob_space, rewf_net)
        shaping_vf_net = VNet(self.env.ob_space, h1=32, h2=32)
        shaping_vf = DeterministicSVfunc(self.env.ob_space, shaping_vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(
            list(rewf_net.parameters()) + list(shaping_vf_net.parameters()), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'), 'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj = ef.add_next_obs(expert_traj)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.add_next_obs(agent_traj)
        agent_traj = ef.compute_pseudo_rews(
            agent_traj, rew_giver=rewf, state_only=True)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = airl.train(agent_traj, expert_traj, pol, vf, optim_vf, optim_discrim,
                                 rewf=rewf, shaping_vf=shaping_vf,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32, discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3, gamma=0.99)

        del sampler
Example #15
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space, h1=32, h2=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'), 'rb') as f:
            expert_epis = pickle.load(f)
        train_epis, test_epis = ef.train_test_split(
            expert_epis, train_size=0.7)
        train_traj = Traj()
        train_traj.add_epis(train_epis)
        train_traj.register_epis()
        test_traj = Traj()
        test_traj.add_epis(test_epis)
        test_traj.register_epis()

        result_dict = behavior_clone.train(
            train_traj, pol, optim_pol,
            256
        )

        del sampler
Example #16
0
    def test_learning(self):
        qf_net = QNet(self.env.observation_space, self.env.action_space, 32,
                      32)
        lagged_qf_net = QNet(self.env.observation_space, self.env.action_space,
                             32, 32)
        lagged_qf_net.load_state_dict(qf_net.state_dict())
        targ_qf1_net = QNet(self.env.observation_space, self.env.action_space,
                            32, 32)
        targ_qf1_net.load_state_dict(qf_net.state_dict())
        targ_qf2_net = QNet(self.env.observation_space, self.env.action_space,
                            32, 32)
        targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
        qf = DeterministicSAVfunc(self.env.observation_space,
                                  self.env.action_space, qf_net)
        lagged_qf = DeterministicSAVfunc(self.env.observation_space,
                                         self.env.action_space, lagged_qf_net)
        targ_qf1 = CEMDeterministicSAVfunc(self.env.observation_space,
                                           self.env.action_space,
                                           targ_qf1_net,
                                           num_sampling=60,
                                           num_best_sampling=6,
                                           num_iter=2,
                                           multivari=False)
        targ_qf2 = DeterministicSAVfunc(self.env.observation_space,
                                        self.env.action_space, targ_qf2_net)

        pol = ArgmaxQfPol(self.env.observation_space,
                          self.env.action_space,
                          targ_qf1,
                          eps=0.2)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)
        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = qtopt.train(traj, qf, lagged_qf, targ_qf1, targ_qf2,
                                  optim_qf, 1000, 32, 0.9999, 0.995, 'mse')

        del sampler
Example #17
0
    def setUpClass(cls):
        env = GymEnv('Pendulum-v0')
        random_pol = RandomPol(cls.env.observation_space, cls.env.action_space)
        sampler = EpiSampler(cls.env, pol, num_parallel=1)
        epis = sampler.sample(pol, max_steps=32)
        traj = Traj()
        traj.add_epis(epis)
        traj.register_epis()

        cls.num_step = traj.num_step

        make_redis('localhost', '6379')
        cls.r = get_redis()

        cls.r.set('env', env)
        cls.r.set('traj', traj)

        pol_net = PolNet(env.observation_space, env.action_space)
        gpol = GaussianPol(env.observation_space, env.action_space, pol_net)
        pol_net = PolNet(env.observation_space,
                         env.action_space, deterministic=True)
        dpol = DeterministicActionNoisePol(
            env.observation_space, env.action_space, pol_net)
        model_net = ModelNet(env.observation_space, env.action_space)
        mpcpol = MPCPol(env.observation_space,
                        env.action_space, model_net, rew_func)
        q_net = QNet(env.observation_space, env.action_space)
        qfunc = DeterministicSAVfunc(
            env.observation_space, env.action_space, q_net)
        aqpol = ArgmaxQfPol(env.observation_space, env.action_space, qfunc)
        v_net = VNet(env.observation_space)
        vfunc = DeterministicSVfunc(env.observation_space, v_net)

        cls.r.set('gpol', cloudpickle.dumps(gpol))
        cls.r.set('dpol', cloudpickle.dumps(dpol))
        cls.r.set('mpcpol', cloudpickle.dumps(mpcpol))
        cls.r.set('qfunc', cloudpickle.dumps(qfunc))
        cls.r.set('aqpol', cloudpickle.dumps(aqpol))
        cls.r.set('vfunc', cloudpickle.dumps(vfunc))

        c2d = C2DEnv(env)
        pol_net = PolNet(c2d.observation_space, c2d.action_space)
        mcpol = MultiCategoricalPol(
            env.observation_space, env.action_space, pol_net)

        cls.r.set('mcpol', cloudpickle.dumps(mcpol))
Example #18
0
    def __init__(self,
                 world_size,
                 rank=-1,
                 env=None,
                 pol=None,
                 num_parallel=8,
                 prepro=None,
                 seed=256,
                 flush_db=False):
        if rank < 0:
            assert env is not None and pol is not None

        self.world_size = world_size
        self.rank = rank

        self.r = get_redis()

        if flush_db:
            # reset DB
            keys = self.r.keys(pattern="*_trigger_*")
            if keys:
                self.r.delete(*keys)

        if rank < 0:
            self.env = env
            self.pol = pol
            self.num_parallel = num_parallel // world_size
            self.prepro = prepro
            self.seed = seed

            self.original_num_parallel = num_parallel

        self.scatter_from_master('env')
        self.scatter_from_master('pol')
        self.scatter_from_master('num_parallel')
        self.scatter_from_master('prepro')
        self.scatter_from_master('seed')

        self.seed = self.seed * (self.rank + 23000)

        if not rank < 0:
            self.in_node_sampler = EpiSampler(self.env, self.pol,
                                              self.num_parallel, self.prepro,
                                              self.seed)
            self.launch_sampler()
Example #19
0
    env = C2DEnv(env)

ob_space = env.observation_space
ac_space = env.action_space

random_pol = RandomPol(ob_space, ac_space)

######################
### Model-Based RL ###
######################

### Prepare the dataset D_RAND ###

# Performing rollouts to collect training data
rand_sampler = EpiSampler(env,
                          random_pol,
                          num_parallel=args.num_parallel,
                          seed=args.seed)

epis = rand_sampler.sample(random_pol, max_epis=args.num_random_rollouts)
epis = add_noise_to_init_obs(epis, args.noise_to_init_obs)
traj = Traj(traj_device='cpu')
traj.add_epis(epis)
traj = ef.add_next_obs(traj)
traj = ef.compute_h_masks(traj)
# obs, next_obs, and acs should become mean 0, std 1
traj, mean_obs, std_obs, mean_acs, std_acs = ef.normalize_obs_and_acs(traj)
traj.register_epis()

del rand_sampler

### Train Dynamics Model ###
Example #20
0
    if args.rnn:
        pol_net = PolNetLSTM(observation_space,
                             action_space,
                             h_size=256,
                             cell_size=256)
    else:
        pol_net = PolNet(observation_space, action_space)
    if isinstance(action_space, gym.spaces.Box):
        pol = GaussianPol(observation_space, action_space, pol_net, args.rnn)
    elif isinstance(action_space, gym.spaces.Discrete):
        pol = CategoricalPol(observation_space, action_space, pol_net,
                             args.rnn)
    elif isinstance(action_space, gym.spaces.MultiDiscrete):
        pol = MultiCategoricalPol(observation_space, action_space, pol_net,
                                  args.rnn)
    else:
        raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

sampler = EpiSampler(env, pol, num_parallel=1, seed=args.seed)

with open(os.path.join(args.pol_dir, 'models', args.pol_fname), 'rb') as f:
    pol.load_state_dict(
        torch.load(f, map_location=lambda storage, location: storage))

epis = sampler.sample(pol, max_epis=args.num_epis)

rewards = [np.sum(epi['rews']) for epi in epis]
mean_rew = np.mean(rewards)
logger.log('score={}'.format(mean_rew))
del sampler
Example #21
0
 def test_epi_sampler(self):
     sampler = EpiSampler(self.env, self.pol, num_parallel=1)
     epis = sampler.sample(self.pol, max_epis=2)
     assert len(epis) >= 2
Example #22
0
class DistributedEpiSampler(object):
    """
    A sampler which sample episodes.

    Parameters
    ----------
    world_size : int
        Number of nodes
    rank : int
        -1 represent master node.
    env : gym.Env
    pol : Pol
    num_parallel : int
        Number of processes
    prepro : Prepro
    seed : int
    """

    def __init__(self, world_size, rank=-1, env=None, pol=None, num_parallel=8, prepro=None, seed=256):
        if rank < 0:
            assert env is not None and pol is not None

        self.world_size = world_size
        self.rank = rank

        self.r = get_redis()

        if rank < 0:
            self.env = env
            self.pol = pol
            self.num_parallel = num_parallel // world_size
            self.prepro = prepro
            self.seed = seed

            self.original_num_parallel = num_parallel

        self.scatter_from_master('env')
        self.scatter_from_master('pol')
        self.scatter_from_master('num_parallel')
        self.scatter_from_master('prepro')
        self.scatter_from_master('seed')

        self.seed = self.seed * (self.rank + 23000)

        if not rank < 0:
            self.in_node_sampler = EpiSampler(
                self.env, self.pol, self.num_parallel, self.prepro, self.seed)
            self.launch_sampler()

    def __del__(self):
        if not self.rank < 0:
            del self.in_node_sampler

    def launch_sampler(self):
        while True:
            self.scatter_from_master('pol')
            self.scatter_from_master('max_epis')
            self.scatter_from_master('max_steps')
            self.scatter_from_master('deterministic')

            self.epis = self.in_node_sampler.sample(
                self.pol, self.max_epis, self.max_steps, self.deterministic)

            self.gather_to_master('epis')

    def scatter_from_master(self, key):

        if self.rank < 0:
            obj = getattr(self, key)
            self.r.set(key, cloudpickle.dumps(obj))
            triggers = {key + '_trigger' +
                        "_{}".format(rank): '1' for rank in range(self.world_size)}
            self.r.mset(triggers)
            while True:
                time.sleep(0.1)
                values = self.r.mget(triggers)
                if all([_int(v) == 0 for v in values]):
                    break
        else:
            while True:
                time.sleep(0.1)
                trigger = self.r.get(key + '_trigger' +
                                     "_{}".format(self.rank))
                if _int(trigger) == 1:
                    break
            obj = cloudpickle.loads(self.r.get(key))
            setattr(self, key, obj)
            self.r.set(key + '_trigger' + "_{}".format(self.rank), '0')

    def gather_to_master(self, key):
        """
        This method assume that obj is summable to list.
        """

        if self.rank < 0:
            num_done = 0
            objs = []
            while True:
                time.sleep(0.1)
                # This for iteration can be faster.
                for rank in range(self.world_size):
                    trigger = self.r.get(key + '_trigger' + "_{}".format(rank))
                    if _int(trigger) == 1:
                        obj = cloudpickle.loads(
                            self.r.get(key + "_{}".format(rank)))
                        objs += obj
                        self.r.set(key + '_trigger' + "_{}".format(rank), '0')
                        num_done += 1
                if num_done == self.world_size:
                    break
            setattr(self, key, objs)
        else:
            obj = getattr(self, key)
            self.r.set(key + "_{}".format(self.rank), cloudpickle.dumps(obj))
            self.r.set(key + '_trigger' + "_{}".format(self.rank), '1')
            while True:
                time.sleep(0.1)
                if _int(self.r.get(key + '_trigger' + "_{}".format(self.rank))) == 0:
                    break

    def sample(self, pol, max_epis=None, max_steps=None, deterministic=False):
        """
        This method should be called in master node.
        """
        self.pol = pol
        self.max_epis = max_epis // self.world_size if max_epis is not None else None
        self.max_steps = max_steps // self.world_size if max_steps is not None else None
        self.deterministic = deterministic

        self.scatter_from_master('pol')
        self.scatter_from_master('max_epis')
        self.scatter_from_master('max_steps')
        self.scatter_from_master('deterministic')

        self.gather_to_master('epis')

        return self.epis
Example #23
0
class DistributedEpiSampler(object):
    """
    A sampler which sample episodes.

    Parameters
    ----------
    world_size : int
        Number of nodes
    rank : int
        -1 represent master node.
    env : gym.Env
    pol : Pol
    num_parallel : int
        Number of processes
    prepro : Prepro
    seed : int
    """
    def __init__(self,
                 world_size,
                 rank=-1,
                 env=None,
                 pol=None,
                 num_parallel=8,
                 prepro=None,
                 seed=256,
                 flush_db=False):
        if rank < 0:
            assert env is not None and pol is not None

        self.world_size = world_size
        self.rank = rank

        self.r = get_redis()

        if flush_db:
            # reset DB
            keys = self.r.keys(pattern="*_trigger_*")
            if keys:
                self.r.delete(*keys)

        if rank < 0:
            self.env = env
            self.pol = pol
            self.num_parallel = num_parallel // world_size
            self.prepro = prepro
            self.seed = seed

            self.original_num_parallel = num_parallel

        self.scatter_from_master('env')
        self.scatter_from_master('pol')
        self.scatter_from_master('num_parallel')
        self.scatter_from_master('prepro')
        self.scatter_from_master('seed')

        self.seed = self.seed * (self.rank + 23000)

        if not rank < 0:
            self.in_node_sampler = EpiSampler(self.env, self.pol,
                                              self.num_parallel, self.prepro,
                                              self.seed)
            self.launch_sampler()

    def __del__(self):
        if not self.rank < 0:
            del self.in_node_sampler

    def launch_sampler(self):
        while True:
            self.scatter_from_master('pol')
            self.scatter_from_master('max_epis')
            self.scatter_from_master('max_steps')
            self.scatter_from_master('deterministic')

            self.epis = self.in_node_sampler.sample(self.pol, self.max_epis,
                                                    self.max_steps,
                                                    self.deterministic)

            self.gather_to_master('epis')

    def sync(self, keys, target_value):
        """Wait until all `keys` become `target_value`
        """
        while True:
            values = self.r.mget(keys)
            if all([_int(v) == target_value for v in values]):
                break
            time.sleep(0.1)

    def wait_trigger(self, trigger):
        """Wait until `trigger` become 1
        """
        self.sync(trigger, 1)

    def wait_trigger_processed(self, trigger):
        """Wait until `trigger` become 0
        """
        self.sync(trigger, 0)

    def set_trigger(self, trigger, value='1'):
        """Set all triggers to `value`
        """
        if not isinstance(trigger, (list, tuple)):
            trigger = [trigger]
        mapping = {k: value for k in trigger}
        self.r.mset(mapping)

    def reset_trigger(self, trigger):
        """Set all triggers to 0
        """
        self.set_trigger(trigger, value='0')

    def wait_trigger_completion(self, trigger):
        """Set trigger to 1, then wait until it become 0
        """
        self.set_trigger(trigger)
        self.wait_trigger_processed(trigger)

    def scatter_from_master(self, key):
        """
        master: set `key` to DB, then set trigger and wait sampler completion
        sampler: wait trigger, then get `key` and reset trigger
        """

        if self.rank < 0:
            obj = getattr(self, key)
            self.r.set(key, cloudpickle.dumps(obj))
            trigger = [
                '{}_trigger_{}'.format(key, rank)
                for rank in range(self.world_size)
            ]
            self.wait_trigger_completion(trigger)
        else:
            trigger = '{}_trigger_{}'.format(key, self.rank)
            self.wait_trigger(trigger)
            obj = cloudpickle.loads(self.r.get(key))
            setattr(self, key, obj)
            self.reset_trigger(trigger)

    def gather_to_master(self, key):
        """
        master: wait trigger, then get the value from DB
        sampler: set `key` to DB, then wait master fetch

        This method assume that obj is summable to list.
        """

        if self.rank < 0:
            num_done = 0
            objs = []
            while True:
                time.sleep(0.1)
                # This for iteration can be faster.
                for rank in range(self.world_size):
                    trigger = self.r.get(key + '_trigger' + "_{}".format(rank))
                    if _int(trigger) == 1:
                        obj = cloudpickle.loads(
                            self.r.get(key + "_{}".format(rank)))
                        objs += obj
                        self.r.set(key + '_trigger' + "_{}".format(rank), '0')
                        num_done += 1
                if num_done == self.world_size:
                    break
            setattr(self, key, objs)
        else:
            obj = getattr(self, key)
            self.r.set(key + "_{}".format(self.rank), cloudpickle.dumps(obj))
            trigger = '{}_trigger_{}'.format(key, self.rank)
            self.wait_trigger_completion(trigger)

    def sample(self, pol, max_epis=None, max_steps=None, deterministic=False):
        """
        This method should be called in master node.
        """
        self.pol = pol
        self.max_epis = max_epis // self.world_size if max_epis is not None else None
        self.max_steps = max_steps // self.world_size if max_steps is not None else None
        self.deterministic = deterministic

        self.scatter_from_master('pol')
        self.scatter_from_master('max_epis')
        self.scatter_from_master('max_steps')
        self.scatter_from_master('deterministic')

        self.gather_to_master('epis')

        return self.epis
Example #24
0
    def test_learning(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net, rnn=True)

        qf_net1 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net1, rnn=True)
        targ_qf_net1 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net1, rnn=True)

        qf_net2 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net2, rnn=True)
        targ_qf_net2 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net2, rnn=True)

        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]

        log_alpha = nn.Parameter(torch.zeros(()))

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        max_pri = traj.get_max_pri()
        traj = ef.set_all_pris(traj, max_pri)
        traj = ef.compute_seq_pris(traj, 4)
        traj = ef.compute_h_masks(traj)
        for i in range(len(qfs)):
            traj = ef.compute_hs(
                traj, qfs[i], hs_name='q_hs'+str(i), input_acs=True)
            traj = ef.compute_hs(
                traj, targ_qfs[i], hs_name='targ_q_hs'+str(i), input_acs=True)
        traj.register_epis()

        result_dict = r2d2_sac.train(
            traj,
            pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            2, 32, 4, 2,
            0.01, 0.99, 2,
        )

        del sampler
Example #25
0
    pol = CategoricalPol(ob_space, ac_space, pol_net, data_parallel=args.data_parallel)
elif isinstance(ac_space, gym.spaces.MultiDiscrete):
    pol = MultiCategoricalPol(ob_space, ac_space, pol_net, data_parallel=args.data_parallel)
else:
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

if args.pol:
    pol.load_state_dict(torch.load(args.pol, map_location=lambda storage, loc: storage))

vf_net = VNetSNAILConstant(ob_space, args.timestep, args.num_channels, num_keys=args.num_keys, num_tc_fils=args.num_tc_fils, no_attention=args.no_attention, use_pe=args.use_pe)
vf = DeterministicSVfunc(ob_space, vf_net, data_parallel=args.data_parallel)

if args.vf:
    vf.load_state_dict(torch.load(args.vf, map_location=lambda storage, loc: storage))

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)
if center_env is not None:
    center_sampler = EpiSampler(center_env, pol, num_parallel=1, seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
if args.optim_pol:
    optim_pol.load_state_dict(torch.load(args.optim_pol, map_location=lambda storage, loc: storage))
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)
if args.optim_vf:
    optim_vf.load_state_dict(torch.load(args.optim_vf, map_location=lambda storage, loc: storage))

total_epi = 0
total_step = 0
max_mean_rew = -1e6
max_min_rew = -1e6
max_center_rew = -1e6
Example #26
0
    def train(self):
        args = self.args

        # TODO: cuda seems to be broken, I don't care about it right now
        # if args.cuda:
        #     # current_obs = current_obs.cuda()
        #     rollouts.cuda()

        self.train_start_time = time.time()
        total_epi = 0
        total_step = 0
        max_rew = -1e6
        sampler = None

        score_file = os.path.join(self.logger.get_logdir(), "progress.csv")
        logger.add_tabular_output(score_file)

        num_total_frames = args.num_total_frames

        mirror_function = None
        if args.mirror_tuples and hasattr(self.env.unwrapped,
                                          "mirror_indices"):
            mirror_function = get_mirror_function(
                **self.env.unwrapped.mirror_indices)
            num_total_frames *= 2
            if not args.tanh_finish:
                warnings.warn(
                    "When `mirror_tuples` is `True`,"
                    " `tanh_finish` should be set to `True` as well."
                    " Otherwise there is a chance of the training blowing up.")

        while num_total_frames > total_step:
            # setup the correct curriculum learning environment/parameters
            new_curriculum = self.curriculum_handler(total_step /
                                                     args.num_total_frames)

            if total_step == 0 or new_curriculum:
                if sampler is not None:
                    del sampler
                sampler = EpiSampler(
                    self.env,
                    self.pol,
                    num_parallel=self.args.num_processes,
                    seed=self.args.seed + total_step,  # TODO: better fix?
                )

            with measure("sample"):
                epis = sampler.sample(self.pol,
                                      max_steps=args.num_steps *
                                      args.num_processes)

            with measure("train"):
                with measure("epis"):
                    traj = Traj()
                    traj.add_epis(epis)

                    traj = ef.compute_vs(traj, self.vf)
                    traj = ef.compute_rets(traj, args.decay_gamma)
                    traj = ef.compute_advs(traj, args.decay_gamma,
                                           args.gae_lambda)
                    traj = ef.centerize_advs(traj)
                    traj = ef.compute_h_masks(traj)
                    traj.register_epis()

                    if mirror_function:
                        traj.add_traj(mirror_function(traj))

                # if args.data_parallel:
                #     self.pol.dp_run = True
                #     self.vf.dp_run = True

                result_dict = ppo_clip.train(
                    traj=traj,
                    pol=self.pol,
                    vf=self.vf,
                    clip_param=args.clip_eps,
                    optim_pol=self.optim_pol,
                    optim_vf=self.optim_vf,
                    epoch=args.epoch_per_iter,
                    batch_size=args.batch_size
                    if not args.rnn else args.rnn_batch_size,
                    max_grad_norm=args.max_grad_norm,
                )

                # if args.data_parallel:
                #     self.pol.dp_run = False
                #     self.vf.dp_run = False

            ## append the metrics to the `results_dict` (reported in the progress.csv)
            result_dict.update(self.get_extra_metrics(epis))

            total_epi += traj.num_epi
            step = traj.num_step
            total_step += step
            rewards = [np.sum(epi["rews"]) for epi in epis]
            mean_rew = np.mean(rewards)
            logger.record_results(
                self.logger.get_logdir(),
                result_dict,
                score_file,
                total_epi,
                step,
                total_step,
                rewards,
                plot_title=args.env,
            )

            if mean_rew > max_rew:
                self.save_models("max")
                max_rew = mean_rew

            self.save_models("last")

            self.scheduler_pol.step()
            self.scheduler_vf.step()

            del traj
Example #27
0
pol = MultiCategoricalPol(ob_space,
                          ac_space,
                          pol_net,
                          True,
                          data_parallel=args.data_parallel,
                          parallel_dim=1)

vf_net = VNetLSTM(ob_space, h_size=args.h_size, cell_size=args.cell_size)
vf = DeterministicSVfunc(ob_space,
                         vf_net,
                         True,
                         data_parallel=args.data_parallel,
                         parallel_dim=1)

sampler1 = EpiSampler(env1,
                      pol,
                      num_parallel=args.num_parallel,
                      seed=args.seed)
sampler2 = EpiSampler(env2,
                      pol,
                      num_parallel=args.num_parallel,
                      seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)

total_epi = 0
total_step = 0
max_rew = -1e6
while args.max_epis > total_epi:
    with measure('sample'):
        epis1 = sampler1.sample(pol, max_epis=args.max_epis_per_iter)
Example #28
0
else:
    raise ValueError('Only Box, Discrete and Multidiscrete are supported')

if args.teacher_pol:
    t_pol.load_state_dict(torch.load(
        os.path.join(args.teacher_dir, args.teacher_fname)))

if args.rnn:
    s_vf_net = VNetLSTM(observation_space, h_size=256, cell_size=256)
else:
    s_vf_net = VNet(observation_space)

if args.sampling_policy == 'teacher':
    teacher_sampler = EpiSampler(
        env,
        t_pol,
        num_parallel=args.num_parallel,
        seed=args.seed)

student_sampler = EpiSampler(
    env,
    s_pol,
    num_parallel=args.num_parallel,
    seed=args.seed)

optim_pol = torch.optim.Adam(s_pol_net.parameters(), args.pol_lr)

total_epi = 0
total_step = 0
max_rew = -1e6
Example #29
0
                                data_parallel=args.data_parallel)

# q-networkの最適化手法
print('optimizer')
optim_qf = torch.optim.Adam(qf_net.parameters(), args.qf_lr)

# epsilon-greedy policy
print('Policy')
pol = ArgmaxQfPol(flattend_observation_space,
                  action_space,
                  targ_qf1,
                  eps=args.eps)

# replay bufferからサンプリング?
print('sampler')
sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

# off-policy experience. Traj=(s,a,r,s')
off_traj = Traj(args.max_steps_off, traj_device='cpu')

total_epi = 0
total_step = 0
total_grad_step = 0  # パラメータ更新回数
num_update_lagged = 0  # lagged netの更新回数
max_rew = -1000

print('start')
while args.max_epis > total_epi:
    with measure('sample'):
        print('sampling')
        # policyにしたがって行動し、経験を貯める(env.stepをone_epiの__init__内で行っている)
vf = DeterministicSVfunc(observation_space, vf_net)

# optimizer to both models
optim_pol = torch.optim.Adam(pol_net.parameters(), lr=1e-4)
optim_vf = torch.optim.Adam(vf_net.parameters(), lr=3e-4)

#  arguments of PPO
gamma = 0.99
lam = 0.95
clip_param = 0.2
epoch_per_iter = 4
batch_size = 64
max_grad_norm = 0.5
num_parallel = 16

sampler = EpiSampler(env, pol, num_parallel=num_parallel, seed=seed)

# machina automatically write log (model ,scores, etc..)
if not os.path.exists(log_dir_name):
    os.mkdir(log_dir_name)
if not os.path.exists(f'{log_dir_name}/models'):
    os.mkdir(f'{log_dir_name}/models')
score_file = os.path.join(log_dir_name, 'progress.csv')
logger.add_tabular_output(score_file)

# counter and record for loop
total_epi = 0
total_step = 0
max_rew = -inf
max_episodes = 1000000
max_steps_per_iter = 3000