Ejemplo n.º 1
0
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        embedding_concat = models.EmbeddingBagConcat(
            state_dim=len(state_normalization_parameters),
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn = models.Sequential(
            embedding_concat,
            rlt.TensorFeatureData(),
            models.FullyConnectedDQN(
                embedding_concat.output_dim,
                action_dim=action_dim,
                sizes=[16],
                activations=["relu"],
            ),
        )

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype.id_list_features.items()
        }
        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(
                float_features=state_preprocessor(*state_with_presence),
                id_list_features=state_id_list_features,
            ))
        self.assertTrue((expected_output == q_values).all())
Ejemplo n.º 2
0
    def build_actor(
        self,
        state_feature_config: rlt.ModelFeatureConfig,
        state_normalization_data: NormalizationData,
        action_normalization_data: NormalizationData,
    ) -> ModelBase:
        state_dim = get_num_output_features(
            state_normalization_data.dense_normalization_parameters)
        action_dim = get_num_output_features(
            action_normalization_data.dense_normalization_parameters)
        input_dim = state_dim
        embedding_dim = self.embedding_dim

        embedding_concat = None
        if embedding_dim is not None:
            embedding_concat = models.EmbeddingBagConcat(
                state_dim=state_dim,
                model_feature_config=state_feature_config,
                embedding_dim=embedding_dim,
            )
            input_dim = embedding_concat.output_dim

        gaussian_fc_actor = GaussianFullyConnectedActor(
            state_dim=input_dim,
            action_dim=action_dim,
            sizes=self.sizes,
            activations=self.activations,
            use_batch_norm=self.use_batch_norm,
            use_layer_norm=self.use_layer_norm,
            use_l2_normalization=self.use_l2_normalization,
        )

        if not embedding_dim:
            return gaussian_fc_actor

        assert embedding_concat is not None
        return models.Sequential(  # type: ignore
            embedding_concat,
            rlt.TensorFeatureData(),
            gaussian_fc_actor,
        )
Ejemplo n.º 3
0
    def get_predictor(self, trainer, environment):
        state_preprocessor = Preprocessor(environment.normalization, False)
        q_network = trainer.q_network
        if isinstance(trainer, QRDQNTrainer):

            class _Mean(torch.nn.Module):
                def forward(self, input):
                    assert input.ndim == 3
                    return input.mean(dim=2)

            q_network = models.Sequential(q_network, _Mean())

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            q_network.cpu_model().eval(), state_preprocessor
        )
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor,
            action_names=environment.ACTIONS,
        )
        predictor = DiscreteDqnTorchPredictor(serving_module)
        return predictor
Ejemplo n.º 4
0
 def build_q_network(
     self,
     state_feature_config: rlt.ModelFeatureConfig,
     state_normalization_data: NormalizationData,
     output_dim: int,
 ) -> models.ModelBase:
     state_dim = self._get_input_dim(state_normalization_data)
     embedding_concat = models.EmbeddingBagConcat(
         state_dim=state_dim,
         model_feature_config=state_feature_config,
         embedding_dim=self.embedding_dim,
     )
     return models.Sequential(  # type: ignore
         embedding_concat,
         rlt.TensorFeatureData(),
         models.FullyConnectedDQN(
             embedding_concat.output_dim,
             action_dim=output_dim,
             sizes=self.sizes,
             activations=self.activations,
             dropout_ratio=self.dropout_ratio,
         ),
     )