Пример #1
0
def build_r2d2_model_and_distribution(
    policy: Policy, obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict) -> \
        Tuple[ModelV2, TorchDistributionWrapper]:
    """Build q_model and target_q_model for DQN

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        (q_model, TorchCategorical)
            Note: The target q model will not be returned, just assigned to
            `policy.target_q_model`.
    """

    # Create the policy's models and action dist class.
    model, distribution_cls = build_q_model_and_distribution(
        policy, obs_space, action_space, config)

    # Assert correct model type.
    assert model.get_initial_state() != [], \
        "R2D2 requires its model to be a recurrent one! Try using " \
        "`model.use_lstm` or `model.use_attention` in your config " \
        "to auto-wrap your model with an LSTM- or attention net."

    return model, distribution_cls
Пример #2
0
def build_r2d2_model_and_distribution(
    policy: Policy, obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict) -> \
        Tuple[ModelV2, TorchDistributionWrapper]:
    """Build q_model and target_model for DQN

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        (q_model, TorchCategorical)
            Note: The target q model will not be returned, just assigned to
            `policy.target_model`.
    """

    # Create the policy's models and action dist class.
    model, distribution_cls = build_q_model_and_distribution(
        policy, obs_space, action_space, config)

    # Assert correct model type by checking the init state to be present.
    # For attention nets: These don't necessarily publish their init state via
    # Model.get_initial_state, but may only use the trajectory view API
    # (view_requirements).
    assert (model.get_initial_state() != [] or
            model.view_requirements.get("state_in_0") is not None), \
        "R2D2 requires its model to be a recurrent one! Try using " \
        "`model.use_lstm` or `model.use_attention` in your config " \
        "to auto-wrap your model with an LSTM- or attention net."

    return model, distribution_cls