def process_tgt_seq(action): if action is not None: _, output_size = action.shape # Account for decoder starting symbol and padding symbol candidates_augment = torch.cat( ( torch.zeros( batch_size, 2, candidate_dim, device=device), candidates, ), dim=1, ) tgt_out_idx = action + 2 tgt_in_idx = torch.full((batch_size, output_size), DECODER_START_SYMBOL, device=device) tgt_in_idx[:, 1:] = tgt_out_idx[:, :-1] tgt_out_seq = gather(candidates_augment, tgt_out_idx) tgt_in_seq = torch.zeros(batch_size, output_size, candidate_dim, device=device) tgt_in_seq[:, 1:] = tgt_out_seq[:, :-1] tgt_tgt_mask = subsequent_mask(output_size, device) else: tgt_in_idx = None tgt_out_idx = None tgt_in_seq = None tgt_out_seq = None tgt_tgt_mask = None return tgt_in_idx, tgt_out_idx, tgt_in_seq, tgt_out_seq, tgt_tgt_mask
def rank_on_policy_and_eval(seq2slate_net, batch: rlt.PreprocessedRankingInput, tgt_seq_len: int, greedy: bool): model_propensity, model_action = rank_on_policy(seq2slate_net, batch, tgt_seq_len, greedy=greedy) ranked_cities = gather(batch.src_seq.float_features, model_action) reward = compute_reward(ranked_cities) return model_propensity, model_action, reward
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor], ): # state_value.shape == state_presence.shape == batch_size x state_feat_num # candidate_value.shape == candidate_presence.shape == # batch_size x max_src_seq_len x candidate_feat_num batch_size = state_with_presence[0].shape[0] max_tgt_seq_len = self.model.max_tgt_seq_len max_src_seq_len = self.model.max_src_seq_len # we use a fake slate_idx_with_presence to retrive the first # max_tgt_seq_len candidates from # len(slate_idx_with presence) == batch_size # component: 1d tensor with length max_tgt_seq_len slate_idx_with_presence = [ (torch.arange(max_tgt_seq_len), torch.ones(max_tgt_seq_len)) ] * batch_size preprocessed_state = self.state_preprocessor( state_with_presence[0], state_with_presence[1] ) preprocessed_candidates = self.candidate_preprocessor( candidate_with_presence[0].view( batch_size * max_src_seq_len, len(self.candidate_sorted_features) ), candidate_with_presence[1].view( batch_size * max_src_seq_len, len(self.candidate_sorted_features) ), ).view(batch_size, max_src_seq_len, -1) src_src_mask = torch.ones(batch_size, max_src_seq_len, max_src_seq_len) tgt_out_idx = torch.cat( [slate_idx[0] for slate_idx in slate_idx_with_presence] ).view(batch_size, max_tgt_seq_len) tgt_out_seq = gather(preprocessed_candidates, tgt_out_idx) ranking_input = rlt.PreprocessedRankingInput.from_tensors( state=preprocessed_state, src_seq=preprocessed_candidates, src_src_mask=src_src_mask, tgt_out_seq=tgt_out_seq, # +2 is needed to avoid two preserved symbols: # PADDING_SYMBOL = 0 # DECODER_START_SYMBOL = 1 tgt_out_idx=tgt_out_idx + 2, ) output = self.model(ranking_input) return output.predicted_reward
def compute_best_reward(input_cities): batch_size, candidate_num, _ = input_cities.shape all_perm = torch.tensor( list(permutations(torch.arange(candidate_num), candidate_num))) res = [ compute_reward(gather(input_cities, perm.repeat(batch_size, 1))) for perm in all_perm ] # res shape: batch_size, num_perm res = torch.cat(res, dim=1) best_possible_reward = torch.min(res, dim=1).values best_possible_reward_mean = torch.mean(best_possible_reward) return best_possible_reward_mean
def encoder_output_to_scores(self, state, src_seq, src_src_mask, tgt_out_idx): # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq, src_src_mask) # encoder_output shape: batch_size, src_seq_len, dim_model # tgt_out_idx shape: batch_size, tgt_seq_len batch_size, tgt_seq_len = tgt_out_idx.shape # order encoder_output by tgt_out_idx # slate_encoder_output shape: batch_size, tgt_seq_len, dim_model slate_encoder_output = gather(encoder_output, tgt_out_idx - 2) # encoder_scores shape: batch_size, tgt_seq_len return self.encoder_scorer(slate_encoder_output).squeeze()
def create_batch( batch_size, candidate_num, candidate_dim, device, learning_method, diverse_input=False, ): # fake state, we only use candidates state = torch.zeros(batch_size, 1) if diverse_input: # city coordinates are spread in [0, 4] candidates = torch.randint( 5, (batch_size, candidate_num, candidate_dim)).float() else: # every training data has the same nodes as the input cities global FIX_CANDIDATES if FIX_CANDIDATES is None or FIX_CANDIDATES.shape != ( batch_size, candidate_num, candidate_dim, ): candidates = torch.randint( 5, (batch_size, candidate_num, candidate_dim)).float() candidates[1:] = candidates[0] FIX_CANDIDATES = candidates else: candidates = FIX_CANDIDATES batch_dict = { "state": state, "candidates": candidates, "device": device, } if learning_method == OFF_POLICY: # using data from a uniform sampling policy action = torch.stack( [torch.randperm(candidate_num) for _ in range(batch_size)]) propensity = torch.full((batch_size, 1), 1.0 / math.factorial(candidate_num)) ranked_cities = gather(candidates, action) reward = compute_reward(ranked_cities) batch_dict["action"] = action batch_dict["logged_propensities"] = propensity batch_dict["slate_reward"] = -reward batch = rlt.PreprocessedRankingInput.from_input(**batch_dict) logger.info("Generate one batch") return batch
def _convert_seq2slate_to_reward_model_format( self, input: rlt.PreprocessedRankingInput ): device = next(self.parameters()).device # pyre-fixme[16]: Optional type has no attribute `float_features`. batch_size, tgt_seq_len, candidate_dim = input.tgt_out_seq.float_features.shape src_seq_len = input.src_seq.float_features.shape[1] assert self.max_tgt_seq_len == tgt_seq_len assert self.max_src_seq_len == src_seq_len # unselected_idx stores indices of items that are not included in the slate unselected_idx = torch.ones(batch_size, src_seq_len, device=device) unselected_idx[ # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`. torch.arange(batch_size, device=device).repeat_interleave( torch.tensor(tgt_seq_len, device=device) ), # pyre-fixme[16]: Optional type has no attribute `flatten`. input.tgt_out_idx.flatten() - 2, ] = 0 # shape: batch_size, (src_seq_len - tgt_seq_len) unselected_idx = torch.nonzero(unselected_idx, as_tuple=True)[1].reshape( batch_size, src_seq_len - tgt_seq_len ) # shape: batch_size, (src_seq_len - tgt_seq_len), candidate_dim unselected_candidate_features = gather( input.src_seq.float_features, unselected_idx ) # shape: batch_size, src_seq_len + 1, candidate_dim tgt_in_seq = torch.cat( ( input.tgt_out_seq.float_features, unselected_candidate_features, self.end_of_seq_vec.repeat(batch_size, 1, 1), ), dim=1, ) return rlt.PreprocessedRankingInput.from_tensors( state=input.state.float_features, src_seq=input.src_seq.float_features, src_src_mask=input.src_src_mask, tgt_in_seq=tgt_in_seq, )
def encoder_output_to_scores( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_out_idx: torch.Tensor) -> Seq2SlateTransformerOutput: # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq) # encoder_output shape: batch_size, src_seq_len, dim_model # tgt_out_idx shape: batch_size, tgt_seq_len batch_size, tgt_seq_len = tgt_out_idx.shape # order encoder_output by tgt_out_idx # slate_encoder_output shape: batch_size, tgt_seq_len, dim_model slate_encoder_output = gather(encoder_output, tgt_out_idx - 2) # encoder_scores shape: batch_size, tgt_seq_len encoder_scores = self.encoder_scorer(slate_encoder_output).squeeze() return Seq2SlateTransformerOutput( ranked_per_symbol_probs=None, ranked_per_seq_probs=None, ranked_tgt_out_idx=None, per_symbol_log_probs=None, per_seq_log_probs=None, encoder_scores=encoder_scores, )
def _simulated_training_input( self, training_input: rlt.PreprocessedRankingInput): # precision error may cause invalid actions valid_output = False while not valid_output: rank_output = self.seq2slate_net( training_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=self.seq2slate_net.max_tgt_seq_len, greedy=False, ) model_propensities = rank_output.ranked_per_seq_probs model_actions_with_offset = rank_output.ranked_tgt_out_idx model_actions = model_actions_with_offset - 2 if torch.all(model_actions >= 0): valid_output = True batch_size = model_actions_with_offset.shape[0] simulated_slate_features = gather( training_input.src_seq.float_features, model_actions) if not self.reward_name_and_net: self.reward_name_and_net = _load_reward_net( self.sim_param.reward_name_path, self.use_gpu) sim_slate_reward = torch.zeros(batch_size, 1, device=self.device) for name, reward_net in self.reward_name_and_net.items(): weight = self.sim_param.reward_name_weight[name] power = self.sim_param.reward_name_power[name] sr = reward_net( training_input.state.float_features, training_input.src_seq.float_features, simulated_slate_features, training_input.src_src_mask, model_actions_with_offset, ).detach() assert sr.ndim == 2, f"Slate reward {name} output should be 2-D tensor" sim_slate_reward += weight * (sr**power) # guard-rail reward prediction range reward_clamp = self.sim_param.reward_clamp if reward_clamp is not None: sim_slate_reward = torch.clamp(sim_slate_reward, min=reward_clamp.clamp_min, max=reward_clamp.clamp_max) # guard-rail sequence similarity distance_penalty = self.sim_param.distance_penalty if distance_penalty is not None: sim_distance = ( torch.tensor( # pyre-fixme[16]: `int` has no attribute `__iter__`. [swap_dist(x.tolist()) for x in model_actions], device=self.device, ).unsqueeze(1).float()) sim_slate_reward += distance_penalty * (self.MAX_DISTANCE - sim_distance) assert (len(sim_slate_reward.shape) == 2 and sim_slate_reward.shape[1] == 1), f"{sim_slate_reward.shape}" on_policy_input = rlt.PreprocessedRankingInput.from_input( state=training_input.state.float_features, candidates=training_input.src_seq.float_features, device=self.device, # pyre-fixme[6]: Expected `Optional[torch.Tensor]` for 4th param but got # `int`. action=model_actions, slate_reward=sim_slate_reward, logged_propensities=model_propensities, ) return on_policy_input
def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy): """ Decode sequences based on given inputs """ device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros(batch_size, src_seq_len + 2, candidate_dim, device=device) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq, src_src_mask) ranked_per_symbol_probs = torch.zeros(batch_size, tgt_seq_len, candidate_size, device=device) ranked_per_seq_probs = torch.zeros(batch_size, 1) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[:, :tgt_seq_len] # +2 to account for start symbol and padding symbol tgt_out_idx += 2 # every position has propensity of 1 because we are just using argsort ranked_per_symbol_probs = ranked_per_symbol_probs.scatter( 2, tgt_out_idx.unsqueeze(2), 1.0) ranked_per_seq_probs[:, :] = 1.0 return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx tgt_in_idx = (torch.ones(batch_size, 1, device=device).fill_( self._DECODER_START_SYMBOL).type(torch.long)) assert greedy is not None for l in range(tgt_seq_len): tgt_in_seq = gather(candidate_features, tgt_in_idx) tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( memory, tgt_in_idx, self.num_heads) # shape batch_size, l + 1, candidate_size probs = self.decode( memory=memory, state=state, tgt_src_mask=tgt_src_mask, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, ) # next candidate shape: batch_size, 1 # prob shape: batch_size, candidate_size next_candidate, next_candidate_sample_prob = self.generator( probs, greedy) ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1) # remove the decoder start symbol # tgt_out_idx shape: batch_size, tgt_seq_len tgt_out_idx = tgt_in_idx[:, 1:] ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 # tgt_out_idx shape: batch_size, tgt_seq_len return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx