示例#1
0
def sparse_input_prototype(
    model: ModelBase,
    state_preprocessor: Preprocessor,
    state_feature_config: rlt.ModelFeatureConfig,
):
    name2id = state_feature_config.name2id
    model_prototype = model.input_prototype()
    # Terrible hack to make JIT tracing works. Python dict doesn't have type
    # so we need to insert something so JIT tracer can infer the type.
    state_id_list_features = {
        42: (torch.zeros(1, dtype=torch.long), torch.tensor([], dtype=torch.long))
    }
    state_id_score_list_features = {
        42: (
            torch.zeros(1, dtype=torch.long),
            torch.tensor([], dtype=torch.long),
            torch.tensor([], dtype=torch.float),
        )
    }
    if isinstance(model_prototype, rlt.FeatureData):
        if model_prototype.id_list_features:
            state_id_list_features = {
                name2id[k]: v for k, v in model_prototype.id_list_features.items()
            }
        if model_prototype.id_score_list_features:
            state_id_score_list_features = {
                name2id[k]: v for k, v in model_prototype.id_score_list_features.items()
            }

    input = rlt.ServingFeatureData(
        float_features_with_presence=state_preprocessor.input_prototype(),
        id_list_features=state_id_list_features,
        id_score_list_features=state_id_score_list_features,
    )
    return (input,)
示例#2
0
def sparse_input_prototype(
    model: ModelBase,
    state_preprocessor: Preprocessor,
    state_feature_config: rlt.ModelFeatureConfig,
):
    name2id = state_feature_config.name2id
    model_prototype = model.input_prototype()
    # Terrible hack to make JIT tracing works. Python dict doesn't have type
    # so we need to insert something so JIT tracer can infer the type.
    state_id_list_features = FAKE_STATE_ID_LIST_FEATURES
    state_id_score_list_features = FAKE_STATE_ID_SCORE_LIST_FEATURES
    if isinstance(model_prototype, rlt.FeatureData):
        if model_prototype.id_list_features:
            state_id_list_features = {
                name2id[k]: v
                for k, v in model_prototype.id_list_features.items()
            }
        if model_prototype.id_score_list_features:
            state_id_score_list_features = {
                name2id[k]: v
                for k, v in model_prototype.id_score_list_features.items()
            }

    input = rlt.ServingFeatureData(
        float_features_with_presence=state_preprocessor.input_prototype(),
        id_list_features=state_id_list_features,
        id_score_list_features=state_id_score_list_features,
    )
    return (input, )
示例#3
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     state_id_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor]],
     state_id_score_list_features: Dict[int,
                                        Tuple[torch.Tensor, torch.Tensor,
                                              torch.Tensor]],
 ) -> Tuple[List[str], torch.Tensor]:
     return self.model(
         rlt.ServingFeatureData(
             float_features_with_presence=state_with_presence,
             id_list_features=state_id_list_features,
             id_score_list_features=state_id_score_list_features,
         ))
示例#4
0
 def act(
     self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]]
 ) -> rlt.ActorOutput:
     """ Input is either state_with_presence, or
     ServingFeatureData (in the case of sparse features) """
     assert isinstance(obs, tuple)
     if isinstance(obs, rlt.ServingFeatureData):
         state: rlt.ServingFeatureData = obs
     else:
         state = rlt.ServingFeatureData(
             float_features_with_presence=obs,
             id_list_features={},
             id_score_list_features={},
         )
     scores = self.scorer(state)
     return self.sampler.sample_action(scores).cpu().detach()
示例#5
0
    def _test_seq2reward_with_preprocessor(self, plan_short_sequence):
        state_dim = 4
        action_dim = 2
        seq_len = 3
        model = FakeSeq2RewardNetwork()
        state_normalization_parameters = {
            i: NormalizationParameters(feature_type=DO_NOT_PREPROCESS,
                                       mean=0.0,
                                       stddev=1.0)
            for i in range(1, state_dim)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)

        if plan_short_sequence:
            step_prediction_model = FakeStepPredictionNetwork(seq_len)
            model_with_preprocessor = Seq2RewardPlanShortSeqWithPreprocessor(
                model,
                step_prediction_model,
                state_preprocessor,
                seq_len,
                action_dim,
            )
        else:
            model_with_preprocessor = Seq2RewardWithPreprocessor(
                model,
                state_preprocessor,
                seq_len,
                action_dim,
            )
        input_prototype = rlt.ServingFeatureData(
            float_features_with_presence=state_preprocessor.input_prototype(),
            id_list_features=FAKE_STATE_ID_LIST_FEATURES,
            id_score_list_features=FAKE_STATE_ID_SCORE_LIST_FEATURES,
        )
        q_values = model_with_preprocessor(input_prototype)
        if plan_short_sequence:
            # When planning for 1, 2, and 3 steps ahead,
            # the expected q values are respectively:
            # [0, 1], [1, 11], [11, 111]
            # Weighting the expected q values by predicted step
            # probabilities [0.33, 0.33, 0.33], we have [4, 41]
            expected_q_values = torch.tensor([[4.0, 41.0]])
        else:
            expected_q_values = torch.tensor([[11.0, 111.0]])
        assert torch.all(expected_q_values == q_values)
示例#6
0
 def serving_obs_preprocessor(self, obs: np.ndarray) -> rlt.ServingFeatureData:
     dense_val, id_list_val, id_score_list_val = self._split_state(obs)
     return rlt.ServingFeatureData(
         float_features_with_presence=(
             dense_val,
             torch.ones_like(dense_val, dtype=torch.uint8),
         ),
         id_list_features={
             100: (torch.tensor([0], dtype=torch.long), id_list_val + ID_LIST_OFFSET)
         },
         id_score_list_features={
             1000: (
                 torch.tensor([0], dtype=torch.long),
                 torch.arange(self.num_arms, dtype=torch.long)
                 + ID_SCORE_LIST_OFFSET,
                 id_score_list_val,
             )
         },
     )