# Update and plot train reward metrics metrics['steps'].append(t + metrics['steps'][-1]) metrics['episodes'].append(episode) metrics['train_rewards'].append(total_reward) lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir) # Test model print("Test model") if episode % args.test_interval == 0: # Set models to eval mode transition_model.eval() observation_model.eval() reward_model.eval() encoder.eval() actor_model.eval() value_model.eval() # Initialise parallelised test environments test_envs = EnvBatcher( Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth), {}, args.test_episodes) with torch.no_grad(): observation, total_rewards, video_frames = test_envs.reset( ), np.zeros((args.test_episodes, )), [] belief, posterior_state, action = torch.zeros( args.test_episodes, args.belief_size, device=args.device), torch.zeros( args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes,
# Test model print("Test model") if episode % args.test_interval == 0: # PlaNet, Dreamer: Uses the planner that is optimized along with the world model(World model trained with data from reward driven planner). # Plan2Explore zeroshot: Uses the planner that is optimized along with the world model(World model trained with data from curiousity driven planner). # Plan2Explore fewshot: Uses the planner that will not be trained until reaches to adaptation_step. # After the adaptation_step it will be same as PlaNet or Dreamer policy = planner # Set models to eval mode transition_model.eval() observation_model.eval() reward_model.eval() encoder.eval() if args.algo=="p2e" or args.algo=="dreamer": actor_model.eval() value_model.eval() if args.algo=="p2e": curious_actor_model.eval() curious_value_model.eval() # Initialise parallelised test environments with torch.no_grad(): observation, total_rewards, video_frames = test_envs.reset(), np.zeros((args.test_episodes, )), [] belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device) pbar = tqdm(range(args.max_episode_length // args.action_repeat)) for t in pbar: belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, policy, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device)) total_rewards += reward.numpy() if not args.symbolic_env: # Collect real vs. predicted frames for video video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre else: