Ejemplo n.º 1
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]],
 ):
     assert (
         len(candidate_with_presence_list) == self.num_candidates
     ), f"{len(candidate_with_presence_list)} != {self.num_candidates}"
     preprocessed_state = self.state_preprocessor(*state_with_presence)
     # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim
     preprocessed_candidates = torch.stack(
         [
             self.candidate_preprocessor(*x)
             for x in candidate_with_presence_list
         ],
         dim=1,
     )
     input = rlt.FeatureData(
         float_features=preprocessed_state,
         candidate_docs=rlt.DocList(
             float_features=preprocessed_candidates,
             mask=torch.tensor(-1),
             value=torch.tensor(-1),
         ),
     )
     input = rlt._embed_states(input)
     action = self.model(input).action
     if self.action_postprocessor is not None:
         # pyre-fixme[29]: `Optional[Postprocessor]` is not a function.
         action = self.action_postprocessor(action)
     return action
Ejemplo n.º 2
0
    def __call__(self, obs):
        user = torch.tensor(obs["user"]).float().unsqueeze(0)

        doc_obs = obs["doc"]

        if self.discrete_keys or self.box_keys:
            # Dict space
            discrete_features: List[torch.Tensor] = []
            for k, n in self.discrete_keys:
                vals = torch.tensor([v[k] for v in doc_obs.values()])
                assert vals.shape == (self.num_docs, )
                discrete_features.append(F.one_hot(vals, n).float())

            box_features: List[torch.Tensor] = []
            for k, d in self.box_keys:
                vals = np.vstack([v[k] for v in doc_obs.values()])
                assert vals.shape == (self.num_docs, d)
                box_features.append(torch.tensor(vals).float())

            doc_features = torch.cat(discrete_features + box_features,
                                     dim=1).unsqueeze(0)
        else:
            # Simply a Box space
            vals = np.vstack(list(doc_obs.values()))
            doc_features = torch.tensor(vals).float().unsqueeze(0)

        # This comes from ValueWrapper
        value = (torch.tensor([
            v["value"] for v in obs["augmentation"].values()
        ]).float().unsqueeze(0))

        candidate_docs = rlt.DocList(float_features=doc_features, value=value)
        return rlt.FeatureData(float_features=user,
                               candidate_docs=candidate_docs)
Ejemplo n.º 3
0
 def _stack(slates):
     obs = rlt.FeatureData(
         float_features=torch.from_numpy(
             np.stack(np.array([slate["user"] for slate in slates]))),
         candidate_docs=rlt.DocList(float_features=torch.from_numpy(
             np.stack(np.array([slate["doc"] for slate in slates])))),
     )
     return obs
Ejemplo n.º 4
0
 def forward(self, obs):
     if self.log_transform:
         obs = rlt.FeatureData(
             float_features=obs.float_features.clip(EPS).log(),
             candidate_docs=rlt.DocList(
                 float_features=obs.candidate_docs.float_features.clip(EPS).log(),
             ),
         )
     mlp_input = self._concat_features(obs)
     scores = self.mlp(mlp_input)
     return scores.squeeze(-1)
Ejemplo n.º 5
0
 def input_prototype(self):
     # Sample config for input
     batch_size = 2
     state_dim = 5
     num_docs = 3
     candidate_dim = 4
     return rlt.FeatureData(
         float_features=torch.randn((batch_size, state_dim)),
         candidate_docs=rlt.DocList(float_features=torch.randn(
             batch_size, num_docs, candidate_dim)),
     )
Ejemplo n.º 6
0
 def forward(self, state_vp, candidate_vp):
     batch_size, num_candidates, candidate_dim = candidate_vp[0].shape
     state_feats = self.state_preprocessor(*state_vp)
     candidate_feats = self.candidate_preprocessor(
         candidate_vp[0].view(
             batch_size * num_candidates,
             len(self.candidate_preprocessor.sorted_features),
         ),
         candidate_vp[1].view(
             batch_size * num_candidates,
             len(self.candidate_preprocessor.sorted_features),
         ),
     ).view(batch_size, num_candidates, -1)
     input = rlt.FeatureData(float_features=state_feats,
                             candidate_docs=rlt.DocList(candidate_feats))
     scores = self.mlp(input).view(batch_size, num_candidates)
     return scores