Exemple #1
0
def _run_job(cfg: DictConfig) -> None:
    """Runs a regular maze job.

    :param cfg: Hydra configuration for the rollout.
    """
    set_matplotlib_backend()

    # If no env or agent base seed is given generate the seeds randomly and add them to the resolved hydra config
    if cfg.seeding.env_base_seed is None:
        cfg.seeding.env_base_seed = MazeSeeding.generate_seed_from_random_state(
            np.random.RandomState(None))
    if cfg.seeding.agent_base_seed is None:
        cfg.seeding.agent_base_seed = MazeSeeding.generate_seed_from_random_state(
            np.random.RandomState(None))

    # print and log config
    config_str = yaml.dump(OmegaConf.to_container(cfg, resolve=True),
                           sort_keys=False)
    with open("hydra_config.yaml", "w") as fp:
        fp.write("\n" + config_str)
    BColors.print_colored(config_str, color=BColors.HEADER)
    print("Output directory: {}\n".format(os.path.abspath(".")))

    # run job
    runner = Factory(base_type=Runner).instantiate(cfg.runner)
    runner.setup(cfg)
    runner.run()
Exemple #2
0
def test_rollouts_from_python():
    env, agent = GymMazeEnv("CartPole-v0"), DummyCartPolePolicy()

    sequential = SequentialRolloutRunner(n_episodes=2,
                                         max_episode_steps=2,
                                         record_trajectory=False,
                                         record_event_logs=False,
                                         render=False)
    sequential.maze_seeding = MazeSeeding(1234, 4321, False)
    sequential.run_with(env=env, wrappers={}, agent=agent)

    parallel = ParallelRolloutRunner(n_episodes=2,
                                     max_episode_steps=2,
                                     record_trajectory=False,
                                     record_event_logs=False,
                                     n_processes=2)
    parallel.maze_seeding = MazeSeeding(1234, 4321, False)
    # Test with a wrapper config as well
    parallel.run_with(env=env,
                      wrappers={
                          MazeEnvMonitoringWrapper: {
                              "observation_logging": True,
                              "action_logging": False,
                              "reward_logging": False
                          }
                      },
                      agent=agent)
