def get_random_agent_episodes(args, device, steps): envs = make_vec_envs(args, args.num_processes) obs = envs.reset() episode_rewards = deque(maxlen=10) print('-------Collecting samples----------') episodes = [[[]] for _ in range(args.num_processes) ] # (n_processes * n_episodes * episode_len) episode_labels = [[[]] for _ in range(args.num_processes)] for step in range(steps // args.num_processes): # Take action using a random policy action = torch.tensor( np.array([np.random.randint(1, envs.action_space.n) for _ in range(args.num_processes)])) \ .unsqueeze(dim=1).to(device) obs, reward, done, infos = envs.step(action) for i, info in enumerate(infos): if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if done[i] != 1: episodes[i][-1].append(obs[i].clone()) if "labels" in info.keys(): episode_labels[i][-1].append(info["labels"]) else: episodes[i].append([obs[i].clone()]) if "labels" in info.keys(): episode_labels[i].append([info["labels"]]) # Convert to 2d list from 3d list episodes = list(chain.from_iterable(episodes)) # Convert to 2d list from 3d list episode_labels = list(chain.from_iterable(episode_labels)) envs.close() return episodes, episode_labels
def get_ppo_representations(args, steps, checkpoint_step): # Gives PPO represnetations over data collected by a random agent filepath = download_run(args, checkpoint_step) while not os.path.exists(filepath): time.sleep(5) envs = make_vec_envs(args, args.num_processes) actor_critic, ob_rms = \ torch.load(filepath, map_location=lambda storage, loc: storage) mean_reward = evaluate(actor_critic, env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, eval_log_dir="./tmp", device="cpu", num_evals=args.num_rew_evals) print(mean_reward) episode_labels = [[[]] for _ in range(args.num_processes)] episode_rewards = deque(maxlen=10) episode_features = [[[]] for _ in range(args.num_processes)] masks = torch.zeros(1, 1) obs = envs.reset() for step in range(steps // args.num_processes): # Take action using a random policy if args.probe_collect_mode == 'random_agent': action = torch.tensor( np.array([np.random.randint(1, envs.action_space.n) for _ in range(args.num_processes)])) \ .unsqueeze(dim=1) else: with torch.no_grad(): _, action, _, _, actor_features, _ = actor_critic.act( obs, None, masks, deterministic=False) action = torch.tensor([ envs.action_space.sample() if np.random.uniform(0, 1) < 0.2 else action[i] for i in range(args.num_processes) ]).unsqueeze(dim=1) obs, reward, done, infos = envs.step(action) for i, info in enumerate(infos): if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if done[i] != 1: episode_features[i][-1].append(actor_features[i].clone()) if "labels" in info.keys(): episode_labels[i][-1].append(info["labels"]) else: episode_features[i].append([actor_features[i].clone()]) if "labels" in info.keys(): episode_labels[i].append([info["labels"]]) # Convert to 2d list from 3d list episode_labels = list(chain.from_iterable(episode_labels)) episode_features = list(chain.from_iterable(episode_features)) return episode_features, episode_labels, mean_reward
def get_ppo_rollouts(args, steps, checkpoint_step): filepath = download_run(args, checkpoint_step) while not os.path.exists(filepath): time.sleep(5) envs = make_vec_envs(args, args.num_processes) actor_critic, ob_rms = \ torch.load(filepath, map_location=lambda storage, loc: storage) episodes = [[[]] for _ in range(args.num_processes)] # (n_processes * n_episodes * episode_len) episode_labels = [[[]] for _ in range(args.num_processes)] episode_rewards = deque(maxlen=10) step = 0 masks = torch.zeros(1, 1) obs = envs.reset() entropies = [] for step in range(steps // args.num_processes): # Take action using a random policy with torch.no_grad(): obs, action, _, _, actor_features, dist_entropy = actor_critic.act(obs, None, masks, deterministic=False) action = torch.tensor([envs.action_space.sample() if np.random.uniform(0, 1) < 0.2 else action[i] for i in range(args.num_processes)]).unsqueeze(dim=1) entropies.append(dist_entropy.clone()) obs, reward, done, infos = envs.step(action) for i, info in enumerate(infos): if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if done[i] != 1: episodes[i][-1].append(obs[i].clone()) if "labels" in info.keys(): episode_labels[i][-1].append(info["labels"]) else: episodes[i].append([obs[i].clone()]) if "labels" in info.keys(): episode_labels[i].append([info["labels"]]) # Convert to 2d list from 3d list episodes = list(chain.from_iterable(episodes)) # Convert to 2d list from 3d list episode_labels = list(chain.from_iterable(episode_labels)) mean_entropy = torch.stack(entropies).mean() return episodes, episode_labels, np.mean(episode_rewards), mean_entropy
def get_envs( env_name, seed=42, num_processes=1, num_frame_stack=1, downsample=False, color=False ): return make_vec_envs( env_name, seed, num_processes, num_frame_stack, downsample, color )