Example #1
0
def test_records_next_observations():
    """Recording next observations."""
    env = build_dummy_structured_env()
    rollout_generator = RolloutGenerator(env=env,
                                         record_next_observations=True)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(policy, n_steps=10)

    assert len(trajectory) == 10

    sub_step_keys = env.action_spaces_dict.keys()
    last_next_obs = None
    for record in trajectory.step_records:
        assert sub_step_keys == record.observations_dict.keys()
        assert sub_step_keys == record.next_observations_dict.keys()
        assert record.batch_shape is None

        for step_key in sub_step_keys:
            curr_obs = record.observations_dict[step_key]

            # Next obs from the previous sub-step should be equal to the current observation
            if last_next_obs:
                assert list(curr_obs.keys()) == list(last_next_obs.keys())
                for obs_key in curr_obs.keys():
                    assert np.all(curr_obs[obs_key] == last_next_obs[obs_key])

            last_next_obs = record.next_observations_dict[step_key]
Example #2
0
    def init_replay_buffer(replay_buffer: BaseReplayBuffer, initial_sampling_policy: Union[DictConfig, Policy],
                           initial_buffer_size: int, replay_buffer_seed: int,
                           split_rollouts_into_transitions: bool, n_rollout_steps: int,
                           env_factory: Callable[[], MazeEnv]) -> None:
        """Fill the buffer with initial_buffer_size rollouts by rolling out the initial_sampling_policy.

        :param replay_buffer: The replay buffer to use.
        :param initial_sampling_policy: The initial sampling policy used to fill the buffer to the initial fill state.
        :param initial_buffer_size: The initial size of the replay buffer filled by sampling from the initial sampling
            policy.
        :param replay_buffer_seed: A seed for initializing and sampling from the replay buffer.
        :param split_rollouts_into_transitions: Specify whether to split rollouts into individual transitions.
        :param n_rollout_steps: Number of rollouts steps to record in one rollout.
        :param env_factory: Factory function for envs to run rollouts on.
        """

        # Create the log stats aggregator for collecting kpis of initializing the replay buffer
        epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        replay_stats_logger = get_stats_logger('init_replay_buffer')
        epoch_stats.register_consumer(replay_stats_logger)

        dummy_env = env_factory()
        dummy_env.seed(replay_buffer_seed)
        sampling_policy: Policy = \
            Factory(Policy).instantiate(initial_sampling_policy, action_spaces_dict=dummy_env.action_spaces_dict)
        sampling_policy.seed(replay_buffer_seed)
        rollout_generator = RolloutGenerator(env=dummy_env,
                                             record_next_observations=True,
                                             record_episode_stats=True)

        print(f'******* Starting to fill the replay buffer with {initial_buffer_size} transitions *******')
        while len(replay_buffer) < initial_buffer_size:
            trajectory = rollout_generator.rollout(policy=sampling_policy, n_steps=n_rollout_steps)

            if split_rollouts_into_transitions:
                replay_buffer.add_rollout(trajectory)
            else:
                replay_buffer.add_transition(trajectory)

            # collect episode statistics
            for step_record in trajectory.step_records:
                if step_record.episode_stats is not None:
                    epoch_stats.receive(step_record.episode_stats)

        # Print the kpis from initializing the replay buffer
        epoch_stats.reduce()
        # Remove the consumer again from the aggregator
        epoch_stats.remove_consumer(replay_stats_logger)
Example #3
0
    def __init__(self, env_factory: Callable[[],
                                             Union[StructuredEnv,
                                                   StructuredEnvSpacesMixin,
                                                   LogStatsEnv]],
                 policy: TorchPolicy, n_rollout_steps: int, n_actors: int,
                 batch_size: int, actor_env_seeds: List[int]):
        super().__init__(env_factory, policy, n_rollout_steps, n_actors,
                         batch_size)

        self.broadcasting_container = BroadcastingContainer()
        self.current_actor_idx = 0

        self.actors: List[RolloutGenerator] = []
        self.policy_version_counter = 0

        for env_seed in actor_env_seeds:
            env = env_factory()
            env.seed(env_seed)
            actor = RolloutGenerator(env=env,
                                     record_logits=True,
                                     record_episode_stats=True)
            self.actors.append(actor)

        if self.n_actors > self.batch_size:
            BColors.print_colored(
                f'It does not make much sense to have more actors (given value: {n_actors}) than '
                f'the actor_batch_size (given value: {batch_size}) when using the DummyMultiprocessingModule.',
                color=BColors.WARNING)
