Beispiel #1
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.STATE_FEATURES,
            self.get_state_preprocessing_options(),
        )

        # Run action feature identification
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.ACTION,
            self.get_action_preprocessing_options(),
        )

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Beispiel #2
0
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
    try:
        return env.normalization_data
    except AttributeError:
        # TODO: make this a property of EnvWrapper?
        # pyre-fixme[16]: Module `envs` has no attribute `RecSim`.
        if HAS_RECSIM and isinstance(env, RecSim):
            return {
                NormalizationKey.STATE: NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["user"].shape[0]))
                    )
                ),
                NormalizationKey.ITEM: NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["doc"]["0"].shape[0]))
                    )
                ),
            }
        return {
            NormalizationKey.STATE: NormalizationData(
                dense_normalization_parameters=build_state_normalizer(env)
            ),
            NormalizationKey.ACTION: NormalizationData(
                dense_normalization_parameters=build_action_normalizer(env)
            ),
        }
 def _test_actor_net_builder(
         self, chooser: ContinuousActorNetBuilder__Union) -> None:
     builder = chooser.value
     state_dim = 3
     state_norm_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
             for i in range(state_dim)
         })
     action_dim = 2
     action_norm_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=builder.default_action_preprocessing,
                 min_value=0.0,
                 max_value=1.0,
             )
             for i in range(action_dim)
         })
     actor_network = builder.build_actor(state_norm_data, action_norm_data)
     x = actor_network.input_prototype()
     y = actor_network(x)
     action = y.action
     log_prob = y.log_prob
     self.assertEqual(action.shape, (1, action_dim))
     self.assertEqual(log_prob.shape, (1, 1))
     serving_module = builder.build_serving_module(actor_network,
                                                   state_norm_data,
                                                   action_norm_data)
     self.assertIsInstance(serving_module, ActorPredictorWrapper)
 def _test_parametric_dqn_net_builder(
         self, chooser: ParametricDQNNetBuilder__Union) -> None:
     builder = chooser.value
     state_dim = 3
     state_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
             for i in range(state_dim)
         })
     action_dim = 2
     action_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
             for i in range(action_dim)
         })
     q_network = builder.build_q_network(state_normalization_data,
                                         action_normalization_data)
     x = q_network.input_prototype()
     y = q_network(*x)
     self.assertEqual(y.shape, (1, 1))
     serving_module = builder.build_serving_module(
         q_network, state_normalization_data, action_normalization_data)
     self.assertIsInstance(serving_module, ParametricDqnPredictorWrapper)
Beispiel #5
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_preprocessing_options = (
            # pyre-fixme[16]: `ActorCriticBase` has no attribute
            #  `_state_preprocessing_options`.
            self._state_preprocessing_options or PreprocessingOptions())
        state_features = [
            ffi.feature_id
            for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state whitelist_features: {state_features}")
        state_preprocessing_options = state_preprocessing_options._replace(
            whitelist_features=state_features)

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES,
            state_preprocessing_options)

        # Run action feature identification
        action_preprocessing_options = (
            # pyre-fixme[16]: `ActorCriticBase` has no attribute
            #  `_action_preprocessing_options`.
            self._action_preprocessing_options or PreprocessingOptions())
        action_features = [
            ffi.feature_id
            for ffi in self.action_feature_config.float_feature_infos
        ]
        logger.info(f"action whitelist_features: {action_features}")

        # pyre-fixme[16]: `ActorCriticBase` has no attribute `actor_net_builder`.
        actor_net_builder = self.actor_net_builder.value
        action_feature_override = actor_net_builder.default_action_preprocessing
        logger.info(
            f"Default action_feature_override is {action_feature_override}")
        if self.action_feature_override is not None:
            action_feature_override = self.action_feature_override

        assert action_preprocessing_options.feature_overrides is None
        action_preprocessing_options = action_preprocessing_options._replace(
            whitelist_features=action_features,
            feature_overrides={
                fid: action_feature_override
                for fid in action_features
            },
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options)

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Beispiel #6
0
def build_normalizer(env) -> Dict[str, NormalizationData]:
    return {
        NormalizationKey.STATE: NormalizationData(
            dense_normalization_parameters=build_state_normalizer(env)
        ),
        NormalizationKey.ACTION: NormalizationData(
            dense_normalization_parameters=build_action_normalizer(env)
        ),
    }
