예제 #1
0
def perform_update(config, env: TGEnv, team_policies: List[Policy],
                   avatar_storages: List[RolloutsStorage]):
    """ Collects rollouts and updates """

    # Used to log
    total_rewards = EpisodeAccumulator(env.num_avatars)
    steps_alive = EpisodeAccumulator(env.num_avatars)
    first_step_probas = EpisodeAccumulator(env.num_avatars,
                                           max(env.num_actions))
    end_reasons = []

    # Will be filled in for each avatar when stepping the environment individually
    actions = [0] * env.num_avatars
    action_log_probs = [0] * env.num_avatars
    values = [0] * env.num_avatars

    avatar_policies = [
        team_policies[env.id2team[avatar_id]]
        for avatar_id in range(env.num_avatars)
    ]

    # Always start with a fresh env
    env_states = env.reset()  # shape: [num_avatars, *env_state_shape]
    rec_hs, rec_cs = _get_initial_recurrent_state(avatar_policies)
    next_rec_hs, next_rec_cs = rec_hs, rec_cs
    first_episode_step = True

    # Set to eval mode because batch norm cannot be computed on a batch size of 1
    # and when we pick actions we only pick one at a time
    for policy in team_policies:
        policy.controller.eval()

    # Collect rollouts
    # TODO (?): collect in parallel
    for step in range(config.num_transitions):
        # Alive at the beginning of step
        avatar_alive = env.avatar_alive.copy()

        # Run each alive avatar individually
        for avatar_id in range(env.num_avatars):
            if avatar_alive[avatar_id]:
                # Chose action based on the policy
                team = env.id2team[avatar_id]
                policy = team_policies[team]

                action_source = policy.scheduler.pick_action_source()
                if action_source == SCRIPTED:
                    scripted_action = env.scripted_action(avatar_id)
                else:
                    scripted_action = None

                (
                    actions[avatar_id],
                    action_log_probs[avatar_id],
                    actor_logits,
                    values[avatar_id],
                    next_rec_hs[avatar_id],
                    next_rec_cs[avatar_id],
                ) = policy.pick_action_and_info(
                    env_states[avatar_id],
                    rec_hs[avatar_id],
                    rec_cs[avatar_id],
                    sampling_method=action_source,
                    externally_chosen_action=scripted_action,
                )

                if first_episode_step:
                    probas = softmax(actor_logits.detach().numpy().flatten())
                    first_step_probas.current[avatar_id, :len(probas)] = probas

        # Step the environment with one action for each avatar
        next_env_states, rewards, dones, info = env.step(actions)

        # Insert transitions for alive avatars
        for avatar_id in range(env.num_avatars):
            if avatar_alive[avatar_id]:
                storage = avatar_storages[avatar_id]
                storage.insert(
                    env_states[avatar_id],
                    actions[avatar_id],
                    action_log_probs[avatar_id],
                    values[avatar_id],
                    rewards[avatar_id],
                    dones[avatar_id],
                    rec_hs[avatar_id],
                    rec_cs[avatar_id],
                )

        total_rewards.current += rewards
        steps_alive.current += avatar_alive

        # Episode is done
        if all(dones):
            env_states = env.reset()
            rec_hs, rec_cs = _get_initial_recurrent_state(avatar_policies)

            total_rewards.episode_over()
            steps_alive.episode_over()
            first_step_probas.episode_over()
            end_reasons.append(info['end_reason'])
            first_episode_step = True

        # The states were not immediately overwritten because we store the state that was used to generate (env_states)
        # the action for the current time-step, not the one we arrive in (next_env_states)
        else:
            env_states = next_env_states
            rec_hs = next_rec_hs
            rec_cs = next_rec_cs

            first_episode_step = False

    # Compute returns for all storages
    for storage in avatar_storages:
        storage.compute_returns()

    # Report progress
    avatar_rewards = total_rewards.final_history(drop_last=True)
    avg_team_rewards = np.array(
        [
            avatar_rewards[:, mask].sum() /
            sum(mask)  # take average per avatar
            for mask in env.team_masks
        ]
    )  # shape [num_teams,] holds the average reward of all thieves thieves got and the average of all guardians
    relative_team_rewards = avg_team_rewards / avg_team_rewards.sum()
    for measure, policy in zip(relative_team_rewards, team_policies):
        policy.scheduler.end_iteration_report(measure)

    # Set to training mode
    for policy in team_policies:
        policy.controller.train()

    # Update policies
    losses_history = [[] for _ in range(env.num_teams)]
    for epoch in range(config.num_epochs):
        for avatar_id in range(env.num_avatars):
            team = env.id2team[avatar_id]
            policy = team_policies[team]
            storage = avatar_storages[avatar_id]

            for batch in storage.sample_batches():
                # A batch contains multidimensional env_states, rec_hs, rec_cs, actions, old_action_log_probs, returns
                losses = policy.update(*batch)
                losses_history[team].append(losses)

    # Prepare storages for the next update
    for storage in avatar_storages:
        storage.reset()

    scheduling_statuses = [
        policy.sync_scheduled_values() for policy in team_policies
    ]

    # Ignore last episode since it's most likely unfinished
    return (
        total_rewards.final_history(drop_last=True),
        steps_alive.final_history(drop_last=True),
        first_step_probas.final_history(drop_last=False),
        end_reasons,
        losses_history,
        scheduling_statuses,
    )
