Example #1
0
        def process_tgt_seq(action):
            if action is not None:
                _, output_size = action.shape
                # Account for decoder starting symbol and padding symbol
                candidates_augment = torch.cat(
                    (
                        torch.zeros(
                            batch_size, 2, candidate_dim, device=device),
                        candidates,
                    ),
                    dim=1,
                )
                tgt_out_idx = action + 2
                tgt_in_idx = torch.full((batch_size, output_size),
                                        DECODER_START_SYMBOL,
                                        device=device)
                tgt_in_idx[:, 1:] = tgt_out_idx[:, :-1]
                tgt_out_seq = gather(candidates_augment, tgt_out_idx)
                tgt_in_seq = torch.zeros(batch_size,
                                         output_size,
                                         candidate_dim,
                                         device=device)
                tgt_in_seq[:, 1:] = tgt_out_seq[:, :-1]
                tgt_tgt_mask = subsequent_mask(output_size, device)
            else:
                tgt_in_idx = None
                tgt_out_idx = None
                tgt_in_seq = None
                tgt_out_seq = None
                tgt_tgt_mask = None

            return tgt_in_idx, tgt_out_idx, tgt_in_seq, tgt_out_seq, tgt_tgt_mask
Example #2
0
def rank_on_policy_and_eval(seq2slate_net, batch: rlt.PreprocessedRankingInput,
                            tgt_seq_len: int, greedy: bool):
    model_propensity, model_action = rank_on_policy(seq2slate_net,
                                                    batch,
                                                    tgt_seq_len,
                                                    greedy=greedy)
    ranked_cities = gather(batch.src_seq.float_features, model_action)
    reward = compute_reward(ranked_cities)
    return model_propensity, model_action, reward
Example #3
0
    def forward(
        self,
        state_with_presence: Tuple[torch.Tensor, torch.Tensor],
        candidate_with_presence: Tuple[torch.Tensor, torch.Tensor],
    ):
        # state_value.shape == state_presence.shape == batch_size x state_feat_num
        # candidate_value.shape == candidate_presence.shape ==
        # batch_size x max_src_seq_len x candidate_feat_num
        batch_size = state_with_presence[0].shape[0]
        max_tgt_seq_len = self.model.max_tgt_seq_len
        max_src_seq_len = self.model.max_src_seq_len

        # we use a fake slate_idx_with_presence to retrive the first
        # max_tgt_seq_len candidates from
        # len(slate_idx_with presence) == batch_size
        # component: 1d tensor with length max_tgt_seq_len
        slate_idx_with_presence = [
            (torch.arange(max_tgt_seq_len), torch.ones(max_tgt_seq_len))
        ] * batch_size

        preprocessed_state = self.state_preprocessor(
            state_with_presence[0], state_with_presence[1]
        )

        preprocessed_candidates = self.candidate_preprocessor(
            candidate_with_presence[0].view(
                batch_size * max_src_seq_len, len(self.candidate_sorted_features)
            ),
            candidate_with_presence[1].view(
                batch_size * max_src_seq_len, len(self.candidate_sorted_features)
            ),
        ).view(batch_size, max_src_seq_len, -1)

        src_src_mask = torch.ones(batch_size, max_src_seq_len, max_src_seq_len)

        tgt_out_idx = torch.cat(
            [slate_idx[0] for slate_idx in slate_idx_with_presence]
        ).view(batch_size, max_tgt_seq_len)

        tgt_out_seq = gather(preprocessed_candidates, tgt_out_idx)

        ranking_input = rlt.PreprocessedRankingInput.from_tensors(
            state=preprocessed_state,
            src_seq=preprocessed_candidates,
            src_src_mask=src_src_mask,
            tgt_out_seq=tgt_out_seq,
            # +2 is needed to avoid two preserved symbols:
            # PADDING_SYMBOL = 0
            # DECODER_START_SYMBOL = 1
            tgt_out_idx=tgt_out_idx + 2,
        )

        output = self.model(ranking_input)
        return output.predicted_reward
