def test_learning(self): t_pol_net = PolNet(self.env.observation_space, self.env.action_space, h1=200, h2=100) s_pol_net = PolNet(self.env.observation_space, self.env.action_space, h1=190, h2=90) t_pol = GaussianPol( self.env.observation_space, self.env.action_space, t_pol_net) s_pol = GaussianPol( self.env.observation_space, self.env.action_space, s_pol_net) student_sampler = EpiSampler(self.env, s_pol, num_parallel=1) optim_pol = torch.optim.Adam(s_pol.parameters(), 3e-4) epis = student_sampler.sample(s_pol, max_steps=32) traj = Traj() traj.add_epis(epis) traj = ef.compute_h_masks(traj) traj.register_epis() result_dict = on_pol_teacher_distill.train( traj=traj, student_pol=s_pol, teacher_pol=t_pol, student_optim=optim_pol, epoch=1, batchsize=32) del student_sampler
with measure('sample'): if args.sampling_policy == 'teacher': epis = teacher_sampler.sample( t_pol, max_epis=args.max_epis_per_iter) else: epis = student_sampler.sample( s_pol, max_epis=args.max_epis_per_iter) with measure('train'): traj = Traj() traj.add_epis(epis) traj = ef.compute_h_masks(traj) traj.register_epis() result_dict = on_pol_teacher_distill.train( traj=traj, student_pol=s_pol, teacher_pol=t_pol, student_optim=optim_pol, epoch=args.epoch_per_iter, batchsize=args.batch_size) logger.log('Testing Student-policy') with measure('sample'): epis_measure = student_sampler.sample( s_pol, max_epis=args.max_epis_per_iter) with measure('measure'): traj_measure = Traj() traj_measure.add_epis(epis_measure) traj_measure = ef.compute_h_masks(traj_measure) traj_measure.register_epis()