Example #4
0
def train_function(n_epochs: int, distributed_env_cls) -> A2C:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    envs = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)])

    # initialize the env and enable statistics collection
    eval_env = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)],
                                   logging_prefix='eval')

    # init distribution mapper
    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={})

    # initialize policies
    policies = {0: FlattenConcatPolicyNet({'observation': (4,)}, {'action': (2,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # initialize critic
    critics = {0: FlattenConcatStateValueNet({'observation': (4,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # algorithm configuration
    algorithm_config = A2CAlgorithmConfig(
        n_epochs=n_epochs,
        epoch_length=2,
        patience=10,
        critic_burn_in_epochs=0,
        n_rollout_steps=20,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.0,
        max_grad_norm=0.0,
        device="cpu",
        rollout_evaluator=RolloutEvaluator(eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True)
    )

    # initialize actor critic model
    model = TorchActorCritic(
        policy=TorchPolicy(networks=policies, distribution_mapper=distribution_mapper, device=algorithm_config.device),
        critic=TorchSharedStateCritic(networks=critics, obs_spaces_dict=env.observation_spaces_dict,
                                      device=algorithm_config.device,
                                      stack_observations=False),
        device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              algorithm_config=algorithm_config,
              evaluator=algorithm_config.rollout_evaluator,
              model=model,
              model_selection=None)

    # train agent
    a2c.train()

    return a2c
Example #5
0
def test_standard_rollout_with_logits_and_step_stats():
    """Recording logits and step statistics."""
    env = build_dummy_structured_env()
    rollout_generator = RolloutGenerator(env=env,
                                         record_step_stats=True,
                                         record_logits=True)
    policy = flatten_concat_probabilistic_policy_for_env(
        env)  # We need a torch policy to be able to record logits
    trajectory = rollout_generator.rollout(policy, n_steps=10)

    assert len(trajectory) == 10

    sub_step_keys = env.action_spaces_dict.keys()
    for record in trajectory.step_records:
        assert record.step_stats is not None
        assert sub_step_keys == record.logits_dict.keys()

        for step_key in sub_step_keys:
            assert record.logits_dict[step_key].keys(
            ) == record.actions_dict[step_key].keys()
Example #6
0
def test_standard_rollout():
    """Rollout with a single structured env."""
    env = build_dummy_structured_env()
    rollout_generator = RolloutGenerator(env=env)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(policy, n_steps=10)

    assert len(trajectory) == 10

    sub_step_keys = env.action_spaces_dict.keys()
    for record in trajectory.step_records:
        assert sub_step_keys == record.actions_dict.keys()
        assert sub_step_keys == record.observations_dict.keys()
        assert sub_step_keys == record.rewards_dict.keys()

        assert record.batch_shape is None
        for step_key in sub_step_keys:
            assert record.observations_dict[
                step_key] in env.observation_spaces_dict[step_key]
            assert record.actions_dict[step_key] in env.action_spaces_dict[
                step_key]
Example #7
0
def test_handles_done_in_substep_with_recorded_episode_stats():
    """Recording episode stats and handling environments that return done during a (non-last) sub-step."""
    env = build_dummy_structured_env()
    env = _FiveSubstepsLimitWrapper.wrap(env)
    policy = RandomPolicy(env.action_spaces_dict)

    # -- Normal operation (should reset the env automatically and continue rollout) --
    rollout_generator = RolloutGenerator(env=env, record_episode_stats=True)
    trajectory = rollout_generator.rollout(policy, n_steps=10)
    assert len(trajectory) == 10

    # The done step records should have data for the first sub-step only
    dones = 0
    for step_record in trajectory.step_records:
        if step_record.is_done():
            assert [0] == list(step_record.observations_dict.keys())
            dones += 1
            assert step_record.episode_stats is not None
        else:
            assert [0, 1] == list(step_record.observations_dict.keys())
            assert step_record.episode_stats is None
    assert dones == 3  # Each episode is done after 5 sub-steps, i.e. 3 structured steps get recorded => 3 episodes fit

    # -- Terminate on done --
    rollout_generator = RolloutGenerator(env=env, terminate_on_done=True)
    trajectory = rollout_generator.rollout(policy, n_steps=10)
    assert len(trajectory) == 3
    assert trajectory.is_done()
    assert [0] == list(trajectory.step_records[-1].observations_dict.keys())
Example #8
0
def _actor_worker(pickled_env_factory: bytes, pickled_policy: bytes,
                  n_rollout_steps: int, actor_output_queue: multiprocessing.Queue,
                  broadcasting_container: BroadcastingContainer, env_seed: int, agent_seed: int):
    """Worker function for the actors. This Method is called with a new process. Its task is to initialize the,
        before going into a loop - updating its policy if necessary, computing a rollout and putting the result into
        the shared queue.

    :param pickled_env_factory: The pickled env factory
    :param pickled_policy: Pickled structured policy
    :param n_rollout_steps: the number of rollout steps to be computed for each rollout
    :param actor_output_queue: the queue to put the computed rollouts in
    :param broadcasting_container: the shared container, where actors can retrieve the newest version of the policies
    :param env_seed: The env seed to be used.
    :param agent_seed: The agent seed to be used.
    """
    try:
        env_factory = cloudpickle.loads(pickled_env_factory)
        env = env_factory()
        env.seed(env_seed)

        policy: TorchPolicy = cloudpickle.loads(pickled_policy)
        policy.seed(agent_seed)
        policy_version_counter = -1

        rollout_generator = RolloutGenerator(env=env, record_logits=True, record_episode_stats=True)

        while not broadcasting_container.stop_flag():
            # Update the policy if new version is available
            shared_policy_version_counter = broadcasting_container.policy_version()
            if policy_version_counter < shared_policy_version_counter:
                policy.load_state_dict(broadcasting_container.policy_state_dict())
                policy_version_counter = shared_policy_version_counter

            trajectory = rollout_generator.rollout(policy, n_steps=n_rollout_steps)
            actor_output_queue.put(trajectory)

    except Exception as e:
        actor_output_queue.put(ExceptionReport(e))
Example #9
0
def test_vectorized_rollout():
    """Rollout with a vector env."""
    concurrency = 3
    env = SequentialVectorEnv([build_dummy_structured_env] * concurrency)
    rollout_generator = RolloutGenerator(env=env)
    policy = DistributedRandomPolicy(env.action_spaces_dict,
                                     concurrency=concurrency)
    trajectory = rollout_generator.rollout(policy, n_steps=10)

    assert len(trajectory) == 10

    sub_step_keys = env.action_spaces_dict.keys()
    for record in trajectory.step_records:
        assert sub_step_keys == record.actions_dict.keys()
        assert sub_step_keys == record.observations_dict.keys()
        assert sub_step_keys == record.rewards_dict.keys()

        assert record.batch_shape == [concurrency]
        # The first dimension of the observations should correspond to the distributed env concurrency
        # (We just check the very first array present in the first observation)
        first_sub_step_obs: Dict = list(record.observations_dict.values())[0]
        first_obs_value = list(first_sub_step_obs.values())[0]
        assert first_obs_value.shape[0] == concurrency
Example #10
0
def estimate_observation_normalization_statistics(env: Union[
    MazeEnv, ObservationNormalizationWrapper], n_samples: int) -> None:
    """Helper function estimating normalization statistics.
    :param env: The observation normalization wrapped environment.
    :param n_samples: The number of samples (i.e., flat environment steps) to take for statistics computation.
    """

    # remove previous statistics dump
    if os.path.exists(env.statistics_dump):
        os.remove(env.statistics_dump)

    print(
        f'******* Starting to estimate observation normalization statistics for {n_samples} steps *******'
    )
    # collect normalization statistics
    env.set_observation_collection(True)
    rollout_generator = RolloutGenerator(env)

    for _ in tqdm(range(n_samples)):
        rollout_generator.rollout(policy=env.sampling_policy, n_steps=1)

    # finally estimate normalization statistics
    env.estimate_statistics()
Example #11
0
    def setup(self, cfg: DictConfig) -> None:
        """
        See :py:meth:`~maze.train.trainers.common.training_runner.TrainingRunner.setup`.
        """

        super().setup(cfg)

        # initialize distributed env
        envs = self.create_distributed_env(self.env_factory, self.concurrency, logging_prefix="train")
        train_env_instance_seeds = [self.maze_seeding.generate_env_instance_seed() for _ in range(self.concurrency)]
        envs.seed(train_env_instance_seeds)

        # initialize actor critic model
        model = TorchActorCritic(
            policy=self._model_composer.policy,
            critic=self._model_composer.critic,
            device=cfg.algorithm.device)

        # initialize best model selection
        self._model_selection = BestModelSelection(dump_file=self.state_dict_dump_file, model=model,
                                                   dump_interval=self.dump_interval)

        # initialize the env and enable statistics collection
        evaluator = None
        if cfg.algorithm.rollout_evaluator.n_episodes > 0:
            eval_env = self.create_distributed_env(self.env_factory, self.eval_concurrency, logging_prefix="eval")
            eval_env_instance_seeds = [self.maze_seeding.generate_env_instance_seed()
                                       for _ in range(self.eval_concurrency)]
            eval_env.seed(eval_env_instance_seeds)

            # initialize rollout evaluator
            evaluator = Factory(base_type=RolloutEvaluator).instantiate(cfg.algorithm.rollout_evaluator,
                                                                        eval_env=eval_env,
                                                                        model_selection=self._model_selection)

        # look up model class
        trainer_class = Factory(base_type=ActorCritic).type_from_name(self.trainer_class)

        # initialize trainer (from input directory)
        self._trainer = trainer_class(
            algorithm_config=cfg.algorithm,
            rollout_generator=RolloutGenerator(env=envs),
            evaluator=evaluator,
            model=model,
            model_selection=self._model_selection
        )

        self._init_trainer_from_input_dir(trainer=self._trainer, state_dict_dump_file=self.state_dict_dump_file,
                                          input_dir=cfg.input_dir)
Example #12
0
    def _init_workers(self):

        self.broadcasting_container = BroadcastingContainer()
        self.current_worker_idx = 0

        self.workers = []
        self.policy_version_counter = 0

        for env_seed in self.env_instance_seeds:
            env = self.env_factory()
            env.seed(env_seed)
            actor = RolloutGenerator(env=env,
                                     record_next_observations=True,
                                     record_episode_stats=True)
            self.workers.append(actor)
Example #13
0
def test_terminates_on_done():
    """Resetting the env or terminating rollout early when the env is done."""
    env = build_dummy_maze_env()
    env = TimeLimitWrapper.wrap(env, max_episode_steps=5)
    policy = RandomPolicy(env.action_spaces_dict)

    # Normal operation (should reset the env automatically and continue rollout)
    rollout_generator = RolloutGenerator(env=env)
    trajectory = rollout_generator.rollout(policy, n_steps=10)
    assert len(trajectory) == 10

    # Terminate on done
    rollout_generator = RolloutGenerator(env=env, terminate_on_done=True)
    trajectory = rollout_generator.rollout(policy, n_steps=10)
    assert len(trajectory) == 5
Example #14
0
def train(n_epochs):
    # Instantiate one environment. This will be used for convenient access to observation
    # and action spaces.
    env = cartpole_env_factory()
    observation_space = env.observation_space
    action_space = env.action_space

    # Policy Setup
    # ------------

    # Policy Network
    # ^^^^^^^^^^^^^^
    # Instantiate policy with the correct shapes of observation and action spaces.
    policy_net = CartpolePolicyNet(
        obs_shapes={'observation': observation_space.spaces['observation'].shape},
        action_logit_shapes={'action': (action_space.spaces['action'].n,)})

    maze_wrapped_policy_net = TorchModelBlock(
        in_keys='observation', out_keys='action',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=policy_net)

    policy_networks = {0: maze_wrapped_policy_net}

    # Policy Distribution
    # ^^^^^^^^^^^^^^^^^^^
    distribution_mapper = DistributionMapper(
        action_space=action_space,
        distribution_mapper_config={})

    # Optionally, you can specify a different distribution with the distribution_mapper_config argument. Using a
    # Categorical distribution for a discrete action space would be done via
    distribution_mapper = DistributionMapper(
        action_space=action_space,
        distribution_mapper_config=[{
            "action_space": gym.spaces.Discrete,
            "distribution": "maze.distributions.categorical.CategoricalProbabilityDistribution"}])

    # Instantiating the Policy
    # ^^^^^^^^^^^^^^^^^^^^^^^^
    torch_policy = TorchPolicy(networks=policy_networks, distribution_mapper=distribution_mapper, device='cpu')

    # Value Function Setup
    # --------------------

    # Value Network
    # ^^^^^^^^^^^^^
    value_net = CartpoleValueNet(obs_shapes={'observation': observation_space.spaces['observation'].shape})

    maze_wrapped_value_net = TorchModelBlock(
        in_keys='observation', out_keys='value',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=value_net)

    value_networks = {0: maze_wrapped_value_net}

    # Instantiate the Value Function
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    torch_critic = TorchSharedStateCritic(networks=value_networks, obs_spaces_dict=env.observation_spaces_dict,
                                          device='cpu', stack_observations=False)

    # Initializing the ActorCritic Model.
    # -----------------------------------
    actor_critic_model = TorchActorCritic(policy=torch_policy, critic=torch_critic, device='cpu')

    # Instantiating the Trainer
    # =========================

    algorithm_config = A2CAlgorithmConfig(
        n_epochs=n_epochs,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SequentialVectorEnv([cartpole_env_factory]),
            n_episodes=1,
            model_selection=None,
            deterministic=True
        )
    )

    # Distributed Environments
    # ------------------------
    # In order to use the distributed trainers, the previously created env factory is supplied to one of Maze's
    # distribution classes:
    train_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="train")
    eval_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="eval")

    # Initialize best model selection.
    model_selection = BestModelSelection(dump_file="params.pt", model=actor_critic_model)

    a2c_trainer = A2C(rollout_generator=RolloutGenerator(train_envs),
                      evaluator=algorithm_config.rollout_evaluator,
                      algorithm_config=algorithm_config,
                      model=actor_critic_model,
                      model_selection=model_selection)

    # Train the Agent
    # ===============
    # Before starting the training, we will enable logging by calling
    log_dir = '.'
    setup_logging(job_config=None, log_dir=log_dir)

    # Now, we can train the agent.
    a2c_trainer.train()

    return 0
def test_custom_model_composer_with_shared_embedding():
    env = build_dummy_structured_env()

    policies = {
        "_target_":
        "maze.perception.models.policies.ProbabilisticPolicyComposer",
        "networks": [{
            "_target_":
            "maze.perception.models.built_in.flatten_concat_shared_embedding.FlattenConcatSharedEmbeddingPolicyNet",
            "non_lin": "torch.nn.SELU",
            "hidden_units": [16],
            "head_units": [16]
        }, {
            "_target_":
            "maze.perception.models.built_in.flatten_concat_shared_embedding.FlattenConcatSharedEmbeddingPolicyNet",
            "non_lin": "torch.nn.SELU",
            "hidden_units": [16],
            "head_units": [16]
        }],
        "substeps_with_separate_agent_nets": []
    }

    step_critic = {
        "_target_":
        "maze.perception.models.critics.StepStateCriticComposer",
        "networks": [{
            "_target_":
            "maze.perception.models.built_in.flatten_concat_shared_embedding.FlattenConcatSharedEmbeddingStateValueNet",
            "non_lin": "torch.nn.SELU",
            "head_units": [16]
        }, {
            "_target_":
            "maze.perception.models.built_in.flatten_concat_shared_embedding.FlattenConcatSharedEmbeddingStateValueNet",
            "non_lin": "torch.nn.SELU",
            "head_units": [16]
        }]
    }

    # check if model config is fine
    CustomModelComposer.check_model_config({"critic": step_critic})

    composer = CustomModelComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        agent_counts_dict=env.agent_counts_dict,
        distribution_mapper_config=[],
        policy=policies,
        critic=step_critic)

    assert isinstance(composer.distribution_mapper, DistributionMapper)
    assert isinstance(composer.critic, TorchStepStateCritic)
    assert isinstance(composer.critic.networks, dict)

    # test saving models
    composer.save_models()

    try:
        import pygraphviz

        for model_file in [
                "critic_0.pdf", "critic_1.pdf", "policy_0.pdf", "policy_1.pdf"
        ]:
            file_path = os.path.join(os.getcwd(), model_file)
            assert os.path.exists(file_path)
            os.remove(file_path)
    except ImportError:
        pass  # no output generated as pygraphviz is not installed.

    rollout_generator = RolloutGenerator(env=env,
                                         record_next_observations=False)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(
        policy, n_steps=10).stack().to_torch(device='cpu')

    policy_output = composer.policy.compute_policy_output(trajectory)
    critic_input = StateCriticInput.build(policy_output, trajectory)
    _ = composer.critic.predict_values(critic_input)
Example #16
0
def main(n_epochs: int) -> None:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    envs = SequentialVectorEnv(
        [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)],
        logging_prefix="train")

    # initialize the env and enable statistics collection
    eval_env = SequentialVectorEnv(
        [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)],
        logging_prefix="eval")

    # init distribution mapper
    env = GymMazeEnv(env="CartPole-v0")

    # init default distribution mapper
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    # initialize policies
    policies = {
        0: PolicyNet({'observation': (4, )}, {'action': (2, )},
                     non_lin=nn.Tanh)
    }

    # initialize critic
    critics = {0: ValueNet({'observation': (4, )})}

    # initialize optimizer
    algorithm_config = A2CAlgorithmConfig(n_epochs=n_epochs,
                                          epoch_length=10,
                                          patience=10,
                                          critic_burn_in_epochs=0,
                                          n_rollout_steps=20,
                                          lr=0.0005,
                                          gamma=0.98,
                                          gae_lambda=1.0,
                                          policy_loss_coef=1.0,
                                          value_loss_coef=0.5,
                                          entropy_coef=0.0,
                                          max_grad_norm=0.0,
                                          device="cpu",
                                          rollout_evaluator=RolloutEvaluator(
                                              eval_env=eval_env,
                                              n_episodes=1,
                                              model_selection=None,
                                              deterministic=True))

    # initialize actor critic model
    model = TorchActorCritic(policy=TorchPolicy(
        networks=policies,
        distribution_mapper=distribution_mapper,
        device=algorithm_config.device),
                             critic=TorchSharedStateCritic(
                                 networks=critics,
                                 obs_spaces_dict=env.observation_spaces_dict,
                                 device=algorithm_config.device,
                                 stack_observations=False),
                             device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              evaluator=algorithm_config.rollout_evaluator,
              algorithm_config=algorithm_config,
              model=model,
              model_selection=None)

    setup_logging(job_config=None)

    # train agent
    a2c.train()

    # final evaluation run
    print("Final Evaluation Run:")
    a2c.evaluate()
def main(n_epochs: int, rnn_steps: int) -> None:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """
    env_name = "CartPole-v0"

    # initialize distributed env
    envs = SequentialVectorEnv([
        lambda: to_rnn_dict_space_environment(env=env_name,
                                              rnn_steps=rnn_steps)
        for _ in range(4)
    ],
                               logging_prefix="train")

    # initialize the env and enable statistics collection
    eval_env = SequentialVectorEnv([
        lambda: to_rnn_dict_space_environment(env=env_name,
                                              rnn_steps=rnn_steps)
        for _ in range(4)
    ],
                                   logging_prefix="eval")

    # map observations to a modality
    obs_modalities_mappings = {"observation": "feature"}

    # define how to process a modality
    modality_config = dict()
    modality_config["feature"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [32, 32],
            "non_lin": "torch.nn.Tanh"
        }
    }
    modality_config["hidden"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [64],
            "non_lin": "torch.nn.Tanh"
        }
    }
    modality_config["recurrence"] = {}
    if rnn_steps > 0:
        modality_config["recurrence"] = {
            "block_type": "maze.perception.blocks.LSTMLastStepBlock",
            "block_params": {
                "hidden_size": 8,
                "num_layers": 1,
                "bidirectional": False,
                "non_lin": "torch.nn.Tanh"
            }
        }

    template_builder = TemplateModelComposer(
        action_spaces_dict=envs.action_spaces_dict,
        observation_spaces_dict=envs.observation_spaces_dict,
        agent_counts_dict=envs.agent_counts_dict,
        distribution_mapper_config={},
        model_builder=ConcatModelBuilder(modality_config,
                                         obs_modalities_mappings, None),
        policy={
            '_target_':
            'maze.perception.models.policies.ProbabilisticPolicyComposer'
        },
        critic={
            '_target_': 'maze.perception.models.critics.StateCriticComposer'
        })

    algorithm_config = A2CAlgorithmConfig(n_epochs=n_epochs,
                                          epoch_length=10,
                                          patience=10,
                                          critic_burn_in_epochs=0,
                                          n_rollout_steps=20,
                                          lr=0.0005,
                                          gamma=0.98,
                                          gae_lambda=1.0,
                                          policy_loss_coef=1.0,
                                          value_loss_coef=0.5,
                                          entropy_coef=0.0,
                                          max_grad_norm=0.0,
                                          device="cpu",
                                          rollout_evaluator=RolloutEvaluator(
                                              eval_env=eval_env,
                                              n_episodes=1,
                                              model_selection=None,
                                              deterministic=True))

    model = TorchActorCritic(policy=TorchPolicy(
        networks=template_builder.policy.networks,
        distribution_mapper=template_builder.distribution_mapper,
        device=algorithm_config.device),
                             critic=template_builder.critic,
                             device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              evaluator=algorithm_config.rollout_evaluator,
              algorithm_config=algorithm_config,
              model=model,
              model_selection=None)

    setup_logging(job_config=None)

    # train agent
    a2c.train()

    # final evaluation run
    print("Final Evaluation Run:")
    a2c.evaluate()
Example #18
0
def build_structured_with_critic_type(
    env, critics_composer_type: type(BaseStateCriticComposer),
    critics_type: type(TorchStateCritic),
    shared_embedding_keys: Optional[Union[List[str], Dict[StepKeyType,
                                                          List[str]]]]):
    """ helper function """

    # map observations to a modality
    obs_modalities = {
        "observation_0": "image",
        "observation_1": "feature",
        DeltaStateCriticComposer.prev_value_key: 'feature'
    }

    # define how to process a modality
    modality_config = dict()
    modality_config["feature"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [32, 32],
            "non_lin": "torch.nn.ReLU"
        }
    }
    modality_config['image'] = {
        'block_type': 'maze.perception.blocks.StridedConvolutionDenseBlock',
        'block_params': {
            'hidden_channels': [8, 16, 32],
            'hidden_kernels': [8, 4, 4],
            'convolution_dimension': 2,
            'hidden_strides': [4, 2, 2],
            'hidden_dilations': None,
            'hidden_padding': [1, 1, 1],
            'padding_mode': None,
            'hidden_units': [],
            'non_lin': 'torch.nn.SELU'
        }
    }

    modality_config["hidden"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [64],
            "non_lin": "torch.nn.ReLU"
        }
    }
    modality_config["recurrence"] = {}

    model_builder = {
        '_target_': 'maze.perception.builders.concat.ConcatModelBuilder',
        'modality_config': modality_config,
        'observation_modality_mapping': obs_modalities,
        'shared_embedding_keys': shared_embedding_keys
    }

    # initialize default model builder
    default_builder = TemplateModelComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        agent_counts_dict=env.agent_counts_dict,
        distribution_mapper_config={},
        model_builder=model_builder,
        policy={'_target_': policy_composer_type},
        critic={'_target_': critics_composer_type})

    # create model pdf
    default_builder.save_models()

    assert isinstance(default_builder.distribution_mapper, DistributionMapper)
    for pp in default_builder.policy.networks.values():
        assert isinstance(pp, nn.Module)
    for cc in default_builder.critic.networks.values():
        assert isinstance(cc, nn.Module)

    assert isinstance(default_builder.critic, critics_type)

    rollout_generator = RolloutGenerator(env=env,
                                         record_next_observations=False)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(
        policy, n_steps=10).stack().to_torch(device='cpu')

    policy_output = default_builder.policy.compute_policy_output(trajectory)
    critic_input = StateCriticInput.build(policy_output, trajectory)
    _ = default_builder.critic.predict_values(critic_input)
Example #19
0
def build_single_step_with_critic_type(
    critics_composer_type: type(BaseStateCriticComposer),
    critics_type: type(TorchStateCritic),
    shared_embedding_keys: Optional[Union[List[str], Dict[StepKeyType,
                                                          List[str]]]]):
    """ helper function """
    # init environment
    env = GymMazeEnv('CartPole-v0')
    observation_space = env.observation_space
    action_space = env.action_space

    # map observations to a modality
    obs_modalities = {
        obs_key: "feature"
        for obs_key in observation_space.spaces.keys()
    }
    # define how to process a modality
    modality_config = dict()
    modality_config["feature"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [32, 32],
            "non_lin": "torch.nn.ReLU"
        }
    }
    modality_config["hidden"] = {
        "block_type": "maze.perception.blocks.DenseBlock",
        "block_params": {
            "hidden_units": [64],
            "non_lin": "torch.nn.ReLU"
        }
    }
    modality_config["recurrence"] = {}

    model_builder = {
        '_target_': 'maze.perception.builders.concat.ConcatModelBuilder',
        'modality_config': modality_config,
        'observation_modality_mapping': obs_modalities,
        'shared_embedding_keys': shared_embedding_keys
    }

    # initialize default model builder
    default_builder = TemplateModelComposer(
        action_spaces_dict={0: action_space},
        observation_spaces_dict={0: observation_space},
        agent_counts_dict={0: 1},
        distribution_mapper_config={},
        model_builder=model_builder,
        policy={'_target_': policy_composer_type},
        critic={'_target_': critics_composer_type})

    # create model pdf
    default_builder.save_models()

    assert isinstance(default_builder.distribution_mapper, DistributionMapper)
    assert isinstance(default_builder.policy.networks[0], nn.Module)
    assert isinstance(default_builder.critic.networks[0], nn.Module)
    assert isinstance(default_builder.critic, critics_type)

    # test default policy gradient actor
    policy_net = default_builder.policy.networks[0]
    assert isinstance(policy_net, InferenceBlock)

    assert "action" in policy_net.out_keys
    assert policy_net.out_shapes()[0] == (2, )

    # test standalone critic
    value_net = default_builder.critic.networks[0]
    assert isinstance(value_net, InferenceBlock)
    assert "value" in value_net.out_keys
    assert value_net.out_shapes()[0] == (1, )

    if shared_embedding_keys is not None:
        if isinstance(shared_embedding_keys, list):
            assert all([
                shared_key in policy_net.out_keys
                for shared_key in shared_embedding_keys
            ])
            assert all([
                shared_key in value_net.in_keys
                for shared_key in shared_embedding_keys
            ])
        else:
            assert all([
                shared_key in policy_net.out_keys
                for shared_keylist in shared_embedding_keys.values()
                for shared_key in shared_keylist
            ])
            assert all([
                shared_key in value_net.in_keys
                for shared_keylist in shared_embedding_keys.values()
                for shared_key in shared_keylist
            ])
    else:
        assert value_net.in_keys == policy_net.in_keys

    rollout_generator = RolloutGenerator(env=env,
                                         record_next_observations=False)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(
        policy, n_steps=10).stack().to_torch(device='cpu')

    policy_output = default_builder.policy.compute_policy_output(trajectory)
    critic_input = StateCriticInput.build(policy_output, trajectory)
    _ = default_builder.critic.predict_values(critic_input)
Example #20
0
def test_redistributes_actor_reward_if_available():
    env = build_dummy_maze_env_with_structured_core_env()
    rollout_generator = RolloutGenerator(env=env)
    policy = RandomPolicy(env.action_spaces_dict)
    trajectory = rollout_generator.rollout(policy, n_steps=1)
    assert np.all(trajectory.step_records[0].rewards == [1, 1])