def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: # Run state feature identification state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, self.get_state_preprocessing_options(), ) # Run action feature identification action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, self.get_action_preprocessing_options(), ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=action_normalization_parameters ), }
def _test_parametric_dqn_net_builder( self, chooser: ParametricDQNNetBuilder__Union ) -> None: builder = chooser.value state_dim = 3 state_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0 ) for i in range(state_dim) } ) action_dim = 2 action_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0 ) for i in range(action_dim) } ) q_network = builder.build_q_network( state_normalization_data, action_normalization_data ) x = q_network.input_prototype() y = q_network(*x) self.assertEqual(y.shape, (1, 1)) serving_module = builder.build_serving_module( q_network, state_normalization_data, action_normalization_data ) self.assertIsInstance(serving_module, ParametricDqnPredictorWrapper)
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]: try: return env.normalization_data except AttributeError: # TODO: make this a property of EnvWrapper? # pyre-fixme[16]: Module `envs` has no attribute `RecSim`. if HAS_RECSIM and isinstance(env, RecSim): return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(env.observation_space["user"].shape[0])))), NormalizationKey.ITEM: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(env.observation_space["doc"] ["0"].shape[0])))), } return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=build_state_normalizer(env)), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=build_action_normalizer(env)), }
def _test_actor_net_builder( self, chooser: ContinuousActorNetBuilder__Union) -> None: builder = chooser.value state_dim = 3 state_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for i in range(state_dim) }) action_dim = 2 action_normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=builder.default_action_preprocessing, min_value=0.0, max_value=1.0, ) for i in range(action_dim) }) actor_network = builder.build_actor(state_normalization_data, action_normalization_data) x = actor_network.input_prototype() y = actor_network(x) action = y.action log_prob = y.log_prob self.assertEqual(action.shape, (1, action_dim)) self.assertEqual(log_prob.shape, (1, 1)) serving_module = builder.build_serving_module( actor_network, state_normalization_data, action_normalization_data) self.assertIsInstance(serving_module, ActorPredictorWrapper)
def _test_discrete_dqn_net_builder( self, chooser: DiscreteDQNNetBuilder__Union, state_feature_config: Optional[rlt.ModelFeatureConfig] = None, serving_module_class=DiscreteDqnPredictorWrapper, ) -> None: builder = chooser.value state_dim = 3 state_feature_config = state_feature_config or rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i) for i in range(state_dim) ]) state_dim = len(state_feature_config.float_feature_infos) state_normalization_data = NormalizationData( dense_normalization_parameters={ fi.feature_id: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for fi in state_feature_config.float_feature_infos }) action_names = ["L", "R"] q_network = builder.build_q_network(state_feature_config, state_normalization_data, len(action_names)) x = q_network.input_prototype() y = q_network(x) self.assertEqual(y.shape, (1, 2)) serving_module = builder.build_serving_module( q_network, state_normalization_data, action_names, state_feature_config) self.assertIsInstance(serving_module, serving_module_class)
def normalization_data(self): return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=only_continuous_normalizer( list(range(self.num_arms)), MU_LOW, MU_HIGH)) }
def run_feature_identification( self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: state_preprocessing_options = ( self.model_manager.state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.model_manager.state_feature_config.float_feature_infos ] logger.info(f"state allowedlist_features: {state_features}") state_preprocessing_options = replace( state_preprocessing_options, allowedlist_features=state_features ) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) if self.model_manager.discrete_action_names: return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters ) } # Run action feature identification action_preprocessing_options = ( self.model_manager.action_preprocessing_options or PreprocessingOptions() ) action_features = [ ffi.feature_id for ffi in self.model_manager.action_feature_config.float_feature_infos ] logger.info(f"action allowedlist_features: {action_features}") action_preprocessing_options = replace( action_preprocessing_options, allowedlist_features=action_features ) action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, action_preprocessing_options ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters ), NormalizationKey.ACTION: NormalizationData( dense_normalization_parameters=action_normalization_parameters ), }
def _create_norm(dim, offset=0): normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters( feature_type=CONTINUOUS, mean=0.0, stddev=1.0) for i in range(offset, dim + offset) }) return normalization_data
def run_feature_identification( self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: state_preprocessing_options = ( self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state allowedlist_features: {state_features}") state_preprocessing_options = state_preprocessing_options._replace( allowedlist_features=state_features ) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) item_preprocessing_options = ( self._item_preprocessing_options or PreprocessingOptions() ) item_features = [ ffi.feature_id for ffi in self.item_feature_config.float_feature_infos ] logger.info(f"item allowedlist_features: {item_features}") item_preprocessing_options = item_preprocessing_options._replace( allowedlist_features=item_features, sequence_feature_id=self.slate_feature_id, ) item_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_SEQUENCE_FEATURES, item_preprocessing_options, ) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters ), NormalizationKey.ITEM: NormalizationData( dense_normalization_parameters=item_normalization_parameters ), }
def test_fully_connected(self): chooser = ValueNetBuilder__Union( FullyConnected=value.fully_connected.FullyConnected()) builder = chooser.value state_dim = 3 normalization_data = NormalizationData( dense_normalization_parameters={ i: NormalizationParameters(feature_type=CONTINUOUS) for i in range(state_dim) }) value_network = builder.build_value_network(normalization_data) batch_size = 5 x = torch.randn(batch_size, state_dim) y = value_network(x) self.assertEqual(y.shape, (batch_size, 1))
def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: preprocessing_options = (self.model_manager.preprocessing_options or PreprocessingOptions()) logger.info("Overriding allowedlist_features") state_features = [ ffi.feature_id for ffi in self.model_manager.state_feature_config.float_feature_infos ] preprocessing_options = preprocessing_options._replace( allowedlist_features=state_features) return { NormalizationKey.STATE: NormalizationData(dense_normalization_parameters= identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options)) }
def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: # Run state feature identification state_preprocessing_options = PreprocessingOptions() state_features = [ ffi.feature_id for ffi in self.model_manager.state_feature_config.float_feature_infos ] logger.info(f"Overriding state allowedlist_features: {state_features}") state_preprocessing_options = replace( state_preprocessing_options, allowedlist_features=state_features) state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options) return { NormalizationKey.STATE: NormalizationData( dense_normalization_parameters=state_normalization_parameters) }
def train_seq2reward_compress_model( training_data, seq2reward_network, learning_rate=0.1, num_epochs=5 ): SEQ_LEN, batch_size, NUM_ACTION = next(iter(training_data)).action.shape assert SEQ_LEN == 6 and NUM_ACTION == 2 compress_net_builder = FullyConnected(sizes=[8, 8]) state_normalization_data = NormalizationData( dense_normalization_parameters={ 0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), } ) compress_model_network = compress_net_builder.build_value_network( state_normalization_data, output_dim=NUM_ACTION, ) trainer_param = Seq2RewardTrainerParameters( learning_rate=0.0, multi_steps=SEQ_LEN, action_names=["0", "1"], compress_model_learning_rate=learning_rate, gamma=1.0, view_q_value=True, ) trainer = CompressModelTrainer( compress_model_network=compress_model_network, seq2reward_network=seq2reward_network, params=trainer_param, ) pl.seed_everything(SEED) pl_trainer = pl.Trainer(max_epochs=num_epochs, deterministic=True) pl_trainer.fit(trainer, training_data) return trainer
def test_seq2slate_scriptable(self): state_dim = 2 candidate_dim = 3 num_stacked_layers = 2 num_heads = 2 dim_model = 128 dim_feedforward = 128 candidate_size = 8 slate_size = 8 output_arch = Seq2SlateOutputArch.AUTOREGRESSIVE temperature = 1.0 greedy_serving = True # test the raw Seq2Slate model is script-able seq2slate = Seq2SlateTransformerModel( state_dim=state_dim, candidate_dim=candidate_dim, num_stacked_layers=num_stacked_layers, num_heads=num_heads, dim_model=dim_model, dim_feedforward=dim_feedforward, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=temperature, ) seq2slate_scripted = torch.jit.script(seq2slate) seq2slate_net = Seq2SlateTransformerNet( state_dim=state_dim, candidate_dim=candidate_dim, num_stacked_layers=num_stacked_layers, num_heads=num_heads, dim_model=dim_model, dim_feedforward=dim_feedforward, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=temperature, ) state_normalization_data = NormalizationData( dense_normalization_parameters={ 0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), }) candidate_normalization_data = NormalizationData( dense_normalization_parameters={ 5: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 6: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), 7: NormalizationParameters(feature_type=DO_NOT_PREPROCESS), }) state_preprocessor = Preprocessor( state_normalization_data.dense_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_data.dense_normalization_parameters, False) # test seq2slate with preprocessor is scriptable seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate_net.eval(), state_preprocessor, candidate_preprocessor, greedy_serving, ) torch.jit.script(seq2slate_with_preprocessor)