示例#1
0
def build_slateq_losses(
    policy: Policy,
    model: ModelV2,
    _,
    train_batch: SampleBatch,
) -> TensorType:
    """Constructs the choice- and Q-value losses for the SlateQTorchPolicy.

    Args:
        policy: The Policy to calculate the loss for.
        model: The Model to calculate the loss for.
        train_batch: The training data.

    Returns:
        The user-choice- and Q-value loss tensors.
    """

    # B=batch size
    # S=slate size
    # C=num candidates
    # E=embedding size
    # A=number of all possible slates

    # Q-value computations.
    # ---------------------
    # action.shape: [B, S]
    actions = train_batch[SampleBatch.ACTIONS]

    observation = convert_to_torch_tensor(
        train_batch[SampleBatch.OBS], device=actions.device
    )
    # user.shape: [B, E]
    user_obs = observation["user"]
    batch_size, embedding_size = user_obs.shape
    # doc.shape: [B, C, E]
    doc_obs = list(observation["doc"].values())

    A, S = policy.slates.shape

    # click_indicator.shape: [B, S]
    click_indicator = torch.stack(
        [k["click"] for k in observation["response"]], 1
    ).float()
    # item_reward.shape: [B, S]
    item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1)
    # q_values.shape: [B, C]
    q_values = model.get_q_values(user_obs, doc_obs)
    # slate_q_values.shape: [B, S]
    slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1)
    # Only get the Q from the clicked document.
    # replay_click_q.shape: [B]
    replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1)

    # Target computations.
    # --------------------
    next_obs = convert_to_torch_tensor(
        train_batch[SampleBatch.NEXT_OBS], device=actions.device
    )

    # user.shape: [B, E]
    user_next_obs = next_obs["user"]
    # doc.shape: [B, C, E]
    doc_next_obs = list(next_obs["doc"].values())
    # Only compute the watch time reward of the clicked item.
    reward = torch.sum(item_reward * click_indicator, dim=1)

    # TODO: Find out, whether it's correct here to use obs, not next_obs!
    # Dopamine uses obs, then next_obs only for the score.
    # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
    next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
    scores, score_no_click = score_documents(user_next_obs, doc_next_obs)

    # next_q_values_slate.shape: [B, A, S]
    indices = policy.slates_indices.to(next_q_values.device)
    next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
        [-1, A, S]
    )
    # scores_slate.shape [B, A, S]
    scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
    # score_no_click_slate.shape: [B, A]
    score_no_click_slate = torch.reshape(
        torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
    )

    # next_q_target_slate.shape: [B, A]
    next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / (
        torch.sum(scores_slate, dim=2) + score_no_click_slate
    )
    next_q_target_max, _ = torch.max(next_q_target_slate, dim=1)

    target = reward + policy.config["gamma"] * next_q_target_max * (
        1.0 - train_batch["dones"].float()
    )
    target = target.detach()

    clicked = torch.sum(click_indicator, dim=1)
    mask_clicked_slates = clicked > 0
    clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
    clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
    # Clicked_indices is a vector and torch.gather selects the batch dimension.
    q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
    target_clicked = torch.gather(target, 0, clicked_indices)

    td_error = torch.where(
        clicked.bool(),
        replay_click_q - target,
        torch.zeros_like(train_batch[SampleBatch.REWARDS]),
    )
    if policy.config["use_huber"]:
        loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
    else:
        loss = torch.pow(td_error, 2.0)
    loss = torch.mean(loss)
    td_error = torch.abs(td_error)
    mean_td_error = torch.mean(td_error)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_values"] = torch.mean(q_values)
    model.tower_stats["q_clicked"] = torch.mean(q_clicked)
    model.tower_stats["scores"] = torch.mean(scores)
    model.tower_stats["score_no_click"] = torch.mean(score_no_click)
    model.tower_stats["slate_q_values"] = torch.mean(slate_q_values)
    model.tower_stats["replay_click_q"] = torch.mean(replay_click_q)
    model.tower_stats["bellman_reward"] = torch.mean(reward)
    model.tower_stats["next_q_values"] = torch.mean(next_q_values)
    model.tower_stats["target"] = torch.mean(target)
    model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate)
    model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max)
    model.tower_stats["target_clicked"] = torch.mean(target_clicked)
    model.tower_stats["q_loss"] = loss
    model.tower_stats["td_error"] = td_error
    model.tower_stats["mean_td_error"] = mean_td_error
    model.tower_stats["mean_actions"] = torch.mean(actions.float())

    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        torch.stack(doc_obs, 1),
        1,
        # index.shape: [batch_size, slate_size, embedding_size]
        actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
    )

    scores = model.choice_model(user_obs, selected_doc)

    # click_indicator.shape: [batch_size, slate_size]
    # no_clicks.shape: [batch_size, 1]
    no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True)
    # targets.shape: [batch_size, slate_size+1]
    targets = torch.cat([click_indicator, no_clicks], dim=1)
    choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    model.tower_stats["choice_loss"] = choice_loss

    return choice_loss, loss