Example #4
0
def compute_best_reward(input_cities):
    batch_size, candidate_num, _ = input_cities.shape
    all_perm = torch.tensor(
        list(permutations(torch.arange(candidate_num), candidate_num)))
    res = [
        compute_reward(gather(input_cities, perm.repeat(batch_size, 1)))
        for perm in all_perm
    ]
    # res shape: batch_size, num_perm
    res = torch.cat(res, dim=1)
    best_possible_reward = torch.min(res, dim=1).values
    best_possible_reward_mean = torch.mean(best_possible_reward)
    return best_possible_reward_mean
Example #5
0
    def encoder_output_to_scores(self, state, src_seq, src_src_mask,
                                 tgt_out_idx):
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq, src_src_mask)

        # encoder_output shape: batch_size, src_seq_len, dim_model
        # tgt_out_idx shape: batch_size, tgt_seq_len
        batch_size, tgt_seq_len = tgt_out_idx.shape

        # order encoder_output by tgt_out_idx
        # slate_encoder_output shape: batch_size, tgt_seq_len, dim_model
        slate_encoder_output = gather(encoder_output, tgt_out_idx - 2)
        # encoder_scores shape: batch_size, tgt_seq_len
        return self.encoder_scorer(slate_encoder_output).squeeze()
def create_batch(
    batch_size,
    candidate_num,
    candidate_dim,
    device,
    learning_method,
    diverse_input=False,
):
    # fake state, we only use candidates
    state = torch.zeros(batch_size, 1)
    if diverse_input:
        # city coordinates are spread in [0, 4]
        candidates = torch.randint(
            5, (batch_size, candidate_num, candidate_dim)).float()
    else:
        # every training data has the same nodes as the input cities
        global FIX_CANDIDATES
        if FIX_CANDIDATES is None or FIX_CANDIDATES.shape != (
                batch_size,
                candidate_num,
                candidate_dim,
        ):
            candidates = torch.randint(
                5, (batch_size, candidate_num, candidate_dim)).float()
            candidates[1:] = candidates[0]
            FIX_CANDIDATES = candidates
        else:
            candidates = FIX_CANDIDATES

    batch_dict = {
        "state": state,
        "candidates": candidates,
        "device": device,
    }
    if learning_method == OFF_POLICY:
        # using data from a uniform sampling policy
        action = torch.stack(
            [torch.randperm(candidate_num) for _ in range(batch_size)])
        propensity = torch.full((batch_size, 1),
                                1.0 / math.factorial(candidate_num))
        ranked_cities = gather(candidates, action)
        reward = compute_reward(ranked_cities)
        batch_dict["action"] = action
        batch_dict["logged_propensities"] = propensity
        batch_dict["slate_reward"] = -reward

    batch = rlt.PreprocessedRankingInput.from_input(**batch_dict)
    logger.info("Generate one batch")
    return batch
Example #7
0
    def _convert_seq2slate_to_reward_model_format(
        self, input: rlt.PreprocessedRankingInput
    ):
        device = next(self.parameters()).device
        # pyre-fixme[16]: Optional type has no attribute `float_features`.
        batch_size, tgt_seq_len, candidate_dim = input.tgt_out_seq.float_features.shape
        src_seq_len = input.src_seq.float_features.shape[1]
        assert self.max_tgt_seq_len == tgt_seq_len
        assert self.max_src_seq_len == src_seq_len

        # unselected_idx stores indices of items that are not included in the slate
        unselected_idx = torch.ones(batch_size, src_seq_len, device=device)
        unselected_idx[
            # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
            torch.arange(batch_size, device=device).repeat_interleave(
                torch.tensor(tgt_seq_len, device=device)
            ),
            # pyre-fixme[16]: Optional type has no attribute `flatten`.
            input.tgt_out_idx.flatten() - 2,
        ] = 0
        # shape: batch_size, (src_seq_len - tgt_seq_len)
        unselected_idx = torch.nonzero(unselected_idx, as_tuple=True)[1].reshape(
            batch_size, src_seq_len - tgt_seq_len
        )
        # shape: batch_size, (src_seq_len - tgt_seq_len), candidate_dim
        unselected_candidate_features = gather(
            input.src_seq.float_features, unselected_idx
        )
        # shape: batch_size, src_seq_len + 1, candidate_dim
        tgt_in_seq = torch.cat(
            (
                input.tgt_out_seq.float_features,
                unselected_candidate_features,
                self.end_of_seq_vec.repeat(batch_size, 1, 1),
            ),
            dim=1,
        )

        return rlt.PreprocessedRankingInput.from_tensors(
            state=input.state.float_features,
            src_seq=input.src_seq.float_features,
            src_src_mask=input.src_src_mask,
            tgt_in_seq=tgt_in_seq,
        )
