def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: # Run state feature identification state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, self.get_state_preprocessing_options(), ) # Run action feature identification action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, self.get_action_preprocessing_options(), ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=action_normalization_parameters ), }
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]: try: return env.normalization_data except AttributeError: # TODO: make this a property of EnvWrapper? # pyre-fixme[16]: Module `envs` has no attribute `RecSim`. if HAS_RECSIM and isinstance(env, RecSim): return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(env.observation_space["user"].shape[0])) ) ), NormalizationKey.ITEM: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(env.observation_space["doc"]["0"].shape[0])) ) ), } return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=build_state_normalizer(env) ), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=build_action_normalizer(env) ), }
def _test_actor_net_builder( self, chooser: ContinuousActorNetBuilder__Union) -> None: builder = chooser.value state_dim = 3 state_norm_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for i in range(state_dim) }) action_dim = 2 action_norm_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=builder.default_action_preprocessing, min_value=0.0, max_value=1.0, ) for i in range(action_dim) }) actor_network = builder.build_actor(state_norm_data, action_norm_data) x = actor_network.input_prototype() y = actor_network(x) action = y.action log_prob = y.log_prob self.assertEqual(action.shape, (1, action_dim)) self.assertEqual(log_prob.shape, (1, 1)) serving_module = builder.build_serving_module(actor_network, state_norm_data, action_norm_data) self.assertIsInstance(serving_module, ActorPredictorWrapper)
def _test_parametric_dqn_net_builder( self, chooser: ParametricDQNNetBuilder__Union) -> None: builder = chooser.value state_dim = 3 state_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for i in range(state_dim) }) action_dim = 2 action_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for i in range(action_dim) }) q_network = builder.build_q_network(state_normalization_data, action_normalization_data) x = q_network.input_prototype() y = q_network(*x) self.assertEqual(y.shape, (1, 1)) serving_module = builder.build_serving_module( q_network, state_normalization_data, action_normalization_data) self.assertIsInstance(serving_module, ParametricDqnPredictorWrapper)
def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: # Run state feature identification state_preprocessing_options = ( # pyre-fixme[16]: `ActorCriticBase` has no attribute # `_state_preprocessing_options`. self._state_preprocessing_options or PreprocessingOptions()) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state whitelist_features: {state_features}") state_preprocessing_options = state_preprocessing_options._replace( whitelist_features=state_features) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options) # Run action feature identification action_preprocessing_options = ( # pyre-fixme[16]: `ActorCriticBase` has no attribute # `_action_preprocessing_options`. self._action_preprocessing_options or PreprocessingOptions()) action_features = [ ffi.feature_id for ffi in self.action_feature_config.float_feature_infos ] logger.info(f"action whitelist_features: {action_features}") # pyre-fixme[16]: `ActorCriticBase` has no attribute `actor_net_builder`. actor_net_builder = self.actor_net_builder.value action_feature_override = actor_net_builder.default_action_preprocessing logger.info( f"Default action_feature_override is {action_feature_override}") if self.action_feature_override is not None: action_feature_override = self.action_feature_override assert action_preprocessing_options.feature_overrides is None action_preprocessing_options = action_preprocessing_options._replace( whitelist_features=action_features, feature_overrides={ fid: action_feature_override for fid in action_features }, ) action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, action_preprocessing_options) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=action_normalization_parameters ), }
def build_normalizer(env) -> Dict[str, NormalizationData]: return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=build_state_normalizer(env) ), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=build_action_normalizer(env) ), }
def build_normalizer(env: Env) -> Dict[str, NormalizationData]: try: # pyre-fixme[16]: `Env` has no attribute `normalization_data`. return env.normalization_data except AttributeError: return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=build_state_normalizer(env)), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=build_action_normalizer(env)), }
def _test_discrete_dqn_net_builder( self, chooser: DiscreteDQNNetBuilder__Union, state_feature_config: Optional[rlt.ModelFeatureConfig] = None, serving_module_class=DiscreteDqnPredictorWrapper, ) -> None: builder = chooser.value state_dim = 3 state_feature_config = state_feature_config or rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i) for i in range(state_dim) ]) state_dim = len(state_feature_config.float_feature_infos) state_normalization_data = NormalizationData( dense_normalization_parameters={ fi.feature_id: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for fi in state_feature_config.float_feature_infos }) action_names = ["L", "R"] q_network = builder.build_q_network(state_feature_config, state_normalization_data, len(action_names)) x = q_network.input_prototype() y = q_network(x) self.assertEqual(y.shape, (1, 2)) serving_module = builder.build_serving_module( q_network, state_normalization_data, action_names, state_feature_config) self.assertIsInstance(serving_module, serving_module_class)
def normalization_data(self): return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(self.num_arms)), MU_LOW, MU_HIGH)) }
def run_feature_identification( self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: # Run state feature identification state_preprocessing_options = ( self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state whitelist_features: {state_features}") state_preprocessing_options = state_preprocessing_options._replace( whitelist_features=state_features ) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) # Run action feature identification action_preprocessing_options = ( self._action_preprocessing_options or PreprocessingOptions() ) action_features = [ ffi.feature_id for ffi in self.action_feature_config.float_feature_infos ] logger.info(f"action whitelist_features: {action_features}") action_preprocessing_options = action_preprocessing_options._replace( whitelist_features=action_features ) action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, action_preprocessing_options ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters ), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=action_normalization_parameters ), }
def run_feature_identification( self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: state_preprocessing_options = ( self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state whitelist_features: {state_features}") state_preprocessing_options = state_preprocessing_options._replace( whitelist_features=state_features ) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) item_preprocessing_options = ( self._item_preprocessing_options or PreprocessingOptions() ) item_features = [ ffi.feature_id for ffi in self.item_feature_config.float_feature_infos ] logger.info(f"item whitelist_features: {item_features}") item_preprocessing_options = item_preprocessing_options._replace( whitelist_features=item_features, sequence_feature_id=self.slate_feature_id ) item_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_SEQUENCE_FEATURES, item_preprocessing_options, ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters ), NormalizationKey.ITEM: NormalizationData( dense_normalization_parameters=item_normalization_parameters ), }
def test_fully_connected(self): chooser = ValueNetBuilder__Union( FullyConnected=value.fully_connected.FullyConnected()) builder = chooser.value state_dim = 3 normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters(feature_type=CONTINUOUS) for i in range(state_dim) }) value_network = builder.build_value_network(normalization_data) batch_size = 5 x = torch.randn(batch_size, state_dim) y = value_network(x) self.assertEqual(y.shape, (batch_size, 1))
def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: preprocessing_options = self.preprocessing_options or PreprocessingOptions( ) logger.info("Overriding whitelist_features") state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] preprocessing_options = preprocessing_options._replace( whitelist_features=state_features) return { NormalizationKey.STATE: NormalizationData(dense_normalization_parameters= identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options)) }
def build_normalizer(env): if isinstance(env.observation_space, gym.spaces.Box): assert (len(env.observation_space.shape) == 1 ), f"{env.observation_space} not supported." return { "state": NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(env.observation_space.shape[0])), env.observation_space.low, env.observation_space.high, )) } elif isinstance(env.observation_space, gym.spaces.Dict): # assuming env.observation_space is image return None else: raise NotImplementedError(f"{env.observation_space} not supported")
def action_normalization_data(self) -> NormalizationData: return NormalizationData( dense_normalization_parameters={ i: NormalizationParameters(feature_type="DISCRETE_ACTION") for i in range(len(self.action_names)) })
def test_seq2slate_scriptable(self): state_dim = 2 candidate_dim = 3 num_stacked_layers = 2 num_heads = 2 dim_model = 128 dim_feedforward = 128 candidate_size = 8 slate_size = 8 output_arch = Seq2SlateOutputArch.AUTOREGRESSIVE temperature = 1.0 greedy_serving = True # test the raw Seq2Slate model is script-able seq2slate = Seq2SlateTransformerModel( state_dim=state_dim, candidate_dim=candidate_dim, num_stacked_layers=num_stacked_layers, num_heads=num_heads, dim_model=dim_model, dim_feedforward=dim_feedforward, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=temperature, ) seq2slate_scripted = torch.jit.script(seq2slate) seq2slate_net = Seq2SlateTransformerNet( state_dim=state_dim, candidate_dim=candidate_dim, num_stacked_layers=num_stacked_layers, num_heads=num_heads, dim_model=dim_model, dim_feedforward=dim_feedforward, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=temperature, ) state_normalization_data = NormalizationData( dense_normalization_parameters={ 0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), }) candidate_normalization_data = NormalizationData( dense_normalization_parameters={ 5: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 6: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 7: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), }) state_preprocessor = Preprocessor( state_normalization_data.dense_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_data.dense_normalization_parameters, False) # test seq2slate with preprocessor is scriptable seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate_net.eval(), state_preprocessor, candidate_preprocessor, greedy_serving, ) torch.jit.script(seq2slate_with_preprocessor)