Beispiel #1
0
def build_r2d2_model(
        policy: Policy, obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict) -> Tuple[ModelV2, ActionDistribution]:
    """Build q_model and target_model for DQN

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        q_model
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """

    # Create the policy's models.
    model = build_q_model(policy, obs_space, action_space, config)

    # Assert correct model type.
    assert model.get_initial_state() != [], \
        "R2D2 requires its model to be a recurrent one! Try using " \
        "`model.use_lstm` or `model.use_attention` in your config " \
        "to auto-wrap your model with an LSTM- or attention net."

    return model
Beispiel #2
0
def build_r2d2_model(policy: Policy, obs_space: gym.spaces.Space,
                     action_space: gym.spaces.Space, config: TrainerConfigDict
                     ) -> Tuple[ModelV2, ActionDistribution]:
    """Build q_model and target_model for DQN

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        q_model
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """

    # Create the policy's models.
    model = build_q_model(policy, obs_space, action_space, config)

    # Assert correct model type by checking the init state to be present.
    # For attention nets: These don't necessarily publish their init state via
    # Model.get_initial_state, but may only use the trajectory view API
    # (view_requirements).
    assert (model.get_initial_state() != [] or
            model.view_requirements.get("state_in_0") is not None), \
        "R2D2 requires its model to be a recurrent one! Try using " \
        "`model.use_lstm` or `model.use_attention` in your config " \
        "to auto-wrap your model with an LSTM- or attention net."

    return model