Exemplo n.º 1
0
score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)
logger.add_tensorboard_output(args.log)

env = GymEnv(args.env_name, log_dir=os.path.join(
    args.log, 'movie'), record_video=args.record)
env.env.seed(args.seed)

observation_space = env.observation_space
action_space = env.action_space

pol_net = PolNetLSTM(observation_space, action_space)
pol = GaussianPol(observation_space, action_space, pol_net, rnn=True)

qf_net1 = QNetLSTM(observation_space, action_space)
qf1 = DeterministicSAVfunc(observation_space, action_space, qf_net1, rnn=True)
targ_qf_net1 = QNetLSTM(observation_space, action_space)
targ_qf_net1.load_state_dict(qf_net1.state_dict())
targ_qf1 = DeterministicSAVfunc(
    observation_space, action_space, targ_qf_net1, rnn=True)

qf_net2 = QNetLSTM(observation_space, action_space)
qf2 = DeterministicSAVfunc(observation_space, action_space, qf_net2, rnn=True)
targ_qf_net2 = QNetLSTM(observation_space, action_space)
targ_qf_net2.load_state_dict(qf_net2.state_dict())
targ_qf2 = DeterministicSAVfunc(
    observation_space, action_space, targ_qf_net2, rnn=True)

qfs = [qf1, qf2]
targ_qfs = [targ_qf1, targ_qf2]
Exemplo n.º 2
0
    def test_learning(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net, rnn=True)

        qf_net1 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net1, rnn=True)
        targ_qf_net1 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net1, rnn=True)

        qf_net2 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net2, rnn=True)
        targ_qf_net2 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net2, rnn=True)

        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]

        log_alpha = nn.Parameter(torch.zeros(()))

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        max_pri = traj.get_max_pri()
        traj = ef.set_all_pris(traj, max_pri)
        traj = ef.compute_seq_pris(traj, 4)
        traj = ef.compute_h_masks(traj)
        for i in range(len(qfs)):
            traj = ef.compute_hs(
                traj, qfs[i], hs_name='q_hs'+str(i), input_acs=True)
            traj = ef.compute_hs(
                traj, targ_qfs[i], hs_name='targ_q_hs'+str(i), input_acs=True)
        traj.register_epis()

        result_dict = r2d2_sac.train(
            traj,
            pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            2, 32, 4, 2,
            0.01, 0.99, 2,
        )

        del sampler