Exemplo n.º 1
0
def test_env_fill_places_correct_num_foods(env: Env) -> None:
    """ Tests that we get exactly ``self.initial_num_foods`` in the grid. """
    env.fill()
    num_grid_foods = 0
    for pos in itertools.product(range(env.width), range(env.height)):
        if env.grid[pos][env.obj_type_ids["food"]] == 1:
            num_grid_foods += 1
    assert num_grid_foods == env.initial_num_foods
Exemplo n.º 2
0
def test_env_fill_generates_id_map_positions_correctly(env: Env) -> None:
    """ Tests that ``self.id_map`` is correct after ``self.fill`` is called. """
    env.fill()
    for x, y in itertools.product(range(env.width), range(env.height)):
        object_map = env.id_map[x][y]
        for obj_type_id, obj_ids in object_map.items():
            if len(obj_ids) > 0:
                assert env.grid[x][y][obj_type_id] == 1
Exemplo n.º 3
0
def test_env_fill_generates_id_map_ids_correctly(env: Env) -> None:
    """ Tests that ``self.id_map`` is correct after ``self.fill`` is called. """
    env.fill()
    for x, y in itertools.product(range(env.width), range(env.height)):
        object_map = env.id_map[x][y]
        for obj_type_id, obj_ids in object_map.items():
            if obj_type_id == env.obj_type_ids["agent"]:
                for obj_id in obj_ids:
                    assert obj_id in env.agents
Exemplo n.º 4
0
def test_env_fill_sets_all_agent_positions_correctly(env: Env) -> None:
    """ Tests that ``agent.pos`` is set correctly. """
    env.fill()
    agent_positions = [agent.pos for agent in env.agents.values()]
    for pos in itertools.product(range(env.width), range(env.height)):
        if pos in agent_positions:
            assert env.grid[pos][env.obj_type_ids["agent"]] == 1
        if env.grid[pos][env.obj_type_ids["agent"]] == 1:
            assert pos in agent_positions
Exemplo n.º 5
0
def test_env_fill_places_correct_number_of_agents(env: Env) -> None:
    """ Tests that we find a place for each agent in ``self.agents``. """
    env.fill()
    num_grid_agents = 0
    for row in env.grid:
        for square in row:
            if square[env.obj_type_ids["agent"]] == 1:
                num_grid_agents += 1
    assert num_grid_agents == len(env.agents)
Exemplo n.º 6
0
def empty_positions(_draw: Callable[[SearchStrategy], Any], env: Env,
                    obj_type_id: int) -> Optional[Tuple[int, int]]:
    """ Strategy for grid positions with any objects of type ``obj_type_id``. """
    for x in range(env.width):
        for y in range(env.height):
            if not env._obj_exists(obj_type_id, (x, y)):
                return (x, y)
    return None
Exemplo n.º 7
0
def test_env_reset_sees_correct_number_of_objects(env: Env) -> None:
    """ Tests that each observation has the correct number of each object type. """

    obs = env.reset()
    for agent_id, agent_obs in obs.items():

        # Calculate correct number of each object type.
        correct_obj_nums = {obj_type: 0 for obj_type in env.obj_type_ids.values()}
        for dx, dy in product(range(-env.sight_len, env.sight_len + 1), repeat=2):
            x = env.agents[agent_id].pos[0] + dx
            y = env.agents[agent_id].pos[1] + dy
            if (x, y) not in product(range(env.width), range(env.height)):
                continue
            for obj_type in env.obj_type_ids.values():
                correct_obj_nums[obj_type] += int(env.grid[x][y][obj_type])

        # Calculate number of each object type in returned observations.
        observed_obj_nums = {obj_type: 0 for obj_type in env.obj_type_ids.values()}
        for dx, dy in product(range(-env.sight_len, env.sight_len + 1), repeat=2):
            for obj_type in env.obj_type_ids.values():
                observed_obj_nums[obj_type] += int(agent_obs[obj_type][dx][dy])

        assert correct_obj_nums == observed_obj_nums
