Ejemplo n.º 1
0
    def test_learning(self):
        pol_net = PolNet(self.env.ob_space, self.env.ac_space, h1=32, h2=32)
        pol = GaussianPol(self.env.ob_space, self.env.ac_space, pol_net)

        qf_net1 = QNet(self.env.ob_space, self.env.ac_space)
        qf1 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                   qf_net1)
        targ_qf_net1 = QNet(self.env.ob_space, self.env.ac_space)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                        targ_qf_net1)

        qf_net2 = QNet(self.env.ob_space, self.env.ac_space)
        qf2 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                   qf_net2)
        targ_qf_net2 = QNet(self.env.ob_space, self.env.ac_space)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(self.env.ob_space, self.env.ac_space,
                                        targ_qf_net2)

        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]

        log_alpha = nn.Parameter(torch.zeros(()))

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = sac.train(
            traj,
            pol,
            qfs,
            targ_qfs,
            log_alpha,
            optim_pol,
            optim_qfs,
            optim_alpha,
            2,
            32,
            0.01,
            0.99,
            2,
        )

        del sampler
Ejemplo n.º 2
0
        on_traj.register_epis()

        off_traj.add_traj(on_traj)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step

        if args.data_parallel:
            pol.dp_run = True
            for qf, targ_qf in zip(qfs, targ_qfs):
                qf.dp_run = True
                targ_qf.dp_run = True

        result_dict = sac.train(off_traj, pol, qfs, targ_qfs, log_alpha,
                                optim_pol, optim_qfs, optim_alpha, step,
                                args.batch_size, args.tau, args.gamma,
                                args.sampling, not args.no_reparam)

        if args.data_parallel:
            pol.dp_run = False
            for qf, targ_qf in zip(qfs, targ_qfs):
                qf.dp_run = False
                targ_qf.dp_run = False

    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(args.log,
                          result_dict,
                          score_file,
                          total_epi,
                          step,
Ejemplo n.º 3
0
                                      max_grad_norm=args.max_grad_norm)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step

        off_traj.add_traj(on_traj)

        result_dict2 = sac.train(
            off_traj,
            pol,
            [qf],
            [targ_qf],
            log_alpha,
            optim_pol,
            [optim_qf],
            optim_alpha,
            100,
            args.batch_size,
            args.tau,
            args.gamma,
            args.sampling,
        )

    result_dict1.update(result_dict2)

    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(args.log,
                          result_dict1,
                          score_file,