def sparse_input_prototype( model: ModelBase, state_preprocessor: Preprocessor, state_feature_config: rlt.ModelFeatureConfig, ): name2id = state_feature_config.name2id model_prototype = model.input_prototype() # Terrible hack to make JIT tracing works. Python dict doesn't have type # so we need to insert something so JIT tracer can infer the type. state_id_list_features = FAKE_STATE_ID_LIST_FEATURES state_id_score_list_features = FAKE_STATE_ID_SCORE_LIST_FEATURES if isinstance(model_prototype, rlt.FeatureData): if model_prototype.id_list_features: state_id_list_features = { name2id[k]: v for k, v in model_prototype.id_list_features.items() } if model_prototype.id_score_list_features: state_id_score_list_features = { name2id[k]: v for k, v in model_prototype.id_score_list_features.items() } input = rlt.ServingFeatureData( float_features_with_presence=state_preprocessor.input_prototype(), id_list_features=state_id_list_features, id_score_list_features=state_id_score_list_features, ) return (input, )
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], state_id_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor]], state_id_score_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], ) -> Tuple[List[str], torch.Tensor]: return self.model( rlt.ServingFeatureData( float_features_with_presence=state_with_presence, id_list_features=state_id_list_features, id_score_list_features=state_id_score_list_features, ))
def _test_seq2reward_with_preprocessor(self, plan_short_sequence): state_dim = 4 action_dim = 2 seq_len = 3 model = FakeSeq2RewardNetwork() state_normalization_parameters = { i: NormalizationParameters(feature_type=DO_NOT_PREPROCESS, mean=0.0, stddev=1.0) for i in range(1, state_dim) } state_preprocessor = Preprocessor(state_normalization_parameters, False) if plan_short_sequence: step_prediction_model = FakeStepPredictionNetwork(seq_len) model_with_preprocessor = Seq2RewardPlanShortSeqWithPreprocessor( model, step_prediction_model, state_preprocessor, seq_len, action_dim, ) else: model_with_preprocessor = Seq2RewardWithPreprocessor( model, state_preprocessor, seq_len, action_dim, ) input_prototype = rlt.ServingFeatureData( float_features_with_presence=state_preprocessor.input_prototype(), id_list_features=FAKE_STATE_ID_LIST_FEATURES, id_score_list_features=FAKE_STATE_ID_SCORE_LIST_FEATURES, ) q_values = model_with_preprocessor(input_prototype) if plan_short_sequence: # When planning for 1, 2, and 3 steps ahead, # the expected q values are respectively: # [0, 1], [1, 11], [11, 111] # Weighting the expected q values by predicted step # probabilities [0.33, 0.33, 0.33], we have [4, 41] expected_q_values = torch.tensor([[4.0, 41.0]]) else: expected_q_values = torch.tensor([[11.0, 111.0]]) assert torch.all(expected_q_values == q_values)
def act( self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]], possible_actions_mask: Optional[np.ndarray], ) -> rlt.ActorOutput: """Input is either state_with_presence, or ServingFeatureData (in the case of sparse features)""" assert isinstance(obs, tuple) if isinstance(obs, rlt.ServingFeatureData): state: rlt.ServingFeatureData = obs else: state = rlt.ServingFeatureData( float_features_with_presence=obs, id_list_features={}, id_score_list_features={}, ) scores = self.scorer(state, possible_actions_mask) return self.sampler.sample_action(scores).cpu().detach()
def serving_obs_preprocessor(self, obs: np.ndarray) -> rlt.ServingFeatureData: dense_val, id_list_val, id_score_list_val = self._split_state(obs) return rlt.ServingFeatureData( float_features_with_presence=( dense_val, torch.ones_like(dense_val, dtype=torch.uint8), ), id_list_features={ ID_LIST_FEATURE_ID: ( torch.tensor([0], dtype=torch.long), id_list_val + ID_LIST_OFFSET, ) }, id_score_list_features={ ID_SCORE_LIST_FEATURE_ID: ( torch.tensor([0], dtype=torch.long), torch.arange(self.num_arms, dtype=torch.long) + ID_SCORE_LIST_OFFSET, id_score_list_val, ) }, )
def act( self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]], possible_actions_mask: Optional[torch.Tensor] = None, ) -> rlt.ActorOutput: """Input is either state_with_presence, or ServingFeatureData (in the case of sparse features)""" assert isinstance(obs, tuple) if isinstance(obs, rlt.ServingFeatureData): state: rlt.ServingFeatureData = obs else: state = rlt.ServingFeatureData( float_features_with_presence=obs, id_list_features={}, id_score_list_features={}, ) output = self.predictor(*state) if isinstance(output, tuple): action, log_prob = output log_prob = log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX) return rlt.ActorOutput(action=action.cpu(), log_prob=log_prob.cpu()) else: return rlt.ActorOutput(action=output.cpu())