Example #8
0
    def encoder_output_to_scores(
            self, state: torch.Tensor, src_seq: torch.Tensor,
            tgt_out_idx: torch.Tensor) -> Seq2SlateTransformerOutput:
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq)

        # encoder_output shape: batch_size, src_seq_len, dim_model
        # tgt_out_idx shape: batch_size, tgt_seq_len
        batch_size, tgt_seq_len = tgt_out_idx.shape

        # order encoder_output by tgt_out_idx
        # slate_encoder_output shape: batch_size, tgt_seq_len, dim_model
        slate_encoder_output = gather(encoder_output, tgt_out_idx - 2)
        # encoder_scores shape: batch_size, tgt_seq_len
        encoder_scores = self.encoder_scorer(slate_encoder_output).squeeze()
        return Seq2SlateTransformerOutput(
            ranked_per_symbol_probs=None,
            ranked_per_seq_probs=None,
            ranked_tgt_out_idx=None,
            per_symbol_log_probs=None,
            per_seq_log_probs=None,
            encoder_scores=encoder_scores,
        )
Example #9
0
    def _simulated_training_input(
            self, training_input: rlt.PreprocessedRankingInput):
        # precision error may cause invalid actions
        valid_output = False
        while not valid_output:
            rank_output = self.seq2slate_net(
                training_input,
                mode=Seq2SlateMode.RANK_MODE,
                tgt_seq_len=self.seq2slate_net.max_tgt_seq_len,
                greedy=False,
            )
            model_propensities = rank_output.ranked_per_seq_probs
            model_actions_with_offset = rank_output.ranked_tgt_out_idx
            model_actions = model_actions_with_offset - 2
            if torch.all(model_actions >= 0):
                valid_output = True

        batch_size = model_actions_with_offset.shape[0]
        simulated_slate_features = gather(
            training_input.src_seq.float_features, model_actions)

        if not self.reward_name_and_net:
            self.reward_name_and_net = _load_reward_net(
                self.sim_param.reward_name_path, self.use_gpu)

        sim_slate_reward = torch.zeros(batch_size, 1, device=self.device)
        for name, reward_net in self.reward_name_and_net.items():
            weight = self.sim_param.reward_name_weight[name]
            power = self.sim_param.reward_name_power[name]
            sr = reward_net(
                training_input.state.float_features,
                training_input.src_seq.float_features,
                simulated_slate_features,
                training_input.src_src_mask,
                model_actions_with_offset,
            ).detach()
            assert sr.ndim == 2, f"Slate reward {name} output should be 2-D tensor"
            sim_slate_reward += weight * (sr**power)

        # guard-rail reward prediction range
        reward_clamp = self.sim_param.reward_clamp
        if reward_clamp is not None:
            sim_slate_reward = torch.clamp(sim_slate_reward,
                                           min=reward_clamp.clamp_min,
                                           max=reward_clamp.clamp_max)
        # guard-rail sequence similarity
        distance_penalty = self.sim_param.distance_penalty
        if distance_penalty is not None:
            sim_distance = (
                torch.tensor(
                    # pyre-fixme[16]: `int` has no attribute `__iter__`.
                    [swap_dist(x.tolist()) for x in model_actions],
                    device=self.device,
                ).unsqueeze(1).float())
            sim_slate_reward += distance_penalty * (self.MAX_DISTANCE -
                                                    sim_distance)

        assert (len(sim_slate_reward.shape) == 2 and sim_slate_reward.shape[1]
                == 1), f"{sim_slate_reward.shape}"

        on_policy_input = rlt.PreprocessedRankingInput.from_input(
            state=training_input.state.float_features,
            candidates=training_input.src_seq.float_features,
            device=self.device,
            # pyre-fixme[6]: Expected `Optional[torch.Tensor]` for 4th param but got
            #  `int`.
            action=model_actions,
            slate_reward=sim_slate_reward,
            logged_propensities=model_propensities,
        )
        return on_policy_input
