def test_learning(self): pol_net = PolNet(self.env.observation_space, self.env.action_space, h1=32, h2=32, deterministic=True) noise = OUActionNoise(self.env.action_space) pol = DeterministicActionNoisePol(self.env.observation_space, self.env.action_space, pol_net, noise) targ_pol_net = PolNet(self.env.observation_space, self.env.action_space, 32, 32, deterministic=True) targ_pol_net.load_state_dict(pol_net.state_dict()) targ_noise = OUActionNoise(self.env.action_space) targ_pol = DeterministicActionNoisePol(self.env.observation_space, self.env.action_space, targ_pol_net, targ_noise) qf_net = QNet(self.env.observation_space, self.env.action_space, h1=32, h2=32) qf = DeterministicSAVfunc(self.env.observation_space, self.env.action_space, qf_net) targ_qf_net = QNet(self.env.observation_space, self.env.action_space, 32, 32) targ_qf_net.load_state_dict(targ_qf_net.state_dict()) targ_qf = DeterministicSAVfunc(self.env.observation_space, self.env.action_space, targ_qf_net) sampler = EpiSampler(self.env, pol, num_parallel=1) optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4) optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4) epis = sampler.sample(pol, max_steps=32) traj = Traj() traj.add_epis(epis) traj = ef.add_next_obs(traj) traj.register_epis() result_dict = ddpg.train(traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, 1, 32, 0.01, 0.9) del sampler
with measure('train'): on_traj = Traj() on_traj.add_epis(epis) on_traj = ef.add_next_obs(on_traj) on_traj.register_epis() off_traj.add_traj(on_traj) total_epi += on_traj.num_epi step = on_traj.num_step total_step += step result_dict = ddpg.train( off_traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, step, args.batch_size, args.tau, args.gamma ) rewards = [np.sum(epi['rews']) for epi in epis] mean_rew = np.mean(rewards) logger.record_results(args.log, result_dict, score_file, total_epi, step, total_step, rewards, plot_title=args.env_name) if mean_rew > max_rew: torch.save(pol.state_dict(), os.path.join( args.log, 'models', 'pol_max.pkl')) torch.save(qf.state_dict(), os.path.join( args.log, 'models', 'qf_max.pkl'))