예제 #2
0
def simulate_episode(env: TGEnv, team_policies: List[Policy], sampling_method):
    """
    sampling_method: either one int or one for each team
    """
    if type(sampling_method) is int:
        sampling_method = [sampling_method] * env.num_teams

    map_history = []
    pos2id_history = []
    rewards_history = []
    actions_history = []

    avatar_policies = [
        team_policies[env.id2team[avatar_id]]
        for avatar_id in range(env.num_avatars)
    ]

    env_states = env.reset()  # shape: [num_avatars, *env_state_shape]
    rec_hs, rec_cs = _get_initial_recurrent_state(avatar_policies)
    dones = [False] * env.num_avatars
    cumulative_reward = np.zeros(env.num_avatars)

    actions = [0] * env.num_avatars
    action_log_probs = [0] * env.num_avatars

    # Set to evaluation mode
    for policy in team_policies:
        policy.controller.eval()

    while not all(dones):
        map_history.append(env._map.copy())
        pos2id_history.append(copy(env._pos2id))
        rewards_history.append(cumulative_reward.copy())

        # Alive at the beginning of step
        avatar_alive = env.avatar_alive.copy()

        # Run each alive avatar individually
        for avatar_id in range(env.num_avatars):
            if avatar_alive[avatar_id]:
                # Chose action based on the policy
                team = env.id2team[avatar_id]
                policy = team_policies[team]

                if sampling_method[team] == SCRIPTED:
                    scripted_action = env.scripted_action(avatar_id)
                else:
                    scripted_action = None

                with torch.no_grad():
                    (
                        actions[avatar_id],
                        action_log_probs[avatar_id],
                        _,
                        _,
                        rec_hs[avatar_id],
                        rec_cs[avatar_id],
                    ) = policy.pick_action_and_info(
                        env_states[avatar_id],
                        rec_hs[avatar_id],
                        rec_cs[avatar_id],
                        sampling_method=sampling_method[team],
                        externally_chosen_action=scripted_action,
                    )
            else:
                actions[avatar_id] = DEAD

        # Step the environment with one action for each avatar
        env_states, rewards, dones, infos = env.step(actions)
        cumulative_reward += rewards
        actions_history.append([
            ACTION_IDX2SYMBOL[env._interpret_action(a, env.id2team[i])]
            for i, a in enumerate(actions)
        ])

    # Add final state as well
    map_history.append(env._map.copy())
    pos2id_history.append(copy(env._pos2id))
    rewards_history.append(cumulative_reward.copy())
    actions_history.append([ACTION_IDX2SYMBOL[DEAD]] * env.num_avatars)

    return map_history, pos2id_history, rewards_history, actions_history, infos[
        'end_reason']