Example #1
0
def build_q_models(policy: Policy, obs_space: gym.Space,
                   action_space: gym.Space,
                   config: TrainerConfigDict) -> ModelV2:

    if not isinstance(action_space, gym.spaces.Discrete):
        raise UnsupportedSpaceException(
            "Action space {} is not supported for DQN.".format(action_space))

    policy.q_model = ModelCatalog.get_model_v2(obs_space=obs_space,
                                               action_space=action_space,
                                               num_outputs=action_space.n,
                                               model_config=config["model"],
                                               framework=config["framework"],
                                               name=Q_SCOPE)

    policy.target_q_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_TARGET_SCOPE)

    policy.q_func_vars = policy.q_model.variables()
    policy.target_q_func_vars = policy.target_q_model.variables()

    return policy.q_model
Example #2
0
def _build_q_models(policy: Policy, obs_space: gym.spaces.Space,
                    action_space: gym.spaces.Space,
                    config: TrainerConfigDict) -> ModelV2:
    """Build q_model and target_q_model for Simple Q learning

    Note that this function works for both Tensorflow and PyTorch.

    Args:
        policy (Policy): The Policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        ModelV2: The Model for the Policy to use.
            Note: The target q model will not be returned, just assigned to
            `policy.target_q_model`.
    """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise UnsupportedSpaceException(
            "Action space {} is not supported for DQN.".format(action_space))

    policy.q_model = ModelCatalog.get_model_v2(obs_space=obs_space,
                                               action_space=action_space,
                                               num_outputs=action_space.n,
                                               model_config=config["model"],
                                               framework=config["framework"],
                                               name=Q_SCOPE)
    if torch.cuda.is_available():
        policy.q_model = policy.q_model.to("cuda")

    policy.target_q_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=config["model"],
        framework=config["framework"],
        name=Q_TARGET_SCOPE)
    if torch.cuda.is_available():
        policy.target_q_model = policy.target_q_model.to("cuda")

    policy.q_func_vars = policy.q_model.variables()
    policy.target_q_func_vars = policy.target_q_model.variables()

    return policy.q_model