Example #1
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.STATE_FEATURES,
            self.get_state_preprocessing_options(),
        )

        # Run action feature identification
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.ACTION,
            self.get_action_preprocessing_options(),
        )

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Example #2
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_preprocessing_options = (
            # pyre-fixme[16]: `ActorCriticBase` has no attribute
            #  `_state_preprocessing_options`.
            self._state_preprocessing_options or PreprocessingOptions())
        state_features = [
            ffi.feature_id
            for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state whitelist_features: {state_features}")
        state_preprocessing_options = state_preprocessing_options._replace(
            whitelist_features=state_features)

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES,
            state_preprocessing_options)

        # Run action feature identification
        action_preprocessing_options = (
            # pyre-fixme[16]: `ActorCriticBase` has no attribute
            #  `_action_preprocessing_options`.
            self._action_preprocessing_options or PreprocessingOptions())
        action_features = [
            ffi.feature_id
            for ffi in self.action_feature_config.float_feature_infos
        ]
        logger.info(f"action whitelist_features: {action_features}")

        # pyre-fixme[16]: `ActorCriticBase` has no attribute `actor_net_builder`.
        actor_net_builder = self.actor_net_builder.value
        action_feature_override = actor_net_builder.default_action_preprocessing
        logger.info(
            f"Default action_feature_override is {action_feature_override}")
        if self.action_feature_override is not None:
            action_feature_override = self.action_feature_override

        assert action_preprocessing_options.feature_overrides is None
        action_preprocessing_options = action_preprocessing_options._replace(
            whitelist_features=action_features,
            feature_overrides={
                fid: action_feature_override
                for fid in action_features
            },
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options)

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
    def run_feature_identification(
        self, input_table_spec: TableSpec
    ) -> Dict[str, NormalizationData]:
        state_preprocessing_options = (
            self.model_manager.state_preprocessing_options or PreprocessingOptions()
        )
        state_features = [
            ffi.feature_id
            for ffi in self.model_manager.state_feature_config.float_feature_infos
        ]
        logger.info(f"state allowedlist_features: {state_features}")
        state_preprocessing_options = replace(
            state_preprocessing_options, allowedlist_features=state_features
        )

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
        )
        if self.model_manager.discrete_action_names:
            return {
                NormalizationKey.STATE: NormalizationData(
                    dense_normalization_parameters=state_normalization_parameters
                )
            }
        # Run action feature identification
        action_preprocessing_options = (
            self.model_manager.action_preprocessing_options or PreprocessingOptions()
        )
        action_features = [
            ffi.feature_id
            for ffi in self.model_manager.action_feature_config.float_feature_infos
        ]
        logger.info(f"action allowedlist_features: {action_features}")
        action_preprocessing_options = replace(
            action_preprocessing_options, allowedlist_features=action_features
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options
        )
        return {
            NormalizationKey.STATE: NormalizationData(
                dense_normalization_parameters=state_normalization_parameters
            ),
            NormalizationKey.ACTION: NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Example #4
0
    def run_feature_identification(
        self, input_table_spec: TableSpec
    ) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_preprocessing_options = (
            self._state_preprocessing_options or PreprocessingOptions()
        )
        state_features = [
            ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state whitelist_features: {state_features}")
        state_preprocessing_options = state_preprocessing_options._replace(
            whitelist_features=state_features
        )

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
        )

        # Run action feature identification
        action_preprocessing_options = (
            self._action_preprocessing_options or PreprocessingOptions()
        )
        action_features = [
            ffi.feature_id for ffi in self.action_feature_config.float_feature_infos
        ]
        logger.info(f"action whitelist_features: {action_features}")
        action_preprocessing_options = action_preprocessing_options._replace(
            whitelist_features=action_features
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options
        )
        return {
            NormalizationKey.STATE: NormalizationData(
                dense_normalization_parameters=state_normalization_parameters
            ),
            NormalizationKey.ACTION: NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Example #5
0
 def run_feature_identification(
     self, input_table_spec: TableSpec
 ) -> Dict[str, NormalizationData]:
     state_preprocessing_options = (
         self._state_preprocessing_options or PreprocessingOptions()
     )
     state_features = [
         ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
     ]
     logger.info(f"state allowedlist_features: {state_features}")
     state_preprocessing_options = state_preprocessing_options._replace(
         allowedlist_features=state_features
     )
     state_normalization_parameters = identify_normalization_parameters(
         input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
     )
     item_preprocessing_options = (
         self._item_preprocessing_options or PreprocessingOptions()
     )
     item_features = [
         ffi.feature_id for ffi in self.item_feature_config.float_feature_infos
     ]
     logger.info(f"item allowedlist_features: {item_features}")
     item_preprocessing_options = item_preprocessing_options._replace(
         allowedlist_features=item_features,
         sequence_feature_id=self.slate_feature_id,
     )
     item_normalization_parameters = identify_normalization_parameters(
         input_table_spec,
         InputColumn.STATE_SEQUENCE_FEATURES,
         item_preprocessing_options,
     )
     return {
         NormalizationKey.STATE: NormalizationData(
             dense_normalization_parameters=state_normalization_parameters
         ),
         NormalizationKey.ITEM: NormalizationData(
             dense_normalization_parameters=item_normalization_parameters
         ),
     }
Example #6
0
 def run_feature_identification(
         self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
     preprocessing_options = self.preprocessing_options or PreprocessingOptions(
     )
     logger.info("Overriding whitelist_features")
     state_features = [
         ffi.feature_id
         for ffi in self.state_feature_config.float_feature_infos
     ]
     preprocessing_options = preprocessing_options._replace(
         whitelist_features=state_features)
     return {
         NormalizationKey.STATE:
         NormalizationData(dense_normalization_parameters=
                           identify_normalization_parameters(
                               input_table_spec, InputColumn.STATE_FEATURES,
                               preprocessing_options))
     }
 def run_feature_identification(
     self, input_table_spec: TableSpec
 ) -> Dict[str, NormalizationData]:
     preprocessing_options = (
         self.model_manager.preprocessing_options or PreprocessingOptions()
     )
     state_features = [
         ffi.feature_id
         for ffi in self.model_manager.state_feature_config.float_feature_infos
     ]
     logger.info(f"Overriding allowedlist_features: {state_features}")
     preprocessing_options = replace(
         preprocessing_options, allowedlist_features=state_features
     )
     return {
         NormalizationKey.STATE: NormalizationData(
             dense_normalization_parameters=identify_normalization_parameters(
                 input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options
             )
         )
     }
Example #8
0
    def test_preprocessing(self):
        distributions = {}
        distributions["0"] = {"mean": 0, "stddev": 1}
        distributions["1"] = {"mean": 4, "stddev": 3}

        def get_random_feature():
            return {
                k: np.random.normal(loc=info["mean"], scale=info["stddev"])
                for k, info in distributions.items()
            }

        data = [(i, get_random_feature()) for i in range(NUM_ROWS)]
        df = self.sc.parallelize(data).toDF(["i", COL_NAME])
        df.show()

        df.createOrReplaceTempView(TABLE_NAME)

        num_samples = NUM_ROWS // 2
        preprocessing_options = PreprocessingOptions(num_samples=num_samples)

        table_spec = TableSpec(table_name=TABLE_NAME)

        normalization_params = identify_normalization_parameters(
            table_spec,
            COL_NAME,
            preprocessing_options,
            seed=self.test_class_seed)

        logger.info(normalization_params)
        for k, info in distributions.items():
            logger.info(f"Expect {k} to be normal with "
                        f"mean {info['mean']}, stddev {info['stddev']}.")
            assert normalization_params[k].feature_type == CONTINUOUS
            assert (
                abs(normalization_params[k].mean - info["mean"]) < 0.05
            ), f"{normalization_params[k].mean} not close to {info['mean']}"
            assert abs(
                normalization_params[k].stddev - info["stddev"] < 0.2
            ), f"{normalization_params[k].stddev} not close to {info['stddev']}"
        logger.info("identify_normalization_parameters seems fine.")