Exemplo n.º 1
0
 def __init__(
     self,
     model: ModelBase,
     state_preprocessor: Preprocessor,
     state_feature_config: rlt.ModelFeatureConfig,
 ):
     super().__init__()
     self.model = model
     self.state_preprocessor = state_preprocessor
     self.state_feature_config = state_feature_config
     self.sparse_preprocessor = make_sparse_preprocessor(
         self.state_feature_config, device=torch.device("cpu"))
Exemplo n.º 2
0
 def __init__(
     self,
     id_list_keys: List[str],
     id_score_list_keys: List[str],
     feature_config: rlt.ModelFeatureConfig,
     device: torch.device,
 ):
     self.id_list_keys = id_list_keys
     self.id_score_list_keys = id_score_list_keys
     assert set(id_list_keys).intersection(set(id_score_list_keys)) == set()
     self.feature_config = feature_config
     self.sparse_preprocessor = make_sparse_preprocessor(
         feature_config=feature_config, device=device)
Exemplo n.º 3
0
 def __init__(
     self,
     model: ModelBase,
     state_preprocessor: Preprocessor,
     seq_len: int,
     num_action: int,
     state_feature_config: Optional[rlt.ModelFeatureConfig] = None,
 ):
     super().__init__()
     self.model = model
     self.state_preprocessor = state_preprocessor
     self.state_feature_config = state_feature_config or rlt.ModelFeatureConfig(
     )
     self.sparse_preprocessor = make_sparse_preprocessor(
         self.state_feature_config, device=torch.device("cpu"))
     self.seq_len = seq_len
     self.num_action = num_action
Exemplo n.º 4
0
 def __init__(
     self,
     model: ModelBase,
     state_preprocessor: Preprocessor,
     state_feature_config: rlt.ModelFeatureConfig,
     action_postprocessor: Optional[Postprocessor] = None,
     serve_mean_policy: bool = False,
 ):
     super().__init__()
     self.model = model
     self.state_preprocessor = state_preprocessor
     self.state_feature_config = state_feature_config
     self.sparse_preprocessor = make_sparse_preprocessor(
         self.state_feature_config, device=torch.device("cpu")
     )
     self.action_postprocessor = action_postprocessor
     self.serve_mean_policy = serve_mean_policy