Esempio n. 1
0
def action_distribution_fn(
    policy: Policy,
    model: SlateQTorchModel,
    input_dict,
    *,
    explore,
    is_training,
    **kwargs,
):
    """Determine which action to take."""

    observation = input_dict[SampleBatch.OBS]

    # user.shape: [B, E]
    user_obs = observation["user"]
    doc_obs = list(observation["doc"].values())

    # Compute scores per candidate.
    scores, score_no_click = score_documents(user_obs, doc_obs)
    # Compute Q-values per candidate.
    q_values = model.get_q_values(user_obs, doc_obs)

    per_slate_q_values = get_per_slate_q_values(
        policy, score_no_click, scores, q_values
    )
    if not hasattr(model, "slates"):
        model.slates = policy.slates
    return per_slate_q_values, TorchCategorical, []
Esempio n. 2
0
def build_slateq_model_and_distribution(
    policy: Policy,
    obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict,
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
    """Build models for SlateQ

    Args:
        policy: The policy, which will use the model for optimization.
        obs_space: The policy's observation space.
        action_space: The policy's action space.
        config: The Trainer's config dict.

    Returns:
        Tuple consisting of 1) Q-model and 2) an action distribution class.
    """
    model = SlateQTorchModel(
        obs_space,
        action_space,
        num_outputs=action_space.nvec[0],
        model_config=config["model"],
        name="slateq_model",
        fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
    )

    policy.target_model = SlateQTorchModel(
        obs_space,
        action_space,
        num_outputs=action_space.nvec[0],
        model_config=config["model"],
        name="target_slateq_model",
        fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
    )

    return model, TorchCategorical