Beispiel #7
0
def build_normalizer(env: Env) -> Dict[str, NormalizationData]:
    try:
        # pyre-fixme[16]: `Env` has no attribute `normalization_data`.
        return env.normalization_data
    except AttributeError:
        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=build_state_normalizer(env)),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=build_action_normalizer(env)),
        }
    def _test_discrete_dqn_net_builder(
        self,
        chooser: DiscreteDQNNetBuilder__Union,
        state_feature_config: Optional[rlt.ModelFeatureConfig] = None,
        serving_module_class=DiscreteDqnPredictorWrapper,
    ) -> None:
        builder = chooser.value
        state_dim = 3
        state_feature_config = state_feature_config or rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i)
                for i in range(state_dim)
            ])
        state_dim = len(state_feature_config.float_feature_infos)

        state_normalization_data = NormalizationData(
            dense_normalization_parameters={
                fi.feature_id: NormalizationParameters(
                    feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
                for fi in state_feature_config.float_feature_infos
            })

        action_names = ["L", "R"]
        q_network = builder.build_q_network(state_feature_config,
                                            state_normalization_data,
                                            len(action_names))
        x = q_network.input_prototype()
        y = q_network(x)
        self.assertEqual(y.shape, (1, 2))
        serving_module = builder.build_serving_module(
            q_network, state_normalization_data, action_names,
            state_feature_config)
        self.assertIsInstance(serving_module, serving_module_class)
Beispiel #9
0
 def normalization_data(self):
     return {
         NormalizationKey.STATE:
         NormalizationData(
             dense_normalization_parameters=only_continuous_normalizer(
                 list(range(self.num_arms)), MU_LOW, MU_HIGH))
     }
Beispiel #10
0
    def run_feature_identification(
        self, input_table_spec: TableSpec
    ) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_preprocessing_options = (
            self._state_preprocessing_options or PreprocessingOptions()
        )
        state_features = [
            ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state whitelist_features: {state_features}")
        state_preprocessing_options = state_preprocessing_options._replace(
            whitelist_features=state_features
        )

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
        )

        # Run action feature identification
        action_preprocessing_options = (
            self._action_preprocessing_options or PreprocessingOptions()
        )
        action_features = [
            ffi.feature_id for ffi in self.action_feature_config.float_feature_infos
        ]
        logger.info(f"action whitelist_features: {action_features}")
        action_preprocessing_options = action_preprocessing_options._replace(
            whitelist_features=action_features
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options
        )
        return {
            NormalizationKey.STATE: NormalizationData(
                dense_normalization_parameters=state_normalization_parameters
            ),
            NormalizationKey.ACTION: NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
Beispiel #11
0
 def run_feature_identification(
     self, input_table_spec: TableSpec
 ) -> Dict[str, NormalizationData]:
     state_preprocessing_options = (
         self._state_preprocessing_options or PreprocessingOptions()
     )
     state_features = [
         ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
     ]
     logger.info(f"state whitelist_features: {state_features}")
     state_preprocessing_options = state_preprocessing_options._replace(
         whitelist_features=state_features
     )
     state_normalization_parameters = identify_normalization_parameters(
         input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
     )
     item_preprocessing_options = (
         self._item_preprocessing_options or PreprocessingOptions()
     )
     item_features = [
         ffi.feature_id for ffi in self.item_feature_config.float_feature_infos
     ]
     logger.info(f"item whitelist_features: {item_features}")
     item_preprocessing_options = item_preprocessing_options._replace(
         whitelist_features=item_features, sequence_feature_id=self.slate_feature_id
     )
     item_normalization_parameters = identify_normalization_parameters(
         input_table_spec,
         InputColumn.STATE_SEQUENCE_FEATURES,
         item_preprocessing_options,
     )
     return {
         NormalizationKey.STATE: NormalizationData(
             dense_normalization_parameters=state_normalization_parameters
         ),
         NormalizationKey.ITEM: NormalizationData(
             dense_normalization_parameters=item_normalization_parameters
         ),
     }
Beispiel #12
0
 def test_fully_connected(self):
     chooser = ValueNetBuilder__Union(
         FullyConnected=value.fully_connected.FullyConnected())
     builder = chooser.value
     state_dim = 3
     normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(feature_type=CONTINUOUS)
             for i in range(state_dim)
         })
     value_network = builder.build_value_network(normalization_data)
     batch_size = 5
     x = torch.randn(batch_size, state_dim)
     y = value_network(x)
     self.assertEqual(y.shape, (batch_size, 1))
Beispiel #13
0
 def run_feature_identification(
         self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
     preprocessing_options = self.preprocessing_options or PreprocessingOptions(
     )
     logger.info("Overriding whitelist_features")
     state_features = [
         ffi.feature_id
         for ffi in self.state_feature_config.float_feature_infos
     ]
     preprocessing_options = preprocessing_options._replace(
         whitelist_features=state_features)
     return {
         NormalizationKey.STATE:
         NormalizationData(dense_normalization_parameters=
                           identify_normalization_parameters(
                               input_table_spec, InputColumn.STATE_FEATURES,
                               preprocessing_options))
     }
Beispiel #14
0
def build_normalizer(env):
    if isinstance(env.observation_space, gym.spaces.Box):
        assert (len(env.observation_space.shape) == 1
                ), f"{env.observation_space} not supported."
        return {
            "state":
            NormalizationData(
                dense_normalization_parameters=only_continuous_normalizer(
                    list(range(env.observation_space.shape[0])),
                    env.observation_space.low,
                    env.observation_space.high,
                ))
        }
    elif isinstance(env.observation_space, gym.spaces.Dict):
        # assuming env.observation_space is image
        return None
    else:
        raise NotImplementedError(f"{env.observation_space} not supported")
Beispiel #15
0
 def action_normalization_data(self) -> NormalizationData:
     return NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(feature_type="DISCRETE_ACTION")
             for i in range(len(self.action_names))
         })
Beispiel #16
0
    def test_seq2slate_scriptable(self):
        state_dim = 2
        candidate_dim = 3
        num_stacked_layers = 2
        num_heads = 2
        dim_model = 128
        dim_feedforward = 128
        candidate_size = 8
        slate_size = 8
        output_arch = Seq2SlateOutputArch.AUTOREGRESSIVE
        temperature = 1.0
        greedy_serving = True

        # test the raw Seq2Slate model is script-able
        seq2slate = Seq2SlateTransformerModel(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )
        seq2slate_scripted = torch.jit.script(seq2slate)

        seq2slate_net = Seq2SlateTransformerNet(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )

        state_normalization_data = NormalizationData(
            dense_normalization_parameters={
                0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })

        candidate_normalization_data = NormalizationData(
            dense_normalization_parameters={
                5: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                6: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                7: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })
        state_preprocessor = Preprocessor(
            state_normalization_data.dense_normalization_parameters, False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_data.dense_normalization_parameters, False)

        # test seq2slate with preprocessor is scriptable
        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate_net.eval(),
            state_preprocessor,
            candidate_preprocessor,
            greedy_serving,
        )
        torch.jit.script(seq2slate_with_preprocessor)