def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]], ): assert ( len(candidate_with_presence_list) == self.num_candidates ), f"{len(candidate_with_presence_list)} != {self.num_candidates}" preprocessed_state = self.state_preprocessor(*state_with_presence) # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim preprocessed_candidates = torch.stack( [ self.candidate_preprocessor(*x) for x in candidate_with_presence_list ], dim=1, ) input = rlt.FeatureData( float_features=preprocessed_state, candidate_docs=rlt.DocList( float_features=preprocessed_candidates, mask=torch.tensor(-1), value=torch.tensor(-1), ), ) input = rlt._embed_states(input) action = self.model(input).action if self.action_postprocessor is not None: # pyre-fixme[29]: `Optional[Postprocessor]` is not a function. action = self.action_postprocessor(action) return action
def __call__(self, obs): user = torch.tensor(obs["user"]).float().unsqueeze(0) doc_obs = obs["doc"] if self.discrete_keys or self.box_keys: # Dict space discrete_features: List[torch.Tensor] = [] for k, n in self.discrete_keys: vals = torch.tensor([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, ) discrete_features.append(F.one_hot(vals, n).float()) box_features: List[torch.Tensor] = [] for k, d in self.box_keys: vals = np.vstack([v[k] for v in doc_obs.values()]) assert vals.shape == (self.num_docs, d) box_features.append(torch.tensor(vals).float()) doc_features = torch.cat(discrete_features + box_features, dim=1).unsqueeze(0) else: # Simply a Box space vals = np.vstack(list(doc_obs.values())) doc_features = torch.tensor(vals).float().unsqueeze(0) # This comes from ValueWrapper value = (torch.tensor([ v["value"] for v in obs["augmentation"].values() ]).float().unsqueeze(0)) candidate_docs = rlt.DocList(float_features=doc_features, value=value) return rlt.FeatureData(float_features=user, candidate_docs=candidate_docs)
def _stack(slates): obs = rlt.FeatureData( float_features=torch.from_numpy( np.stack(np.array([slate["user"] for slate in slates]))), candidate_docs=rlt.DocList(float_features=torch.from_numpy( np.stack(np.array([slate["doc"] for slate in slates])))), ) return obs
def forward(self, obs): if self.log_transform: obs = rlt.FeatureData( float_features=obs.float_features.clip(EPS).log(), candidate_docs=rlt.DocList( float_features=obs.candidate_docs.float_features.clip( EPS).log(), ), ) return self.mlp(self._concat_features(obs)).squeeze(-1)
def input_prototype(self): # Sample config for input batch_size = 2 state_dim = 5 num_docs = 3 candidate_dim = 4 rlt.FeatureData( float_features=torch.randn((batch_size, state_dim)), candidate_docs=rlt.DocList(float_features=torch.randn( batch_size, num_docs, candidate_dim)), )
def forward(self, state_vp, candidate_vp): batch_size, num_candidates, candidate_dim = candidate_vp[0].shape state_feats = self.state_preprocessor(*state_vp) candidate_feats = self.candidate_preprocessor( candidate_vp[0].view( batch_size * num_candidates, len(self.candidate_preprocessor.sorted_features), ), candidate_vp[1].view( batch_size * num_candidates, len(self.candidate_preprocessor.sorted_features), ), ).view(batch_size, num_candidates, -1) input = rlt.FeatureData(float_features=state_feats, candidate_docs=rlt.DocList(candidate_feats)) scores = self.mlp(input).view(batch_size, num_candidates) return scores
def forward(self, state_vp, candidate_vp): batch_size = state_vp[0].shape[0] state_feats = self.state_preprocessor(*state_vp) candidate_feats = self.candidate_preprocessor( candidate_vp[0].view( batch_size * self.num_candidates, len(self.candidate_preprocessor.sorted_features), ), candidate_vp[1].view( batch_size * self.num_candidates, len(self.candidate_preprocessor.sorted_features), ), ).view(batch_size, self.num_candidates, -1) input = rlt.FeatureData(float_features=state_feats, candidate_docs=rlt.DocList(candidate_feats)) scores = self.mlp(input).view(batch_size, self.num_candidates) return scores.argsort(dim=1, descending=True)[:, :self.slate_size]