def build_slateq_losses( policy: Policy, model: ModelV2, _, train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: The user-choice- and Q-value loss tensors. """ # B=batch size # S=slate size # C=num candidates # E=embedding size # A=number of all possible slates # Q-value computations. # --------------------- # action.shape: [B, S] actions = train_batch[SampleBatch.ACTIONS] observation = convert_to_torch_tensor( train_batch[SampleBatch.OBS], device=actions.device ) # user.shape: [B, E] user_obs = observation["user"] batch_size, embedding_size = user_obs.shape # doc.shape: [B, C, E] doc_obs = list(observation["doc"].values()) A, S = policy.slates.shape # click_indicator.shape: [B, S] click_indicator = torch.stack( [k["click"] for k in observation["response"]], 1 ).float() # item_reward.shape: [B, S] item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1) # q_values.shape: [B, C] q_values = model.get_q_values(user_obs, doc_obs) # slate_q_values.shape: [B, S] slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1) # Only get the Q from the clicked document. # replay_click_q.shape: [B] replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1) # Target computations. # -------------------- next_obs = convert_to_torch_tensor( train_batch[SampleBatch.NEXT_OBS], device=actions.device ) # user.shape: [B, E] user_next_obs = next_obs["user"] # doc.shape: [B, C, E] doc_next_obs = list(next_obs["doc"].values()) # Only compute the watch time reward of the clicked item. reward = torch.sum(item_reward * click_indicator, dim=1) # TODO: Find out, whether it's correct here to use obs, not next_obs! # Dopamine uses obs, then next_obs only for the score. # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs) next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs) scores, score_no_click = score_documents(user_next_obs, doc_next_obs) # next_q_values_slate.shape: [B, A, S] indices = policy.slates_indices.to(next_q_values.device) next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape( [-1, A, S] ) # scores_slate.shape [B, A, S] scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S]) # score_no_click_slate.shape: [B, A] score_no_click_slate = torch.reshape( torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1] ) # next_q_target_slate.shape: [B, A] next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / ( torch.sum(scores_slate, dim=2) + score_no_click_slate ) next_q_target_max, _ = torch.max(next_q_target_slate, dim=1) target = reward + policy.config["gamma"] * next_q_target_max * ( 1.0 - train_batch["dones"].float() ) target = target.detach() clicked = torch.sum(click_indicator, dim=1) mask_clicked_slates = clicked > 0 clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device) clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates) # Clicked_indices is a vector and torch.gather selects the batch dimension. q_clicked = torch.gather(replay_click_q, 0, clicked_indices) target_clicked = torch.gather(target, 0, clicked_indices) td_error = torch.where( clicked.bool(), replay_click_q - target, torch.zeros_like(train_batch[SampleBatch.REWARDS]), ) if policy.config["use_huber"]: loss = huber_loss(td_error, delta=policy.config["huber_threshold"]) else: loss = torch.pow(td_error, 2.0) loss = torch.mean(loss) td_error = torch.abs(td_error) mean_td_error = torch.mean(td_error) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_values"] = torch.mean(q_values) model.tower_stats["q_clicked"] = torch.mean(q_clicked) model.tower_stats["scores"] = torch.mean(scores) model.tower_stats["score_no_click"] = torch.mean(score_no_click) model.tower_stats["slate_q_values"] = torch.mean(slate_q_values) model.tower_stats["replay_click_q"] = torch.mean(replay_click_q) model.tower_stats["bellman_reward"] = torch.mean(reward) model.tower_stats["next_q_values"] = torch.mean(next_q_values) model.tower_stats["target"] = torch.mean(target) model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate) model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max) model.tower_stats["target_clicked"] = torch.mean(target_clicked) model.tower_stats["q_loss"] = loss model.tower_stats["td_error"] = td_error model.tower_stats["mean_td_error"] = mean_td_error model.tower_stats["mean_actions"] = torch.mean(actions.float()) # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] torch.stack(doc_obs, 1), 1, # index.shape: [batch_size, slate_size, embedding_size] actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user_obs, selected_doc) # click_indicator.shape: [batch_size, slate_size] # no_clicks.shape: [batch_size, 1] no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True) # targets.shape: [batch_size, slate_size+1] targets = torch.cat([click_indicator, no_clicks], dim=1) choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) model.tower_stats["choice_loss"] = choice_loss return choice_loss, loss
def build_slateq_losses( policy: Policy, model: ModelV2, _: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors. """ start = time.time() obs = restore_original_dimensions(train_batch[SampleBatch.OBS], policy.observation_space, tensorlib=torch) # user.shape: [batch_size, embedding_size] user = obs["user"] # doc.shape: [batch_size, num_docs, embedding_size] doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1) # action.shape: [batch_size, slate_size] actions = train_batch[SampleBatch.ACTIONS] next_obs = restore_original_dimensions(train_batch[SampleBatch.NEXT_OBS], policy.observation_space, tensorlib=torch) # Step 1: Build user choice model loss _, _, embedding_size = doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user, selected_doc) choice_loss_fn = nn.CrossEntropyLoss() # clicks.shape: [batch_size, slate_size] clicks = torch.stack( [resp["click"][:, 1] for resp in next_obs["response"]], dim=1) no_clicks = 1 - torch.sum(clicks, 1, keepdim=True) # clicks.shape: [batch_size, slate_size+1] targets = torch.cat([clicks, no_clicks], dim=1) choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) # Step 2: Build qvalue loss # Fields in available in train_batch: ['t', 'eps_id', 'agent_index', # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions', # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights', # 'batch_indexes'] learning_strategy = policy.config["slateq_strategy"] # Myopic agent: Don't care about value of next state. # Acts only based off immediate reward. if learning_strategy == "MYOP": next_q_values = torch.tensor(0.0, requires_grad=False) # Q-learning: Default setting for SlateQ -> Use DQN-style loss function. elif learning_strategy == "QL": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): if policy.config["double_q"]: next_target_per_slate_q_values = policy.target_models[ model].get_per_slate_q_values(next_user, next_doc) _, next_q_values, _ = model.choose_slate( next_user, next_doc, next_target_per_slate_q_values) else: _, next_q_values, _ = policy.target_models[model].choose_slate( next_user, next_doc) next_q_values = next_q_values.detach() next_q_values[dones.bool()] = 0.0 # SARS'A': Use on-policy sarsa loss. elif learning_strategy == "SARSA": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_actions = train_batch["next_actions"] _, _, embedding_size = next_doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] next_selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=next_doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): # q_values.shape: [batch_size, slate_size+1] q_values = model.q_model(next_user, next_selected_doc) # raw_scores.shape: [batch_size, slate_size+1] raw_scores = model.choice_model(next_user, next_selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) # next_q_values.shape: [batch_size] next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) next_q_values[dones.bool()] = 0.0 else: raise ValueError(learning_strategy) # target_q_values.shape: [batch_size] target_q_values = (train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * next_q_values) # q_values.shape: [batch_size, slate_size+1]. q_values = model.q_model(user, selected_doc) # raw_scores.shape: [batch_size, slate_size+1]. raw_scores = model.choice_model(user, selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) # shape=[batch_size] td_error = torch.abs(q_values - target_q_values) q_value_loss = torch.mean(huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_loss"] = q_value_loss model.tower_stats["q_values"] = q_values model.tower_stats["next_q_values"] = next_q_values model.tower_stats["next_q_minus_q"] = next_q_values - q_values model.tower_stats["td_error"] = td_error model.tower_stats["target_q_values"] = target_q_values model.tower_stats["scores"] = scores model.tower_stats["raw_scores"] = raw_scores model.tower_stats["choice_loss"] = choice_loss model.tower_stats["choice_beta"] = model.choice_model.beta model.tower_stats[ "choice_score_no_click"] = model.choice_model.score_no_click logger.debug(f"loss calculation took {time.time()-start}s") return choice_loss, q_value_loss