def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) embedding_concat = models.EmbeddingBagConcat( state_dim=len(state_normalization_parameters), model_feature_config=state_feature_config, embedding_dim=8, ) dqn = models.Sequential( embedding_concat, rlt.TensorFeatureData(), models.FullyConnectedDQN( embedding_concat.output_dim, action_dim=action_dim, sizes=[16], activations=["relu"], ), ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype.id_list_features.items() } state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData( float_features=state_preprocessor(*state_with_presence), id_list_features=state_id_list_features, )) self.assertTrue((expected_output == q_values).all())
def build_actor( self, state_feature_config: rlt.ModelFeatureConfig, state_normalization_data: NormalizationData, action_normalization_data: NormalizationData, ) -> ModelBase: state_dim = get_num_output_features( state_normalization_data.dense_normalization_parameters) action_dim = get_num_output_features( action_normalization_data.dense_normalization_parameters) input_dim = state_dim embedding_dim = self.embedding_dim embedding_concat = None if embedding_dim is not None: embedding_concat = models.EmbeddingBagConcat( state_dim=state_dim, model_feature_config=state_feature_config, embedding_dim=embedding_dim, ) input_dim = embedding_concat.output_dim gaussian_fc_actor = GaussianFullyConnectedActor( state_dim=input_dim, action_dim=action_dim, sizes=self.sizes, activations=self.activations, use_batch_norm=self.use_batch_norm, use_layer_norm=self.use_layer_norm, use_l2_normalization=self.use_l2_normalization, ) if not embedding_dim: return gaussian_fc_actor assert embedding_concat is not None return models.Sequential( # type: ignore embedding_concat, rlt.TensorFeatureData(), gaussian_fc_actor, )
def get_predictor(self, trainer, environment): state_preprocessor = Preprocessor(environment.normalization, False) q_network = trainer.q_network if isinstance(trainer, QRDQNTrainer): class _Mean(torch.nn.Module): def forward(self, input): assert input.ndim == 3 return input.mean(dim=2) q_network = models.Sequential(q_network, _Mean()) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor ) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=environment.ACTIONS, ) predictor = DiscreteDqnTorchPredictor(serving_module) return predictor
def build_q_network( self, state_feature_config: rlt.ModelFeatureConfig, state_normalization_data: NormalizationData, output_dim: int, ) -> models.ModelBase: state_dim = self._get_input_dim(state_normalization_data) embedding_concat = models.EmbeddingBagConcat( state_dim=state_dim, model_feature_config=state_feature_config, embedding_dim=self.embedding_dim, ) return models.Sequential( # type: ignore embedding_concat, rlt.TensorFeatureData(), models.FullyConnectedDQN( embedding_concat.output_dim, action_dim=output_dim, sizes=self.sizes, activations=self.activations, dropout_ratio=self.dropout_ratio, ), )