Exemplo n.º 1
0
    def prepare_for_trainer(
        cls,
        env_spec: EnvironmentSpec,
        config: Dict
    ) -> "AlgorithmSpec":
        config_ = config.copy()
        agents_config = config_["agents"]

        actor_params = agents_config["actor"]
        actor = AGENTS.get_from_params(
            **actor_params,
            env_spec=env_spec,
        )

        critic_params = agents_config["critic"]
        critic = AGENTS.get_from_params(
            **critic_params,
            env_spec=env_spec,
        )

        algorithm = cls(
            **config_["algorithm"],
            actor=actor,
            critic=critic,
        )

        return algorithm
Exemplo n.º 2
0
    def prepare_for_trainer(
        cls,
        env_spec: EnvironmentSpec,
        config: Dict
    ) -> "AlgorithmSpec":
        config_ = config.copy()
        agents_config = config_["agents"]

        actor_params = agents_config["actor"]
        actor = AGENTS.get_from_params(
            **actor_params,
            env_spec=env_spec,
        )

        critic_params = agents_config["critic"]
        critic = AGENTS.get_from_params(
            **critic_params,
            env_spec=env_spec,
        )

        action_space = env_spec.action_space
        assert isinstance(action_space, Box)
        action_boundaries = [
            action_space.low[0],
            action_space.high[0]
        ]

        algorithm = cls(
            **config_["algorithm"],
            action_boundaries=action_boundaries,
            actor=actor,
            critic=critic,
        )

        return algorithm
Exemplo n.º 3
0
    def prepare_for_trainer(cls, config):

        config_ = config.copy()

        actor_state_shape = (
            config_["shared"]["history_len"],
            config_["shared"]["observation_size"],
        )
        actor_action_size = config_["shared"]["action_size"]
        n_step = config_["shared"]["n_step"]
        gamma = config_["shared"]["gamma"]
        history_len = config_["shared"]["history_len"]
        trainer_state_shape = (config_["shared"]["observation_size"], )
        trainer_action_shape = (config_["shared"]["action_size"], )

        actor_params = config_["actor"]
        actor = AGENTS.get_from_params(**actor_params,
                                       state_shape=actor_state_shape,
                                       action_size=actor_action_size)

        critic_params = config_["critic"]
        critic = AGENTS.get_from_params(**critic_params,
                                        state_shape=actor_state_shape,
                                        action_size=actor_action_size)

        n_critics = config_["algorithm"].pop("n_critics", 2)
        critics = [
            AGENTS.get_from_params(**critic_params,
                                   state_shape=actor_state_shape,
                                   action_size=actor_action_size)
            for _ in range(n_critics - 1)
        ]

        algorithm = cls(**config_["algorithm"],
                        actor=actor,
                        critic=critic,
                        critics=critics,
                        n_step=n_step,
                        gamma=gamma)

        kwargs = {
            "algorithm": algorithm,
            "state_shape": trainer_state_shape,
            "action_shape": trainer_action_shape,
            "n_step": n_step,
            "gamma": gamma,
            "history_len": history_len
        }

        return kwargs
Exemplo n.º 4
0
    def prepare_for_sampler(cls, env_spec: EnvironmentSpec,
                            config: Dict) -> Union[ActorSpec, CriticSpec]:
        config_ = config.copy()
        agents_config = config_["agents"]
        actor_params = agents_config["actor"]
        actor = AGENTS.get_from_params(
            **actor_params,
            env_spec=env_spec,
        )

        return actor
Exemplo n.º 5
0
    def prepare_for_sampler(cls, config):

        config_ = config.copy()

        actor_state_shape = (
            config_["shared"]["history_len"],
            config_["shared"]["observation_size"],
        )
        actor_action_size = config_["shared"]["action_size"]

        actor_params = config_["actor"]
        actor = AGENTS.get_from_params(**actor_params,
                                       state_shape=actor_state_shape,
                                       action_size=actor_action_size)

        history_len = config_["shared"]["history_len"]

        kwargs = {"actor": actor, "history_len": history_len}

        return kwargs