Exemplo n.º 8
0
def train(args: argparse.Namespace) -> float:
    """
    Runs the environment.

    Three command line arguments ``--settings``, ``--load-from``, ``--save-root``.

    If you want to run training from scratch, you must pass ``--settings``, so that
    the program knows what to do, and you may optionally pass ``--save-root`` in the
    case where you would like to save elsewhere from the canonical directory.

    Passing ``--load-from`` will tell the program to attempt to load from a previously
    saved training run at the directory specified by the value of the argument. If
    ``--save-root`` is passed, then all saves during the current run will be saved in
    that root directory, regardless of the root directory implied by ``--load-from``.
    If ``--settings`` is passed, then the settings file, if any, in the ``--load-from``
    directory will be ignored. If no ``--settings`` argument is passed and there is no
    settings file in the ``--load-from`` directory, then the program will raise an
    error. If no ``--save-root`` is passed, the root will be implicitly set to the
    parent directory of ``--load-from``, i.e. the ``--save-root`` from the training run
    being loaded in. It will NOT default to the canonical root unless this is also the
    parent directory of ``--load-from``.

    Since there are a lot of cases, we should add an ``validate_args()`` function which
    raises errors when needed.

    --save-root : ALWAYS OPTIONAL -> Canonical rootdir default.

    --load-from : ALWAYS OPTIONAL -> Empty default.

    --settings : OPTIONAL IF --load-from ELSE REQUIRED -> Empty default.

    Parameters
    ----------
    args : ``argparse.Namespace``.
        Contains arguments as described above.
    """

    # Create metrics and timer.
    metrics = Metrics()
    timer = Timer()

    # TIMER
    timer.start_interval("initialization")

    setup = Setup(args)
    config: Config = setup.config
    save_dir: str = setup.save_dir
    codename: str = setup.codename
    env_log: TextIO = setup.env_log
    visual_log: TextIO = setup.visual_log
    metrics_log: TextIO = setup.metrics_log
    env_state_path: str = setup.env_state_path
    trainer_state: Dict[str, Any] = setup.trainer_state

    # Create environment.
    if config.print_repr:
        print("Arguments:", str(config))
    env = Env(config)

    if not config.reuse_state_dicts:
        print(
            "Warning: ``config.reuse_state_dicts`` is False. This is slower, but the "
            "alternative bounds the number of unique policy initializations, i.e. "
            "policy initializations will be reused for multiple agents.")

    # Set random seed for all packages.
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    # GPU setup.
    torch.set_num_threads(2)
    device = torch.device("cuda:0" if config.cuda else "cpu")
    if config.cuda and torch.cuda.is_available() and config.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(config.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    # Create multiagent maps.
    agents: Dict[int, Algo] = {}
    rollout_map: Dict[int, RolloutStorage] = {}
    minted_agents: Set[int] = set()

    # Save dead objects to make creation faster.
    dead_agents: Set[Algo] = set()
    dead_pipes: Dict[int, Pipe] = {}
    state_dicts: List[collections.OrderedDict] = []
    optim_state_dicts: List[collections.OrderedDict] = []

    # Multiprocessing maps.
    workers: Dict[int, mp.Process] = {}
    devices: Dict[int, torch.device] = {}
    pipes: Dict[int, Pipe] = {}

    # Set spawn start method for compatibility with torch.
    mp.set_start_method("spawn")

    # TODO: Implement this.
    if args.load_from:

        # Load the environment state from file.
        env.load(env_state_path)

        # Load in multiagent maps.
        agents = trainer_state["agents"]
        rollout_map = trainer_state["rollout_map"]
        minted_agents = trainer_state["minted_agents"]
        metrics = trainer_state["metrics"]

        # Load in dead objects.
        dead_agents = trainer_state["dead_agents"]
        state_dicts = trainer_state["state_dicts"]
        optim_state_dicts = trainer_state["optim_state_dicts"]

        # Don't reset environment if we are resuming a previous run.
        obs = {
            agent_id: agent.observation
            for agent_id, agent in env.agents.items()
        }

    else:
        obs = env.reset()

    # Initialize first policies.
    env_done = False
    step_ema = 1.0
    last_time = time.time()
    for agent_id, ob in obs.items():
        agent, rollouts, worker, device, pipe = get_agent(
            agent_id,
            env.iteration,
            env.agents[agent_id].age,
            ob,
            config,
            env.observation_space,
            env.action_space,
            agents,
            rollout_map,
            dead_agents,
            dead_pipes,
            state_dicts,
            optim_state_dicts,
        )

        # Optionally save a copy of each state dict for reuse in dead policies.
        if config.reuse_state_dicts and agent_id not in agents:
            state_dict = agent.actor_critic.state_dict()
            optim_state_dict = agent.optimizer.state_dict()
            state_dicts.append(copy.deepcopy(state_dict))
            optim_state_dicts.append(copy.deepcopy(optim_state_dict))

        # Copy first observations to rollouts, and send to device.
        if not config.mp:
            initial_observation = torch.FloatTensor([ob])
            rollouts.obs[0].copy_(initial_observation)
            rollouts.to(device)

        agents[agent_id] = agent
        workers[agent_id] = worker
        devices[agent_id] = device
        pipes[agent_id] = pipe
        rollout_map[agent_id] = rollouts

    # Whether or not we make a weight update on this iteration.
    backward_pass: bool = False
    while env.iteration < config.time_steps:

        # Should these all be defined up above with other maps?
        minted_agents = set()
        action_dict: Dict[int, int] = {}
        act_map: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
                                 torch.Tensor, torch.Tensor], ] = {}
        timestep_scores: Dict[int, float] = {}

        # Get actions.
        if config.mp:
            for agent_id in pipes:
                action_dict[agent_id] = pipes[agent_id].action_spout.recv()
        else:
            decay = config.use_linear_lr_decay and backward_pass
            for agent_id in agents:
                act_map[agent_id] = act(
                    env.iteration,
                    decay,
                    agents[agent_id],
                    rollout_map[agent_id],
                    config,
                    env.agents[agent_id].age,
                    None,
                )
                action_dict[agent_id] = int(act_map[agent_id][1][0])

        # Execute environment step.
        obs, rewards, dones, infos = env.step(action_dict)
        backward_pass = env.iteration % config.num_steps == 0 and env.iteration > 0

        # TODO: Check for keyerror: for agent_id in obs:
        if config.mp:
            for agent_id in pipes:
                pipes[agent_id].env_funnel.send((
                    env.iteration,
                    obs[agent_id],
                    rewards[agent_id],
                    dones[agent_id],
                    infos[agent_id],
                    backward_pass,
                ))

        # Write env state and metrics to log.
        env.log_state(env_log, visual_log)
        metrics_log.write(str(metrics.get_summary()) + "\n")

        # Update the policy score.
        if (env.iteration + 1) % config.policy_score_frequency == 0:
            # TODO: Check for keyerror: for agent_id in infos:
            if config.mp:
                for agent_id in pipes:
                    timestep_scores[agent_id] = pipes[
                        agent_id].action_dist_spout.recv()
            else:
                for agent_id in act_map:
                    action_dist = act_map[agent_id][4]
                    timestep_score = get_policy_score(action_dist,
                                                      infos[agent_id])
                    timestep_scores[agent_id] = timestep_score

            metrics = update_policy_score(
                env=env,
                config=config,
                timestep_scores=timestep_scores,
                metrics=metrics,
            )

            # This block will run if train() was called with optuna for parameter
            # optimization. If policy score loss explodes, end the training run
            # early.
            if hasattr(args, "trial"):
                args.trial.report(metrics.policy_score, env.iteration)
                if args.trial.should_prune() or metrics.policy_score == float(
                        "inf"):
                    print(
                        "\nEnding training because ``policy_score_loss`` diverged."
                    )
                    return metrics.policy_score

        step_ema = (config.ema_alpha * step_ema) + ((1 - config.ema_alpha) *
                                                    (time.time() - last_time))
        last_time = time.time()

        # Print debug output.
        end = "\n" if config.print_repr else "\r"

        print("Iteration: %d| " % env.iteration, end="")
        print("Num agents: %d| " % len(agents), end="")
        print("Policy score loss: %.6f" % metrics.policy_score, end="")
        print("/%.6f| " % metrics.initial_policy_score, end="")
        print("Step EMA: %.6f" % step_ema, end="")
        print("||||||", end=end)

        # Agent creation and termination, rollout stacking.
        for agent_id in obs:
            ob = obs[agent_id]
            reward = rewards[agent_id]
            done = dones[agent_id]
            info = infos[agent_id]

            # Initialize new policies.
            if agent_id not in agents:
                agent, rollouts, worker, device, pipe = get_agent(
                    agent_id,
                    env.iteration,
                    env.agents[agent_id].age,
                    ob,
                    config,
                    env.observation_space,
                    env.action_space,
                    agents,
                    rollout_map,
                    dead_agents,
                    dead_pipes,
                    state_dicts,
                    optim_state_dicts,
                )

                # Copy first observations to rollouts, and send to device.
                if not config.mp:
                    initial_observation = torch.FloatTensor([ob])
                    rollouts.obs[0].copy_(initial_observation)
                    rollouts.to(device)

                agents[agent_id] = agent
                workers[agent_id] = worker
                devices[agent_id] = device
                pipes[agent_id] = pipe
                rollout_map[agent_id] = rollouts
                minted_agents.add(agent_id)

                # Optionally save a copy of each state dict for reuse in dead policies.
                if config.reuse_state_dicts:
                    state_dict = agent.actor_critic.state_dict()
                    optim_state_dict = agent.optimizer.state_dict()
                    state_dicts.append(copy.deepcopy(state_dict))
                    optim_state_dicts.append(copy.deepcopy(optim_state_dict))

            else:

                # If done then remove from environment.
                if done:
                    agent = agents.pop(agent_id)
                    # TODO: Should we save a reference to dead ``rollouts``?
                    # TODO: Does garbage collection get these? Use ``del``?
                    rollout_map.pop(agent_id)
                    pipes.pop(agent_id)
                    dead_agents.add(agent)

                elif not config.mp:
                    rollouts = rollout_map[agent_id]
                    fwds = act_map[agent_id]
                    stack_rollouts(rollouts, ob, reward, done, info, fwds)

        # Print out environment state.
        if all(dones.values()):
            if config.print_repr:
                print("All agents have died.")
            env_done = True

        # Only update losses and save on backward passes.
        if env.iteration % config.num_steps == 0 and env.iteration > 0:

            value_losses: Dict[int, float] = {}
            action_losses: Dict[int, float] = {}
            dist_entropies: Dict[int, float] = {}

            # Should we iterate over a different object?
            for agent_id, agent in agents.items():
                if agent_id not in minted_agents:
                    if config.mp:
                        losses = pipes[agent_id].loss_spout.recv()
                    else:
                        rollouts = rollout_map[agent_id]
                        losses = update(agent, rollouts, config)
                    value_losses[agent_id] = losses[0]
                    action_losses[agent_id] = losses[1]
                    dist_entropies[agent_id] = losses[2]

            metrics = update_losses(
                env=env,
                config=config,
                losses=(value_losses, action_losses, dist_entropies),
                metrics=metrics,
                minted_agents=minted_agents,
            )

        # Save for every ``config.save_interval``-th step or on the last update.
        # TODO: Ensure that we aren't saving out an empty state on the last interation.
        save_state: bool = env.iteration % config.save_interval == 0
        if save_state or env.iteration == config.time_steps - 1:

            # Update ``agents`` and ``rollouts`` from worker processes.
            if config.mp:
                for agent_id, agent in agents.items():
                    agent, rollouts = pipes[agent_id].save_spout.recv()
                    agents[agent_id] = agent
                    rollout_map[agent_id] = rollouts

            # Save trainer state objects
            trainer_state = {
                "agents": agents,
                "rollout_map": rollout_map,
                "minted_agents": minted_agents,
                "metrics": metrics,
                "dead_agents": dead_agents,
                "state_dicts": state_dicts,
                "optim_state_dicts": optim_state_dicts,
            }
            trainer_state_path = os.path.join(save_dir,
                                              "%s_trainer.pkl" % codename)
            with open(trainer_state_path, "wb") as trainer_file:
                pickle.dump(trainer_state, trainer_file)

            # Save out environment state.
            state_path = os.path.join(save_dir, "%s_env.pkl" % codename)
            env.save(state_path)

            # Save out settings, removing log files (not paths) from object.
            settings_path = os.path.join(save_dir,
                                         "%s_settings.json" % codename)
            with open(settings_path, "w") as settings_file:
                json.dump(config.settings, settings_file)

            if env_done:
                break

        env.iteration += 1

    # Prints a single line to reset carriage.
    print("")

    return metrics.policy_score
