Exemplo n.º 1
0
    def test_model(self, episode=None): #no explore here
        if episode is None:
            episode = self.tested_episodes


        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        
        # Initialise parallelised test environments
        test_envs = EnvBatcher(ControlSuiteEnv, (self.parms.env_name, self.parms.seed, self.parms.max_episode_length, self.parms.bit_depth), {}, self.parms.test_episodes)
        total_steps = self.parms.max_episode_length // test_envs.action_repeat
        rewards = np.zeros(self.parms.test_episodes)
        
        real_rew = torch.zeros([total_steps,self.parms.test_episodes])
        predicted_rew = torch.zeros([total_steps,self.parms.test_episodes])

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(), np.zeros((self.parms.test_episodes, )), []            
            belief, posterior_state, action = torch.zeros(self.parms.test_episodes, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.parms.state_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.env.action_size, device=self.parms.device)
            tqdm.write("Testing model.")
            for t in range(total_steps):     
                belief, posterior_state, action, next_observation, rewards, done, pred_next_rew  = self.update_belief_and_act(test_envs,  belief, posterior_state, action, observation.to(device=self.parms.device), list(rewards), self.env.action_range[0], self.env.action_range[1])
                total_rewards += rewards.numpy()
                real_rew[t] = rewards
                predicted_rew[t]  = pred_next_rew

                observation = self.env.get_original_frame().unsqueeze(dim=0)

                video_frames.append(make_grid(torch.cat([observation, self.observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == self.parms.test_episodes:
                    break
            
        real_rew = torch.transpose(real_rew, 0, 1)
        predicted_rew = torch.transpose(predicted_rew, 0, 1)
        
        #save and plot metrics 
        self.tested_episodes += 1
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        lineplot(self.metrics['test_episodes'], self.metrics['test_rewards'], 'test_rewards', self.statistics_path)
        
        write_video(video_frames, 'test_episode_%s' % str(episode), self.video_path)  # Lossy compression
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        test_envs.close()
        return self.metrics
Exemplo n.º 2
0
        if not args.symbolic_env:
            episode_str = str(episode).zfill(len(str(args.episodes)))
            write_video(video_frames, 'test_episode_%s' % episode_str,
                        results_dir)  # Lossy compression
            save_image(
                torch.as_tensor(video_frames[-1]),
                os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
        torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

        # Set models to train mode
        transition_model.train()
        observation_model.train()
        reward_model.train()
        encoder.train()
        # Close test environments
        test_envs.close()

    # Checkpoint models
    print("Completed episode {}".format(episode))
    if episode % args.checkpoint_interval == 0:
        print("Saving!")
        torch.save(
            {
                'transition_model': transition_model.state_dict(),
                'observation_model': observation_model.state_dict(),
                'reward_model': reward_model.state_dict(),
                'encoder': encoder.state_dict(),
                'optimiser': optimiser.state_dict()
            }, os.path.join(results_dir, 'models_%d.pth' % episode))
        if args.checkpoint_experience:
            torch.save(
Exemplo n.º 3
0
    def test(self, episode):
        print("Test model")
        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        self.algorithms.train_to_eval()
        # self.actor_model_g.eval()
        # self.value_model_g.eval()
        # Initialise parallelised test environments
        test_envs = EnvBatcher(
            Env, (args.env, args.symbolic_env, args.seed,
                  args.max_episode_length, args.action_repeat, args.bit_depth),
            {}, args.test_episodes)

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(
            ), np.zeros((args.test_episodes, )), []
            belief, posterior_state, action = torch.zeros(
                args.test_episodes, args.belief_size,
                device=args.device), torch.zeros(
                    args.test_episodes, args.state_size,
                    device=args.device), torch.zeros(args.test_episodes,
                                                     self.env.action_size,
                                                     device=args.device)
            pbar = tqdm(range(args.max_episode_length // args.action_repeat))
            for t in pbar:
                belief, posterior_state, action, next_observation, reward, done = self.update_belief_and_act(
                    args, test_envs, belief, posterior_state, action,
                    observation.to(device=args.device))
                total_rewards += reward.numpy()
                if not args.symbolic_env:  # Collect real vs. predicted frames for video
                    video_frames.append(
                        make_grid(torch.cat([
                            observation,
                            self.observation_model(belief,
                                                   posterior_state).cpu()
                        ],
                                            dim=3) + 0.5,
                                  nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == args.test_episodes:
                    pbar.close()
                    break

        # Update and plot reward metrics (and write video if applicable) and save metrics
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        Save_Txt(self.metrics['test_episodes'][-1],
                 self.metrics['test_rewards'][-1], 'test_rewards',
                 args.results_dir)
        # Save_Txt(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'],'test_rewards_steps', results_dir, xaxis='step')

        # lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir)
        # lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step')
        if not args.symbolic_env:
            episode_str = str(episode).zfill(len(str(args.episodes)))
            write_video(video_frames, 'test_episode_%s' % episode_str,
                        args.results_dir)  # Lossy compression
            save_image(
                torch.as_tensor(video_frames[-1]),
                os.path.join(args.results_dir,
                             'test_episode_%s.png' % episode_str))

        torch.save(self.metrics, os.path.join(args.results_dir, 'metrics.pth'))

        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # self.actor_model_g.train()
        # self.value_model_g.train()
        self.algorithms.eval_to_train()
        # Close test environments
        test_envs.close()