Exemple #3
0
def main(n_epochs) -> None:
    """Trains the cart pole environment with the ES implementation.
    """

    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    obs_shapes = observation_spaces_to_in_shapes(env.observation_spaces_dict)
    action_shapes = {
        step_key: {
            action_head: distribution_mapper.required_logits_shape(action_head)
            for action_head in env.action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in env.action_spaces_dict.keys()
    }

    # initialize policies
    policies = [
        PolicyNet(obs_shapes=obs_shapes[0],
                  action_logits_shapes=action_shapes[0],
                  non_lin=nn.SELU)
    ]

    # initialize optimizer
    policy = TorchPolicy(networks=list_to_dict(policies),
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    shared_noise = SharedNoiseTable(count=1_000_000)

    algorithm_config = ESAlgorithmConfig(n_rollouts_per_update=100,
                                         n_timesteps_per_update=0,
                                         max_steps=0,
                                         optimizer=Adam(step_size=0.01),
                                         l2_penalty=0.005,
                                         noise_stddev=0.02,
                                         n_epochs=n_epochs,
                                         policy_wrapper=None)

    trainer = ESTrainer(algorithm_config=algorithm_config,
                        torch_policy=policy,
                        shared_noise=shared_noise,
                        normalization_stats=None)

    setup_logging(job_config=None)

    maze_rng = np.random.RandomState(None)

    # run with pseudo-distribution, without worker processes
    trainer.train(ESDummyDistributedRollouts(
        env=env,
        n_eval_rollouts=10,
        shared_noise=shared_noise,
        agent_instance_seed=MazeSeeding.generate_seed_from_random_state(
            maze_rng)),
                  model_selection=None)
Exemple #4
0
def test_sequential_rollout_with_rendering():
    env, agent = GymMazeEnv("CartPole-v0"), DummyCartPolePolicy()
    sequential = SequentialRolloutRunner(n_episodes=2,
                                         max_episode_steps=2,
                                         record_trajectory=True,
                                         record_event_logs=False,
                                         render=True)
    sequential.maze_seeding = MazeSeeding(1234, 4321, False)
    sequential.run_with(env=env, wrappers={}, agent=agent)
Exemple #5
0
    def seed(self, seed: int) -> None:
        """Apply seed to wrappers rng, and pass the seed forward to the env
        """
        # Create new random state for sampling the random steps
        self.wrapper_rng = np.random.RandomState(seed)
        # Set seed of action space for sampling actions
        self.action_space.seed(
            MazeSeeding.generate_seed_from_random_state(self.wrapper_rng))

        return self.env.seed(seed)
Exemple #6
0
    def setup(self, cfg: DictConfig) -> None:
        """
        Sets up prerequisites to rollouts.
        :param cfg: DictConfig defining components to initialize.
        """

        self._cfg = cfg
        self.input_dir = cfg.input_dir

        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(cfg.seeding.env_base_seed,
                                        cfg.seeding.agent_base_seed,
                                        cfg.seeding.cudnn_determinism_flag)
Exemple #7
0
    def __init__(self, n_episodes: int, max_episode_steps: int,
                 record_trajectory: bool, record_event_logs: bool):
        self.n_episodes = n_episodes
        self.max_episode_steps = max_episode_steps
        self.record_trajectory = record_trajectory
        self.record_event_logs = record_event_logs
        self._cfg: Optional[DictConfig] = None

        # keep track of the input directory
        self.input_dir = None

        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(
            np.random.randint(np.iinfo(np.int32).max),
            np.random.randint(np.iinfo(np.int32).max), False)
Exemple #8
0
def perform_seeding_test(env: MazeEnv, policy: Policy, is_deterministic_env: bool, is_deterministic_agent: bool,
                         n_steps: int = 100) \
        -> None:
    """Perform a test on the seeding capabilities of a given env and agent.
        Within this method a rollout is generated with sampled seeds where the observation and actions are recorded and
        hashed. Then a second rollouts is generated with the SAME seeds to check if the results stay the same. Finally
        the seeds are changes and it is checks if the resulting observations and actions change as well to ensure
        randomness in the env.

    :param env: The Env to test.
    :param policy: The policy to compute the actions with.
    :param is_deterministic_env: Specify whether the given env is deterministic.
    :param is_deterministic_agent: Specify whether the given policy is deterministic.
    :param n_steps: Number of steps to perform the comparison on.
    """

    maze_rng = np.random.RandomState(1234)

    agent_seed = MazeSeeding.generate_seed_from_random_state(maze_rng)
    env_seed = MazeSeeding.generate_seed_from_random_state(maze_rng)

    # Perform a rollout and get hash values to compare other runs to.
    base_obs_hash, base_action_hash = get_obs_action_hash_for_env_agent(
        env, policy, env_seed, agent_seed, n_steps)

    # Perform a second rollout with the same seeds, and check the results are exactly the same.
    second_obs_hash, second_action_hash = get_obs_action_hash_for_env_agent(
        env, policy, env_seed, agent_seed, n_steps)

    assert base_obs_hash == second_obs_hash
    assert base_action_hash == second_action_hash

    # Change the agent seed and check that the values change if the agent is not deterministic and stay the same if it
    #  is
    agent_seed_2 = MazeSeeding.generate_seed_from_random_state(maze_rng)
    agent_2_obs_hash, agent_2_action_hash = get_obs_action_hash_for_env_agent(
        env, policy, env_seed, agent_seed_2, n_steps)
    if is_deterministic_agent:
        assert base_obs_hash == agent_2_obs_hash
        assert base_action_hash == agent_2_action_hash
    else:
        assert base_action_hash != agent_2_action_hash

    # Change the env seed and check that the values change if the env is deterministic, and stay the same if it is.
    env_seed_2 = MazeSeeding.generate_seed_from_random_state(maze_rng)
    env_2_obs_hash, env_2_action_hash = get_obs_action_hash_for_env_agent(
        env, policy, env_seed_2, agent_seed, n_steps)
    if is_deterministic_env:
        assert base_obs_hash == env_2_obs_hash
        assert base_action_hash == env_2_action_hash
    else:
        assert base_obs_hash != env_2_obs_hash

    env_agent_2_obs_hash, env_agent_2_action_hash = get_obs_action_hash_for_env_agent(
        env, policy, env_seed_2, agent_seed_2, n_steps)
    if is_deterministic_env and is_deterministic_agent:
        assert base_obs_hash == env_agent_2_obs_hash
        assert base_action_hash == env_agent_2_action_hash
    elif is_deterministic_env:
        assert base_action_hash != env_agent_2_action_hash
    elif is_deterministic_agent:
        assert base_obs_hash != env_agent_2_obs_hash
    else:
        assert base_action_hash != env_agent_2_action_hash
        assert base_obs_hash != env_agent_2_obs_hash
Exemple #9
0
 def seed(self, seed: int) -> None:
     """Seed the policy by setting the action space seeds."""
     rng = np.random.RandomState(seed)
     for key, action_space in self.action_spaces_dict.items():
         action_space.seed(MazeSeeding.generate_seed_from_random_state(rng))
     pass
Exemple #10
0
    def setup(self, cfg: DictConfig) -> None:
        """
        Sets up prerequisites to training.
        Includes wrapping the environment for observation normalization, instantiating the model composer etc.
        :param cfg: DictConfig defining components to initialize.
        """

        self._cfg = cfg

        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(cfg.seeding.env_base_seed,
                                        cfg.seeding.agent_base_seed,
                                        cfg.seeding.cudnn_determinism_flag)

        with SwitchWorkingDirectoryToInput(cfg.input_dir):
            assert isinstance(cfg.env, DictConfig) or isinstance(
                cfg.env, Callable)
            wrapper_cfg = omegaconf.OmegaConf.to_object(
                cfg["wrappers"]) if "wrappers" in cfg else {}

            # if the observation normalization is already available, read it from the input directory
            if isinstance(cfg.env, DictConfig):
                self.env_factory = EnvFactory(
                    omegaconf.OmegaConf.to_object(cfg["env"]), wrapper_cfg)
            elif isinstance(cfg.env, Callable):
                env_fn = omegaconf.OmegaConf.to_container(cfg)["env"]
                self.env_factory = lambda: WrapperFactory.wrap_from_config(
                    env_fn(), wrapper_cfg)

            normalization_env = self.env_factory()
            normalization_env.seed(
                self.maze_seeding.generate_env_instance_seed())

        # Observation normalization
        self._normalization_statistics = obtain_normalization_statistics(
            normalization_env, n_samples=self.normalization_samples)
        if self._normalization_statistics:
            self.env_factory = make_normalized_env_factory(
                self.env_factory, self._normalization_statistics)
            # dump statistics to current working directory
            assert isinstance(normalization_env,
                              ObservationNormalizationWrapper)
            normalization_env.dump_statistics()

        # Generate an agent seed and set the seed globally for the model initialization
        set_seeds_globally(self.maze_seeding.agent_global_seed,
                           self.maze_seeding.cudnn_determinism_flag,
                           info_txt=f'training runner (Pid:{os.getpid()})')

        # init model composer
        composer_type = Factory(base_type=BaseModelComposer).type_from_name(
            cfg.model['_target_'])
        composer_type.check_model_config(cfg.model)

        # todo Factory.instantiate returns specified dicts as DictConfig, i.e. many specified types are wrong. How do we
        #  go about this? DictConfig behaves similarly to Dict for all intents and purposes, but typing is still off/
        #  misleading. This is independent from our Python training API and can apparently not be changed, i.e. kwargs
        #  seems to be always converted to DictConfig/ListConfig.
        self._model_composer = Factory(
            base_type=BaseModelComposer).instantiate(
                cfg.model,
                action_spaces_dict=normalization_env.action_spaces_dict,
                observation_spaces_dict=normalization_env.
                observation_spaces_dict,
                agent_counts_dict=normalization_env.agent_counts_dict)

        SpacesConfig(self._model_composer.action_spaces_dict,
                     self._model_composer.observation_spaces_dict,
                     self._model_composer.agent_counts_dict).save(
                         self.spaces_config_dump_file)

        # Should be done after the normalization runs, otherwise stats from those will get logged as well.
        setup_logging(job_config=cfg)

        # close normalization env
        normalization_env.close()
Exemple #11
0
class TrainingRunner(Runner):
    """
    Base class for training runner implementations.
    """

    state_dict_dump_file: str
    """Where to save the best model (output directory handled by hydra)."""
    dump_interval: Optional[int]
    """If provided the state dict will be dumped ever 'dump_interval' epochs."""
    spaces_config_dump_file: str
    """Where to save the env spaces configuration (output directory handled by hydra)."""
    normalization_samples: int
    """Number of samples (=steps) to collect normalization statistics at the beginning of the
    training."""

    env_factory: Optional[Union[EnvFactory, Callable[[], Union[
        StructuredEnv, StructuredEnvSpacesMixin,
        ObservationNormalizationWrapper]]]] = dataclasses.field(default=None,
                                                                init=False)
    _model_composer: Optional[BaseModelComposer] = dataclasses.field(
        default=None, init=False)
    _model_selection: Optional[BestModelSelection] = dataclasses.field(
        default=None, init=False)
    _normalization_statistics: Optional[
        StructuredStatisticsType] = dataclasses.field(default=None, init=False)
    _trainer: Optional[Trainer] = dataclasses.field(default=None, init=False)
    _cfg: Optional[DictConfig] = dataclasses.field(default=None, init=False)

    def setup(self, cfg: DictConfig) -> None:
        """
        Sets up prerequisites to training.
        Includes wrapping the environment for observation normalization, instantiating the model composer etc.
        :param cfg: DictConfig defining components to initialize.
        """

        self._cfg = cfg

        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(cfg.seeding.env_base_seed,
                                        cfg.seeding.agent_base_seed,
                                        cfg.seeding.cudnn_determinism_flag)

        with SwitchWorkingDirectoryToInput(cfg.input_dir):
            assert isinstance(cfg.env, DictConfig) or isinstance(
                cfg.env, Callable)
            wrapper_cfg = omegaconf.OmegaConf.to_object(
                cfg["wrappers"]) if "wrappers" in cfg else {}

            # if the observation normalization is already available, read it from the input directory
            if isinstance(cfg.env, DictConfig):
                self.env_factory = EnvFactory(
                    omegaconf.OmegaConf.to_object(cfg["env"]), wrapper_cfg)
            elif isinstance(cfg.env, Callable):
                env_fn = omegaconf.OmegaConf.to_container(cfg)["env"]
                self.env_factory = lambda: WrapperFactory.wrap_from_config(
                    env_fn(), wrapper_cfg)

            normalization_env = self.env_factory()
            normalization_env.seed(
                self.maze_seeding.generate_env_instance_seed())

        # Observation normalization
        self._normalization_statistics = obtain_normalization_statistics(
            normalization_env, n_samples=self.normalization_samples)
        if self._normalization_statistics:
            self.env_factory = make_normalized_env_factory(
                self.env_factory, self._normalization_statistics)
            # dump statistics to current working directory
            assert isinstance(normalization_env,
                              ObservationNormalizationWrapper)
            normalization_env.dump_statistics()

        # Generate an agent seed and set the seed globally for the model initialization
        set_seeds_globally(self.maze_seeding.agent_global_seed,
                           self.maze_seeding.cudnn_determinism_flag,
                           info_txt=f'training runner (Pid:{os.getpid()})')

        # init model composer
        composer_type = Factory(base_type=BaseModelComposer).type_from_name(
            cfg.model['_target_'])
        composer_type.check_model_config(cfg.model)

        # todo Factory.instantiate returns specified dicts as DictConfig, i.e. many specified types are wrong. How do we
        #  go about this? DictConfig behaves similarly to Dict for all intents and purposes, but typing is still off/
        #  misleading. This is independent from our Python training API and can apparently not be changed, i.e. kwargs
        #  seems to be always converted to DictConfig/ListConfig.
        self._model_composer = Factory(
            base_type=BaseModelComposer).instantiate(
                cfg.model,
                action_spaces_dict=normalization_env.action_spaces_dict,
                observation_spaces_dict=normalization_env.
                observation_spaces_dict,
                agent_counts_dict=normalization_env.agent_counts_dict)

        SpacesConfig(self._model_composer.action_spaces_dict,
                     self._model_composer.observation_spaces_dict,
                     self._model_composer.agent_counts_dict).save(
                         self.spaces_config_dump_file)

        # Should be done after the normalization runs, otherwise stats from those will get logged as well.
        setup_logging(job_config=cfg)

        # close normalization env
        normalization_env.close()

    def run(self, n_epochs: Optional[int] = None, **train_kwargs) -> None:
        """
        Runs training.
        While this method is designed to be overriden by individual subclasses, it provides some functionality
        that is useful in general:

        - Building the env factory for env + wrappers
        - Estimating normalization statistics from the env
        - If successfully estimated, wrapping the env factory so that envs are already built with the statistics
        - Building the model composer from model config and env spaces config
        - Serializing the env spaces configuration (so that the model composer can be re-loaded for future rollout)
        - Initializing logging setup

        :param n_epochs: Number of epochs to train.
        :param train_kwargs: Additional arguments for trainer.train().
        """

        self._trainer.train(n_epochs=self._cfg.algorithm.n_epochs
                            if n_epochs is None else n_epochs,
                            **train_kwargs)

    @classmethod
    def _init_trainer_from_input_dir(cls, trainer: Trainer,
                                     state_dict_dump_file: str,
                                     input_dir: str) -> None:
        """Initialize trainer from given state dict and input directory.
        :param trainer: The trainer to initialize.
        :param state_dict_dump_file: The state dict dump file relative to input_dir.
        :param input_dir: The directory to load the state dict from.
        """
        with SwitchWorkingDirectoryToInput(input_dir):
            if os.path.exists(state_dict_dump_file):
                BColors.print_colored(
                    f"Trainer and model initialized from '{state_dict_dump_file}' of run '{input_dir}'!",
                    BColors.OKGREEN)
                trainer.load_state(state_dict_dump_file)
            else:
                BColors.print_colored(
                    "Model initialized with random weights! ", BColors.OKGREEN)

    @property
    def model_composer(self) -> BaseModelComposer:
        """
        Returns model composer.
        :return: Model composer.
        """

        return self._model_composer

    @property
    def cfg(self) -> DictConfig:
        """
        Returns Hydra config.
        :return: Hydra config.
        """

        return self._cfg

    @property
    def trainer(self) -> Trainer:
        """Returns the runners training instance.
        :return: Instantiated trainer.
        """
        return self._trainer
Exemple #12
0
    def setup(self, cfg: DictConfig):
        """
        This method initializes and registers all necessary maze components with RLlib

        :param cfg: Full Hydra run job config
        """
        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(cfg.seeding.env_base_seed,
                                        cfg.seeding.agent_base_seed,
                                        cfg.seeding.cudnn_determinism_flag)

        self._cfg = cfg

        # Initialize env factory (with rllib monkey patches)
        self.env_factory = build_maze_rllib_env_factory(cfg)

        # Register maze env factory with rllib
        tune.register_env("maze_env", lambda x: self.env_factory())

        # Register maze model and distribution mapper if a maze model should be used
        # Check whether we are using the rllib default model composer or a maze model
        using_rllib_model_composer = '_target_' not in cfg.model.keys()
        if not using_rllib_model_composer:
            # Get model class
            model_cls = Factory(MazeRLlibBaseModel).type_from_name(
                cfg.algorithm.model_cls)
            # Register maze model
            ModelCatalog.register_custom_model("maze_model", model_cls)

            if 'policy' in cfg.model and "networks" in cfg.model.policy:
                assert len(cfg.model.policy.networks
                           ) == 1, 'Hierarchical envs are not yet supported'

            # register maze action distribution
            ModelCatalog.register_custom_action_dist(
                'maze_dist', MazeRLlibActionDistribution)
            model_config = {
                "custom_action_dist": 'maze_dist',
                "custom_model": "maze_model",
                "vf_share_layers": False,
                "custom_model_config": {
                    "maze_model_composer_config": cfg.model,
                    'spaces_config_dump_file': self.spaces_config_dump_file,
                    'state_dict_dump_file': self.state_dict_dump_file
                }
            }
        else:
            # If specified use the default rllib model builder
            model_config = OmegaConf.to_container(cfg.model, resolve=True)

        # Build rllib config
        maze_rllib_config = {
            "env": "maze_env",
            # Store env config for possible later use
            "env_config": {
                'env': cfg.env,
                'wrappers': cfg.wrappers
            },
            "model": model_config,
            'callbacks': MazeRLlibLoggingCallbacks,
            "framework": "torch"
        }
        # Load the algorithm config and update the custom parameters
        rllib_config: Dict = OmegaConf.to_container(cfg.algorithm.config,
                                                    resolve=True)
        assert 'model' not in rllib_config, 'The config should be removed from the default yaml files since it will ' \
                                            'be dynamically written'
        assert self.num_workers == rllib_config['num_workers']
        rllib_config.update(maze_rllib_config)

        if rllib_config['seed'] is None:
            rllib_config[
                'seed'] = self.maze_seeding.generate_env_instance_seed()

        # Initialize ray with the passed ray_config parameters
        ray_config: Dict = OmegaConf.to_container(self.ray_config,
                                                  resolve=True)

        # Load tune parameters
        tune_config = OmegaConf.to_container(self.tune_params, resolve=True)
        tune_config['callbacks'] = [MazeRLlibSaveModelCallback()]

        # Start tune experiment
        assert 'config' not in tune_config, 'The config should be removed from the default yaml files since it will ' \
                                            'be dynamically written'

        self.ray_config = ray_config
        self.rllib_config = rllib_config
        self.tune_config = tune_config
Exemple #13
0
class MazeRLlibRunner(Runner):
    """Base class for rllib runner

    :param spaces_config_dump_file: The path to the spaces config dump file.
    :param normalization_samples: The number of normalization samples that should be computed in order to estimated
        observation statistics.
    :param num_workers: The number of worker that should be used.
    :param tune_config: The tune config, used as arguments when starting tune.run().
    :param ray_config: The ray config, used as arguments when initializing ray (ray.init()).
    :param state_dict_dump_file: The path to the state dict dump file.
    """

    normalization_samples: int
    """Number of samples (=steps) to collect normalization statistics at the beginning of the training."""

    spaces_config_dump_file: str
    """Where to save the env spaces configuration (output directory handled by hydra)"""

    num_workers: int
    """Specifies the number of workers for the ray rllib distribution"""

    tune_params: Dict[str, Any]
    """Specify the parameters for the ray.tune method"""

    ray_config: Dict[str, Any]
    """Specify the parameters for the ray.init method"""

    state_dict_dump_file: str
    """Where to save the best model (output directory handled by hydra)"""
    def __init__(self, spaces_config_dump_file: str,
                 normalization_samples: int, num_workers: int,
                 tune_config: Dict[str, Any], ray_config: Dict[str, Any],
                 state_dict_dump_file: str):
        self.spaces_config_dump_file = spaces_config_dump_file
        self.normalization_samples = normalization_samples
        self.num_workers = num_workers
        self.ray_config = ray_config
        self.tune_params = tune_config
        self.state_dict_dump_file = state_dict_dump_file

        self.env_factory = None
        self.model_composer = None
        self.normalization_statistics = None

        self._cfg: Optional[DictConfig] = None
        self.rllib_config: Optional[Dict[str, Any]] = None
        self.tune_config: Optional[Dict[str, Any]] = None

    def setup(self, cfg: DictConfig):
        """
        This method initializes and registers all necessary maze components with RLlib

        :param cfg: Full Hydra run job config
        """
        # Generate a random state used for sampling random seeds for the envs and agents
        self.maze_seeding = MazeSeeding(cfg.seeding.env_base_seed,
                                        cfg.seeding.agent_base_seed,
                                        cfg.seeding.cudnn_determinism_flag)

        self._cfg = cfg

        # Initialize env factory (with rllib monkey patches)
        self.env_factory = build_maze_rllib_env_factory(cfg)

        # Register maze env factory with rllib
        tune.register_env("maze_env", lambda x: self.env_factory())

        # Register maze model and distribution mapper if a maze model should be used
        # Check whether we are using the rllib default model composer or a maze model
        using_rllib_model_composer = '_target_' not in cfg.model.keys()
        if not using_rllib_model_composer:
            # Get model class
            model_cls = Factory(MazeRLlibBaseModel).type_from_name(
                cfg.algorithm.model_cls)
            # Register maze model
            ModelCatalog.register_custom_model("maze_model", model_cls)

            if 'policy' in cfg.model and "networks" in cfg.model.policy:
                assert len(cfg.model.policy.networks
                           ) == 1, 'Hierarchical envs are not yet supported'

            # register maze action distribution
            ModelCatalog.register_custom_action_dist(
                'maze_dist', MazeRLlibActionDistribution)
            model_config = {
                "custom_action_dist": 'maze_dist',
                "custom_model": "maze_model",
                "vf_share_layers": False,
                "custom_model_config": {
                    "maze_model_composer_config": cfg.model,
                    'spaces_config_dump_file': self.spaces_config_dump_file,
                    'state_dict_dump_file': self.state_dict_dump_file
                }
            }
        else:
            # If specified use the default rllib model builder
            model_config = OmegaConf.to_container(cfg.model, resolve=True)

        # Build rllib config
        maze_rllib_config = {
            "env": "maze_env",
            # Store env config for possible later use
            "env_config": {
                'env': cfg.env,
                'wrappers': cfg.wrappers
            },
            "model": model_config,
            'callbacks': MazeRLlibLoggingCallbacks,
            "framework": "torch"
        }
        # Load the algorithm config and update the custom parameters
        rllib_config: Dict = OmegaConf.to_container(cfg.algorithm.config,
                                                    resolve=True)
        assert 'model' not in rllib_config, 'The config should be removed from the default yaml files since it will ' \
                                            'be dynamically written'
        assert self.num_workers == rllib_config['num_workers']
        rllib_config.update(maze_rllib_config)

        if rllib_config['seed'] is None:
            rllib_config[
                'seed'] = self.maze_seeding.generate_env_instance_seed()

        # Initialize ray with the passed ray_config parameters
        ray_config: Dict = OmegaConf.to_container(self.ray_config,
                                                  resolve=True)

        # Load tune parameters
        tune_config = OmegaConf.to_container(self.tune_params, resolve=True)
        tune_config['callbacks'] = [MazeRLlibSaveModelCallback()]

        # Start tune experiment
        assert 'config' not in tune_config, 'The config should be removed from the default yaml files since it will ' \
                                            'be dynamically written'

        self.ray_config = ray_config
        self.rllib_config = rllib_config
        self.tune_config = tune_config

    @override(Runner)
    def run(self) -> None:
        """
        This method initializes and registers all necessary maze components with RLlib before initializing ray and
        starting a tune.run with the config parameters.
        """

        # Init ray
        ray.init(**self.ray_config)

        # Run tune
        tune.run(self._cfg.algorithm.algorithm,
                 config=self.rllib_config,
                 **self.tune_config)

        # Shutdown ray
        ray.shutdown()