def feature_config(self): return rlt.ModelFeatureConfig( id_mapping_config={ "page": rlt.IdMapping(ids=list(range(100, 100 + self.embedding_size))) }, sequence_features_type=SequenceFeatures, )
def feature_config(self): return rlt.ModelFeatureConfig( id_mapping_config={ "page": rlt.IdMapping(ids=list(range(100, 100 + self.embedding_size))) }, id_list_feature_configs=[ rlt.IdFeatureConfig( name="page_id", feature_id=2002, id_mapping_name="page" ) ], )
def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) dqn = FullyConnectedDQNWithEmbedding( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], model_feature_config=state_feature_config, embedding_dim=8, ) dqn_with_preprocessor = DiscreteDqnWithPreprocessorWithIdList( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapperWithIdList( dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_values = wrapper(*input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype[1].items() } expected_output = dqn( rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector( float_features=state_preprocessor(*input_prototype[0]), id_list_features=state_id_list_features, ))).q_values self.assertTrue((expected_output == q_values).all())
def test_fully_connected_with_id_list(self): # Intentionally used this long path to make sure we included it in __init__.py chooser = DiscreteDQNNetBuilderChooser( FullyConnectedWithEmbedding=discrete_dqn.fully_connected_with_embedding.FullyConnectedWithEmbedding.config_type()() ) state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig( name="A", feature_id=10, id_mapping_name="A_mapping" ) ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) self._test_discrete_dqn_net_builder( chooser, state_feature_config=state_feature_config, serving_module_class=DiscreteDqnPredictorWrapperWithIdList, )