Example #10
0
    def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy):
        """ Decode sequences based on given inputs """
        device = src_seq.device
        batch_size, src_seq_len, candidate_dim = src_seq.shape
        candidate_size = src_seq_len + 2

        # candidate_features is used as look-up table for candidate features.
        # the second dim is src_seq_len + 2 because we also want to include
        # features of start symbol and padding symbol
        candidate_features = torch.zeros(batch_size,
                                         src_seq_len + 2,
                                         candidate_dim,
                                         device=device)
        # TODO: T62502977 create learnable feature vectors for start symbol
        # and padding symbol
        candidate_features[:, 2:, :] = src_seq

        # memory shape: batch_size, src_seq_len, dim_model
        memory = self.encode(state, src_seq, src_src_mask)

        ranked_per_symbol_probs = torch.zeros(batch_size,
                                              tgt_seq_len,
                                              candidate_size,
                                              device=device)
        ranked_per_seq_probs = torch.zeros(batch_size, 1)

        if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # encoder_scores shape: batch_size, src_seq_len
            encoder_scores = self.encoder_scorer(memory).squeeze(dim=2)
            tgt_out_idx = torch.argsort(encoder_scores, dim=1,
                                        descending=True)[:, :tgt_seq_len]
            # +2 to account for start symbol and padding symbol
            tgt_out_idx += 2
            # every position has propensity of 1 because we are just using argsort
            ranked_per_symbol_probs = ranked_per_symbol_probs.scatter(
                2, tgt_out_idx.unsqueeze(2), 1.0)
            ranked_per_seq_probs[:, :] = 1.0
            return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx

        tgt_in_idx = (torch.ones(batch_size, 1, device=device).fill_(
            self._DECODER_START_SYMBOL).type(torch.long))

        assert greedy is not None
        for l in range(tgt_seq_len):
            tgt_in_seq = gather(candidate_features, tgt_in_idx)
            tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask(
                memory, tgt_in_idx, self.num_heads)
            # shape batch_size, l + 1, candidate_size
            probs = self.decode(
                memory=memory,
                state=state,
                tgt_src_mask=tgt_src_mask,
                tgt_in_idx=tgt_in_idx,
                tgt_in_seq=tgt_in_seq,
                tgt_tgt_mask=tgt_tgt_mask,
            )
            # next candidate shape: batch_size, 1
            # prob shape: batch_size, candidate_size
            next_candidate, next_candidate_sample_prob = self.generator(
                probs, greedy)
            ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob
            tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1)

        # remove the decoder start symbol
        # tgt_out_idx shape: batch_size, tgt_seq_len
        tgt_out_idx = tgt_in_idx[:, 1:]

        ranked_per_seq_probs = per_symbol_to_per_seq_probs(
            ranked_per_symbol_probs, tgt_out_idx)

        # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size
        # ranked_per_seq_probs shape: batch_size, 1
        # tgt_out_idx shape: batch_size, tgt_seq_len
        return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx