Exemple #1
0
 def feature_config(self):
     return rlt.ModelFeatureConfig(
         id_mapping_config={
             "page": rlt.IdMapping(ids=list(range(100, 100 + self.embedding_size)))
         },
         sequence_features_type=SequenceFeatures,
     )
    def _test_discrete_dqn_net_builder(
        self,
        chooser: DiscreteDQNNetBuilder__Union,
        state_feature_config: Optional[rlt.ModelFeatureConfig] = None,
        serving_module_class=DiscreteDqnPredictorWrapper,
    ) -> None:
        builder = chooser.value
        state_dim = 3
        state_feature_config = state_feature_config or rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i)
                for i in range(state_dim)
            ])
        state_dim = len(state_feature_config.float_feature_infos)

        state_norm_params = {
            fi.feature_id: NormalizationParameters(feature_type=CONTINUOUS,
                                                   mean=0.0,
                                                   stddev=1.0)
            for fi in state_feature_config.float_feature_infos
        }

        action_names = ["L", "R"]
        q_network = builder.build_q_network(state_feature_config,
                                            state_norm_params,
                                            len(action_names))
        x = q_network.input_prototype()
        y = q_network(x).q_values
        self.assertEqual(y.shape, (1, 2))
        serving_module = builder.build_serving_module(q_network,
                                                      state_norm_params,
                                                      action_names,
                                                      state_feature_config)
        self.assertIsInstance(serving_module, serving_module_class)
Exemple #3
0
 def feature_config(self):
     return rlt.ModelFeatureConfig(
         id_mapping_config={
             "page": rlt.IdMapping(ids=list(range(100, 100 + self.embedding_size)))
         },
         id_list_feature_configs=[
             rlt.IdFeatureConfig(
                 name="page_id", feature_id=2002, id_mapping_name="page"
             )
         ],
     )
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        dqn = FullyConnectedDQNWithEmbedding(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn_with_preprocessor = DiscreteDqnWithPreprocessorWithIdList(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapperWithIdList(
            dqn_with_preprocessor, action_names, state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_values = wrapper(*input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype[1].items()
        }
        expected_output = dqn(
            rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector(
                float_features=state_preprocessor(*input_prototype[0]),
                id_list_features=state_id_list_features,
            ))).q_values
        self.assertTrue((expected_output == q_values).all())
Exemple #5
0
 def test_fully_connected_with_id_list(self):
     # Intentionally used this long path to make sure we included it in __init__.py
     chooser = DiscreteDQNNetBuilderChooser(
         FullyConnectedWithEmbedding=discrete_dqn.fully_connected_with_embedding.FullyConnectedWithEmbedding.config_type()()
     )
     state_feature_config = rlt.ModelFeatureConfig(
         float_feature_infos=[
             rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5)
         ],
         id_list_feature_configs=[
             rlt.IdListFeatureConfig(
                 name="A", feature_id=10, id_mapping_name="A_mapping"
             )
         ],
         id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
     )
     self._test_discrete_dqn_net_builder(
         chooser,
         state_feature_config=state_feature_config,
         serving_module_class=DiscreteDqnPredictorWrapperWithIdList,
     )
Exemple #6
0
class DiscreteDQNBase(ModelManager):
    target_action_distribution: Optional[List[float]] = None
    state_feature_config: Optional[rlt.ModelFeatureConfig] = field(
        default_factory=lambda: rlt.ModelFeatureConfig(float_feature_infos=[]))
    preprocessing_options: Optional[PreprocessingOptions] = None
    reader_options: Optional[ReaderOptions] = None

    def __post_init__(self):
        super().__init__()
        self._metrics_to_score = None
        self._q_network: Optional[ModelBase] = None

    @classmethod
    def normalization_key(cls) -> str:
        return DiscreteNormalizationParameterKeys.STATE

    @property
    def metrics_to_score(self) -> List[str]:
        assert self.reward_options is not None
        if self._metrics_to_score is None:
            self._metrics_to_score = get_metrics_to_score(
                self._reward_options.metric_reward_values)
        return self._metrics_to_score

    @property
    def should_generate_eval_dataset(self) -> bool:
        return self.eval_parameters.calc_cpe_in_training

    def _set_normalization_parameters(
            self, normalization_data_map: Dict[str, NormalizationData]):
        """
        Set normalization parameters on current instance
        """
        state_norm_data = normalization_data_map.get(self.normalization_key(),
                                                     None)
        assert state_norm_data is not None
        assert state_norm_data.dense_normalization_parameters is not None
        self.state_normalization_parameters = (
            state_norm_data.dense_normalization_parameters)

    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        preprocessing_options = self.preprocessing_options or PreprocessingOptions(
        )
        logger.info("Overriding whitelist_features")
        state_features = [
            ffi.feature_id
            for ffi in self.state_feature_config.float_feature_infos
        ]
        preprocessing_options = preprocessing_options._replace(
            whitelist_features=state_features)

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, "state_features", preprocessing_options)
        return {
            DiscreteNormalizationParameterKeys.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters)
        }

    def query_data(
        self,
        input_table_spec: TableSpec,
        sample_range: Optional[Tuple[float, float]],
        reward_options: RewardOptions,
        eval_dataset: bool,
    ) -> Dataset:
        # sort is set to False because EvaluationPageHandler sort the data anyway
        return query_data(
            input_table_spec,
            self.action_names,
            self.rl_parameters.use_seq_num_diff_as_time_diff,
            sample_range=sample_range,
            metric_reward_values=reward_options.metric_reward_values,
            custom_reward_expression=reward_options.custom_reward_expression,
            additional_reward_expression=reward_options.
            additional_reward_expression,
            multi_steps=self.multi_steps,
            gamma=self.rl_parameters.gamma,
            sort=False,
        )

    @property
    def multi_steps(self) -> Optional[int]:
        return self.rl_parameters.multi_steps

    def build_batch_preprocessor(self) -> BatchPreprocessor:
        raise NotImplementedError

    def train(self, train_dataset: Dataset, eval_dataset: Optional[Dataset],
              num_epochs: int) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainningOutput. The field that should not be filled
        are:
        - output_path
        - warmstart_output_path
        - vis_metrics
        - validation_output
        """
        logger.info("Creating reporter")
        reporter = DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )
        logger.info("Adding reporter to trainer")
        self.trainer.add_observer(reporter)

        training_page_handler = TrainingPageHandler(self.trainer)
        training_page_handler.add_observer(reporter)
        evaluator = Evaluator(
            self.action_names,
            self.rl_parameters.gamma,
            self.trainer,
            metrics_to_score=self.metrics_to_score,
        )
        logger.info("Adding reporter to evaluator")
        evaluator.add_observer(reporter)
        evaluation_page_handler = EvaluationPageHandler(
            self.trainer, evaluator, reporter)

        batch_preprocessor = self.build_batch_preprocessor()
        train_and_evaluate_generic(
            train_dataset,
            eval_dataset,
            self.trainer,
            num_epochs,
            self.use_gpu,
            batch_preprocessor,
            training_page_handler,
            evaluation_page_handler,
            reader_options=self.reader_options,
        )
        return RLTrainingOutput(
            training_report=reporter.generate_training_report())