Esempio n. 1
0
    def test_learning(self):
        qf_net = QNet(self.env.observation_space, self.env.action_space, 32,
                      32)
        lagged_qf_net = QNet(self.env.observation_space, self.env.action_space,
                             32, 32)
        lagged_qf_net.load_state_dict(qf_net.state_dict())
        targ_qf1_net = QNet(self.env.observation_space, self.env.action_space,
                            32, 32)
        targ_qf1_net.load_state_dict(qf_net.state_dict())
        targ_qf2_net = QNet(self.env.observation_space, self.env.action_space,
                            32, 32)
        targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
        qf = DeterministicSAVfunc(self.env.observation_space,
                                  self.env.action_space, qf_net)
        lagged_qf = DeterministicSAVfunc(self.env.observation_space,
                                         self.env.action_space, lagged_qf_net)
        targ_qf1 = CEMDeterministicSAVfunc(self.env.observation_space,
                                           self.env.action_space,
                                           targ_qf1_net,
                                           num_sampling=60,
                                           num_best_sampling=6,
                                           num_iter=2,
                                           multivari=False)
        targ_qf2 = DeterministicSAVfunc(self.env.observation_space,
                                        self.env.action_space, targ_qf2_net)

        pol = ArgmaxQfPol(self.env.observation_space,
                          self.env.action_space,
                          targ_qf1,
                          eps=0.2)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)
        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = qtopt.train(traj, qf, lagged_qf, targ_qf1, targ_qf2,
                                  optim_qf, 1000, 32, 0.9999, 0.995, 'mse')

        del sampler
Esempio n. 2
0
qf_net = QTOptNet(observation_space, action_space)
qf = DeterministicSAVfunc(
    flattend_observation_space,
    action_space,
    qf_net,
    data_parallel=args.data_parallel)  # 決定的行動状態価値関数?q-netの出力の形を少し整える

# target Q network theta1
print('target1_net')
targ_qf1_net = QTOptNet(observation_space, action_space)
targ_qf1_net.load_state_dict(qf_net.state_dict())  # model(重み)をロード(q-netからコピー)
targ_qf1 = CEMDeterministicSAVfunc(
    flattend_observation_space,
    action_space,
    targ_qf1_net,
    num_sampling=args.num_sampling,
    num_best_sampling=args.num_best_sampling,
    num_iter=args.num_iter,
    multivari=args.multivari,
    data_parallel=args.data_parallel,
    save_memory=args.save_memory)  #CrossEntropy Methodよくわからん

# lagged network
print('lagged_net')
lagged_qf_net = QTOptNet(observation_space, action_space)
lagged_qf_net.load_state_dict(
    qf_net.state_dict())  # model(重み)をロード(theta1からコピー)
lagged_qf = DeterministicSAVfunc(flattend_observation_space,
                                 action_space,
                                 lagged_qf_net,
                                 data_parallel=args.data_parallel)
Esempio n. 3
0
action_space = env.action_space

qf_net = QNet(observation_space, action_space, args.h1, args.h2)
lagged_qf_net = QNet(observation_space, action_space, args.h1, args.h2)
lagged_qf_net.load_state_dict(qf_net.state_dict())
targ_qf1_net = QNet(observation_space, action_space, args.h1, args.h2)
targ_qf1_net.load_state_dict(qf_net.state_dict())
targ_qf2_net = QNet(observation_space, action_space, args.h1, args.h2)
targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
qf = DeterministicSAVfunc(observation_space, action_space, qf_net)
lagged_qf = DeterministicSAVfunc(observation_space, action_space,
                                 lagged_qf_net)
targ_qf1 = CEMDeterministicSAVfunc(observation_space,
                                   action_space,
                                   targ_qf1_net,
                                   num_sampling=args.num_sampling,
                                   num_best_sampling=args.num_best_sampling,
                                   num_iter=args.num_iter,
                                   multivari=args.multivari,
                                   save_memory=args.save_memory)
targ_qf2 = DeterministicSAVfunc(observation_space, action_space, targ_qf2_net)

pol = ArgmaxQfPol(observation_space, action_space, targ_qf1, eps=args.eps)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_qf = torch.optim.Adam(qf_net.parameters(), args.qf_lr)

off_traj = Traj(args.max_steps_off, traj_device='cpu')

total_epi = 0
total_step = 0
Esempio n. 4
0
targ_qf1_net.load_state_dict(qf_net.state_dict())
targ_qf2_net = QNet(ob_space, ac_space, args.h1, args.h2)
targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
qf = DeterministicSAVfunc(ob_space,
                          ac_space,
                          qf_net,
                          data_parallel=args.data_parallel)
lagged_qf = DeterministicSAVfunc(ob_space,
                                 ac_space,
                                 lagged_qf_net,
                                 data_parallel=args.data_parallel)
targ_qf1 = CEMDeterministicSAVfunc(ob_space,
                                   ac_space,
                                   targ_qf1_net,
                                   num_sampling=args.num_sampling,
                                   num_best_sampling=args.num_best_sampling,
                                   num_iter=args.num_iter,
                                   multivari=args.multivari,
                                   data_parallel=args.data_parallel,
                                   save_memory=args.save_memory)
targ_qf2 = DeterministicSAVfunc(ob_space,
                                ac_space,
                                targ_qf2_net,
                                data_parallel=args.data_parallel)

pol = ArgmaxQfPol(ob_space, ac_space, targ_qf1, eps=args.eps)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_qf = torch.optim.Adam(qf_net.parameters(), args.qf_lr)
Esempio n. 5
0
action_space = env.action_space

# Q-Network
qf_net = QNet(observation_space, action_space, args.h1, args.h2)
qf = DeterministicSAVfunc(
    observation_space, action_space, qf_net,
    data_parallel=args.data_parallel)  # 決定的行動状態価値関数?q-netの出力の形を少し整える

# target Q network theta1
targ_qf1_net = QNet(observation_space, action_space, args.h1, args.h2)
targ_qf1_net.load_state_dict(qf_net.state_dict())  # model(重み)をロード(q-netからコピー)
targ_qf1 = CEMDeterministicSAVfunc(
    observation_space,
    action_space,
    targ_qf1_net,
    num_sampling=args.num_sampling,
    num_best_sampling=args.num_best_sampling,
    num_iter=args.num_iter,
    multivari=args.multivari,
    data_parallel=args.data_parallel,
    save_memory=args.save_memory)  #CrossEntropy Methodよくわからん

# lagged network
lagged_qf_net = QNet(observation_space, action_space, args.h1, args.h2)
lagged_qf_net.load_state_dict(
    qf_net.state_dict())  # model(重み)をロード(theta1からコピー)
lagged_qf = DeterministicSAVfunc(observation_space,
                                 action_space,
                                 lagged_qf_net,
                                 data_parallel=args.data_parallel)

# target network theta2