示例#2
0
def build_slateq_losses(
    policy: Policy,
    model: ModelV2,
    _: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> TensorType:
    """Constructs the choice- and Q-value losses for the SlateQTorchPolicy.

    Args:
        policy: The Policy to calculate the loss for.
        model: The Model to calculate the loss for.
        train_batch: The training data.

    Returns:
        Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors.
    """
    start = time.time()
    obs = restore_original_dimensions(train_batch[SampleBatch.OBS],
                                      policy.observation_space,
                                      tensorlib=torch)
    # user.shape: [batch_size, embedding_size]
    user = obs["user"]
    # doc.shape: [batch_size, num_docs, embedding_size]
    doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
    # action.shape: [batch_size, slate_size]
    actions = train_batch[SampleBatch.ACTIONS]

    next_obs = restore_original_dimensions(train_batch[SampleBatch.NEXT_OBS],
                                           policy.observation_space,
                                           tensorlib=torch)

    # Step 1: Build user choice model loss
    _, _, embedding_size = doc.shape
    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        input=doc,
        dim=1,
        # index.shape: [batch_size, slate_size, embedding_size]
        index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
    )

    scores = model.choice_model(user, selected_doc)
    choice_loss_fn = nn.CrossEntropyLoss()

    # clicks.shape: [batch_size, slate_size]
    clicks = torch.stack(
        [resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
    no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
    # clicks.shape: [batch_size, slate_size+1]
    targets = torch.cat([clicks, no_clicks], dim=1)
    choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    # Step 2: Build qvalue loss
    # Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
    # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
    # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
    # 'batch_indexes']
    learning_strategy = policy.config["slateq_strategy"]

    # Myopic agent: Don't care about value of next state.
    # Acts only based off immediate reward.
    if learning_strategy == "MYOP":
        next_q_values = torch.tensor(0.0, requires_grad=False)
    # Q-learning: Default setting for SlateQ -> Use DQN-style loss function.
    elif learning_strategy == "QL":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_user = next_obs["user"]
        dones = train_batch[SampleBatch.DONES]
        with torch.no_grad():
            if policy.config["double_q"]:
                next_target_per_slate_q_values = policy.target_models[
                    model].get_per_slate_q_values(next_user, next_doc)
                _, next_q_values, _ = model.choose_slate(
                    next_user, next_doc, next_target_per_slate_q_values)
            else:
                _, next_q_values, _ = policy.target_models[model].choose_slate(
                    next_user, next_doc)
        next_q_values = next_q_values.detach()
        next_q_values[dones.bool()] = 0.0
    # SARS'A': Use on-policy sarsa loss.
    elif learning_strategy == "SARSA":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_actions = train_batch["next_actions"]
        _, _, embedding_size = next_doc.shape
        # selected_doc.shape: [batch_size, slate_size, embedding_size]
        next_selected_doc = torch.gather(
            # input.shape: [batch_size, num_docs, embedding_size]
            input=next_doc,
            dim=1,
            # index.shape: [batch_size, slate_size, embedding_size]
            index=next_actions.unsqueeze(2).expand(-1, -1,
                                                   embedding_size).long(),
        )
        next_user = next_obs["user"]
        dones = train_batch[SampleBatch.DONES]
        with torch.no_grad():
            # q_values.shape: [batch_size, slate_size+1]
            q_values = model.q_model(next_user, next_selected_doc)
            # raw_scores.shape: [batch_size, slate_size+1]
            raw_scores = model.choice_model(next_user, next_selected_doc)
            max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
            scores = torch.exp(raw_scores - max_raw_scores)
            # next_q_values.shape: [batch_size]
            next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
                scores, dim=1)
            next_q_values[dones.bool()] = 0.0
    else:
        raise ValueError(learning_strategy)
    # target_q_values.shape: [batch_size]
    target_q_values = (train_batch[SampleBatch.REWARDS] +
                       policy.config["gamma"] * next_q_values)

    # q_values.shape: [batch_size, slate_size+1].
    q_values = model.q_model(user, selected_doc)
    # raw_scores.shape: [batch_size, slate_size+1].
    raw_scores = model.choice_model(user, selected_doc)
    max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
    scores = torch.exp(raw_scores - max_raw_scores)
    q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
        scores, dim=1)  # shape=[batch_size]
    td_error = torch.abs(q_values - target_q_values)
    q_value_loss = torch.mean(huber_loss(td_error))

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_loss"] = q_value_loss
    model.tower_stats["q_values"] = q_values
    model.tower_stats["next_q_values"] = next_q_values
    model.tower_stats["next_q_minus_q"] = next_q_values - q_values
    model.tower_stats["td_error"] = td_error
    model.tower_stats["target_q_values"] = target_q_values
    model.tower_stats["scores"] = scores
    model.tower_stats["raw_scores"] = raw_scores
    model.tower_stats["choice_loss"] = choice_loss
    model.tower_stats["choice_beta"] = model.choice_model.beta
    model.tower_stats[
        "choice_score_no_click"] = model.choice_model.score_no_click

    logger.debug(f"loss calculation took {time.time()-start}s")
    return choice_loss, q_value_loss