Exemplo n.º 9
0
def envs(draw: Callable[[SearchStrategy], Any]) -> Env:
    """ A hypothesis strategy for generating ``Env`` objects. """

    sample: Dict[str, Any] = {}

    sample["width"] = draw(st.integers(min_value=1, max_value=9))
    sample["height"] = draw(st.integers(min_value=1, max_value=9))
    num_squares = sample["width"] * sample["height"]
    sample["num_agents"] = draw(st.integers(min_value=1,
                                            max_value=num_squares))
    sample["sight_len"] = draw(st.integers(min_value=1, max_value=4))
    sample["food_density"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["food_size_mean"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["food_size_stddev"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["food_plant_retries"] = draw(st.integers(min_value=0, max_value=5))
    sample["aging_rate"] = draw(st.floats(min_value=1e-6, max_value=1.0))
    sample["mating_cooldown_len"] = draw(
        st.integers(min_value=0, max_value=100))
    sample["target_agent_density"] = draw(
        st.floats(min_value=0.0, max_value=1.0))
    sample["print_repr"] = draw(st.booleans())
    sample["time_steps"] = draw(st.integers(min_value=0, max_value=1000))
    sample["reuse_state_dicts"] = draw(st.booleans())
    sample["policy_score_frequency"] = draw(
        st.integers(min_value=1, max_value=1000))
    sample["ema_alpha"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["n_layers"] = draw(st.integers(min_value=1, max_value=3))
    sample["hidden_dim"] = draw(st.integers(min_value=1, max_value=64))
    sample["reward_weight_mean"] = draw(
        st.floats(min_value=-2.0, max_value=2.0))
    sample["reward_weight_stddev"] = draw(
        st.floats(min_value=0.0, max_value=1.0))
    reward_inputs = st.sampled_from(["actions", "obs", "health"])
    sample["reward_inputs"] = list(
        draw(st.frozensets(reward_inputs, min_size=1)))
    sample["mut_sigma"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["mut_p"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["algo"] = draw(st.sampled_from(["ppo", "a2c", "acktr"]))
    sample["lr"] = draw(st.floats(min_value=1e-6, max_value=0.2))
    sample["min_lr"] = draw(st.floats(min_value=1e-6, max_value=0.2))
    sample["eps"] = draw(st.floats(min_value=0.0, max_value=1e-2))
    sample["alpha"] = draw(st.floats(min_value=0.0, max_value=1e-2))
    sample["gamma"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["use_gae"] = draw(st.booleans())
    sample["gae_lambda"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["value_loss_coef"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["max_grad_norm"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["seed"] = draw(st.integers(min_value=0, max_value=10))
    sample["cuda_deterministic"] = draw(st.booleans())
    sample["num_processes"] = draw(st.integers(min_value=1, max_value=10))
    sample["num_steps"] = draw(st.integers(min_value=1, max_value=1000))
    sample["ppo_epoch"] = draw(st.integers(min_value=1, max_value=8))
    sample["num_mini_batch"] = draw(st.integers(min_value=1, max_value=8))
    sample["clip_param"] = draw(st.floats(min_value=0.0, max_value=1.0))
    sample["log_interval"] = draw(st.integers(min_value=1, max_value=100))
    sample["save_interval"] = draw(st.integers(min_value=1, max_value=100))
    sample["cuda"] = draw(st.booleans())
    sample["use_proper_time_limits"] = draw(st.booleans())
    sample["recurrent_policy"] = draw(st.booleans())
    sample["use_linear_lr_decay"] = draw(st.booleans())

    # Read settings file for defaults.
    settings_path = "bees/settings/settings.json"
    with open(settings_path, "r") as settings_file:
        settings = json.load(settings_file)

    # Fill settings with values from arguments.
    for key, value in sample.items():
        settings[key] = value

    config = Config(settings)
    env = Env(config)

    return env