Beispiel #1
0
 def state_feature_config_provider(
         self) -> ModelFeatureConfigProvider__Union:
     """ For online gym """
     raw = RawModelFeatureConfigProvider(
         float_feature_infos=[
             rlt.FloatFeatureInfo(name="arm0_sample", feature_id=0),
             rlt.FloatFeatureInfo(name="arm1_sample", feature_id=1),
             rlt.FloatFeatureInfo(name="arm2_sample", feature_id=2),
             rlt.FloatFeatureInfo(name="arm3_sample", feature_id=3),
             rlt.FloatFeatureInfo(name="arm4_sample", feature_id=4),
         ],
         id_list_feature_configs=[
             rlt.IdListFeatureConfig(name="legal",
                                     feature_id=100,
                                     id_mapping_name="legal_actions")
         ],
         id_score_list_feature_configs=[
             rlt.IdScoreListFeatureConfig(name="mu_changes",
                                          feature_id=1000,
                                          id_mapping_name="arms_list")
         ],
         id_mapping_config={
             "legal_actions": rlt.IdMapping(ids=[0, 1, 2, 3, 4, 5]),
             "arms_list": rlt.IdMapping(ids=[0, 1, 2, 3, 4]),
         },
     )
     # pyre-fixme[16]: `ModelFeatureConfigProvider__Union` has no attribute
     #  `make_union_instance`.
     return ModelFeatureConfigProvider__Union.make_union_instance(raw)
Beispiel #2
0
    def _test_discrete_dqn_net_builder(
        self,
        chooser: DiscreteDQNNetBuilder__Union,
        state_feature_config: Optional[rlt.ModelFeatureConfig] = None,
        serving_module_class=DiscreteDqnPredictorWrapper,
    ) -> None:
        builder = chooser.value
        state_dim = 3
        state_feature_config = state_feature_config or rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i)
                for i in range(state_dim)
            ])
        state_dim = len(state_feature_config.float_feature_infos)

        state_norm_params = {
            fi.feature_id: NormalizationParameters(feature_type=CONTINUOUS,
                                                   mean=0.0,
                                                   stddev=1.0)
            for fi in state_feature_config.float_feature_infos
        }

        action_names = ["L", "R"]
        q_network = builder.build_q_network(state_feature_config,
                                            state_norm_params,
                                            len(action_names))
        x = q_network.input_prototype()
        y = q_network(x).q_values
        self.assertEqual(y.shape, (1, 2))
        serving_module = builder.build_serving_module(q_network,
                                                      state_norm_params,
                                                      action_names,
                                                      state_feature_config)
        self.assertIsInstance(serving_module, serving_module_class)
Beispiel #3
0
    def test_discrete_wrapper(self):
        ids = range(1, 5)
        state_normalization_parameters = {i: _cont_norm() for i in ids}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = models.FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[
            rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids
        ])
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*state_with_presence)))
        self.assertTrue((expected_output == q_values).all())
Beispiel #4
0
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        embedding_concat = models.EmbeddingBagConcat(
            state_dim=len(state_normalization_parameters),
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn = models.Sequential(
            embedding_concat,
            rlt.TensorFeatureData(),
            models.FullyConnectedDQN(
                embedding_concat.output_dim,
                action_dim=action_dim,
                sizes=[16],
                activations=["relu"],
            ),
        )

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype.id_list_features.items()
        }
        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(
                float_features=state_preprocessor(*state_with_presence),
                id_list_features=state_id_list_features,
            ))
        self.assertTrue((expected_output == q_values).all())
Beispiel #5
0
def get_feature_config(
    float_features: Optional[List[Tuple[int, str]]]
) -> rlt.ModelFeatureConfig:
    float_features = float_features or []
    float_feature_infos = [
        rlt.FloatFeatureInfo(name=f_name, feature_id=f_id)
        for f_id, f_name in float_features
    ]

    return rlt.ModelFeatureConfig(float_feature_infos=float_feature_infos)
Beispiel #6
0
    def test_fully_connected_with_embedding(self):
        # Intentionally used this long path to make sure we included it in __init__.py
        chooser = DiscreteDQNNetBuilder__Union(
            FullyConnectedWithEmbedding=discrete_dqn.
            fully_connected_with_embedding.FullyConnectedWithEmbedding())
        self._test_discrete_dqn_net_builder(chooser)

        # only id_list
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        self._test_discrete_dqn_net_builder(
            chooser, state_feature_config=state_feature_config)

        # with id_score_list
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_score_list_feature_configs=[
                rlt.IdScoreListFeatureConfig(name="B",
                                             feature_id=100,
                                             id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        self._test_discrete_dqn_net_builder(
            chooser, state_feature_config=state_feature_config)