Ejemplo n.º 1
0
def get_random_agent_episodes(args, device, steps):
    envs = make_vec_envs(args, args.num_processes)
    obs = envs.reset()
    episode_rewards = deque(maxlen=10)
    print('-------Collecting samples----------')
    episodes = [[[]] for _ in range(args.num_processes)
                ]  # (n_processes * n_episodes * episode_len)
    episode_labels = [[[]] for _ in range(args.num_processes)]
    for step in range(steps // args.num_processes):
        # Take action using a random policy
        action = torch.tensor(
            np.array([np.random.randint(1, envs.action_space.n) for _ in range(args.num_processes)])) \
            .unsqueeze(dim=1).to(device)
        obs, reward, done, infos = envs.step(action)
        for i, info in enumerate(infos):
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

            if done[i] != 1:
                episodes[i][-1].append(obs[i].clone())
                if "labels" in info.keys():
                    episode_labels[i][-1].append(info["labels"])
            else:
                episodes[i].append([obs[i].clone()])
                if "labels" in info.keys():
                    episode_labels[i].append([info["labels"]])

    # Convert to 2d list from 3d list
    episodes = list(chain.from_iterable(episodes))
    # Convert to 2d list from 3d list
    episode_labels = list(chain.from_iterable(episode_labels))
    envs.close()
    return episodes, episode_labels
def get_ppo_representations(args, steps, checkpoint_step):
    # Gives PPO represnetations over data collected by a random agent
    filepath = download_run(args, checkpoint_step)
    while not os.path.exists(filepath):
        time.sleep(5)

    envs = make_vec_envs(args, args.num_processes)

    actor_critic, ob_rms = \
        torch.load(filepath, map_location=lambda storage, loc: storage)
    mean_reward = evaluate(actor_critic,
                           env_name=args.env_name,
                           seed=args.seed,
                           num_processes=args.num_processes,
                           eval_log_dir="./tmp",
                           device="cpu",
                           num_evals=args.num_rew_evals)
    print(mean_reward)
    episode_labels = [[[]] for _ in range(args.num_processes)]
    episode_rewards = deque(maxlen=10)
    episode_features = [[[]] for _ in range(args.num_processes)]

    masks = torch.zeros(1, 1)
    obs = envs.reset()
    for step in range(steps // args.num_processes):
        # Take action using a random policy
        if args.probe_collect_mode == 'random_agent':
            action = torch.tensor(
                np.array([np.random.randint(1, envs.action_space.n) for _ in range(args.num_processes)])) \
                .unsqueeze(dim=1)
        else:
            with torch.no_grad():
                _, action, _, _, actor_features, _ = actor_critic.act(
                    obs, None, masks, deterministic=False)
            action = torch.tensor([
                envs.action_space.sample()
                if np.random.uniform(0, 1) < 0.2 else action[i]
                for i in range(args.num_processes)
            ]).unsqueeze(dim=1)

        obs, reward, done, infos = envs.step(action)
        for i, info in enumerate(infos):
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

            if done[i] != 1:
                episode_features[i][-1].append(actor_features[i].clone())
                if "labels" in info.keys():
                    episode_labels[i][-1].append(info["labels"])
            else:
                episode_features[i].append([actor_features[i].clone()])
                if "labels" in info.keys():
                    episode_labels[i].append([info["labels"]])

    # Convert to 2d list from 3d list
    episode_labels = list(chain.from_iterable(episode_labels))
    episode_features = list(chain.from_iterable(episode_features))
    return episode_features, episode_labels, mean_reward
Ejemplo n.º 3
0
def get_ppo_rollouts(args, steps, checkpoint_step):
    filepath = download_run(args, checkpoint_step)
    while not os.path.exists(filepath):
        time.sleep(5)

    envs = make_vec_envs(args, args.num_processes)

    actor_critic, ob_rms = \
        torch.load(filepath, map_location=lambda storage, loc: storage)

    episodes = [[[]] for _ in range(args.num_processes)]  # (n_processes * n_episodes * episode_len)
    episode_labels = [[[]] for _ in range(args.num_processes)]
    episode_rewards = deque(maxlen=10)

    step = 0
    masks = torch.zeros(1, 1)
    obs = envs.reset()
    entropies = []
    for step in range(steps // args.num_processes):
        # Take action using a random policy
        with torch.no_grad():
            obs, action, _, _, actor_features, dist_entropy = actor_critic.act(obs, None, masks, deterministic=False)
        action = torch.tensor([envs.action_space.sample() if np.random.uniform(0, 1) < 0.2 else action[i]
                               for i in range(args.num_processes)]).unsqueeze(dim=1)
        entropies.append(dist_entropy.clone())
        obs, reward, done, infos = envs.step(action)
        for i, info in enumerate(infos):
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

            if done[i] != 1:
                episodes[i][-1].append(obs[i].clone())
                if "labels" in info.keys():
                    episode_labels[i][-1].append(info["labels"])
            else:
                episodes[i].append([obs[i].clone()])
                if "labels" in info.keys():
                    episode_labels[i].append([info["labels"]])

    # Convert to 2d list from 3d list
    episodes = list(chain.from_iterable(episodes))
    # Convert to 2d list from 3d list
    episode_labels = list(chain.from_iterable(episode_labels))
    mean_entropy = torch.stack(entropies).mean()
    return episodes, episode_labels, np.mean(episode_rewards), mean_entropy
Ejemplo n.º 4
0
def get_envs(
    env_name, seed=42, num_processes=1, num_frame_stack=1, downsample=False, color=False
):
    return make_vec_envs(
        env_name, seed, num_processes, num_frame_stack, downsample, color
    )