コード例 #1
0
ファイル: run_r2d2_sac.py プロジェクト: yuishihara/machina
            on_traj = ef.compute_hs(
                on_traj, qfs[i], hs_name='q_hs'+str(i), input_acs=True)
            on_traj = ef.compute_hs(
                on_traj, targ_qfs[i], hs_name='targ_q_hs'+str(i), input_acs=True)
        on_traj.register_epis()

        off_traj.add_traj(on_traj)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step

        result_dict = r2d2_sac.train(
            off_traj,
            pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            step//50, args.rnn_batch_size, args.seq_length, args.burn_in_length,
            args.tau, args.gamma, args.sampling, not args.no_reparam
        )

    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(args.log, result_dict, score_file,
                          total_epi, step, total_step,
                          rewards,
                          plot_title=args.env_name)

    if mean_rew > max_rew:
        torch.save(pol.state_dict(), os.path.join(
            args.log, 'models', 'pol_max.pkl'))
        torch.save(qf1.state_dict(), os.path.join(
コード例 #2
0
    def test_learning(self):
        pol_net = PolNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        pol = GaussianPol(self.env.observation_space,
                          self.env.action_space, pol_net, rnn=True)

        qf_net1 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net1, rnn=True)
        targ_qf_net1 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net1, rnn=True)

        qf_net2 = QNetLSTM(self.env.observation_space,
                           self.env.action_space, h_size=32, cell_size=32)
        qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, qf_net2, rnn=True)
        targ_qf_net2 = QNetLSTM(
            self.env.observation_space, self.env.action_space, h_size=32, cell_size=32)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(
            self.env.observation_space, self.env.action_space, targ_qf_net2, rnn=True)

        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]

        log_alpha = nn.Parameter(torch.zeros(()))

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        max_pri = traj.get_max_pri()
        traj = ef.set_all_pris(traj, max_pri)
        traj = ef.compute_seq_pris(traj, 4)
        traj = ef.compute_h_masks(traj)
        for i in range(len(qfs)):
            traj = ef.compute_hs(
                traj, qfs[i], hs_name='q_hs'+str(i), input_acs=True)
            traj = ef.compute_hs(
                traj, targ_qfs[i], hs_name='targ_q_hs'+str(i), input_acs=True)
        traj.register_epis()

        result_dict = r2d2_sac.train(
            traj,
            pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            2, 32, 4, 2,
            0.01, 0.99, 2,
        )

        del sampler