コード例 #1
0
    def test_actor_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {
            i: _cont_action_norm()
            for i in range(101, 105)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        postprocessor = Postprocessor(action_normalization_parameters, False)

        # Test with FullyConnectedActor to make behavior deterministic
        actor = models.FullyConnectedActor(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        actor_with_preprocessor = ActorWithPreprocessor(
            actor, state_preprocessor, postprocessor)
        wrapper = ActorPredictorWrapper(actor_with_preprocessor)
        input_prototype = actor_with_preprocessor.input_prototype()
        action = wrapper(*input_prototype)
        self.assertEqual(action.shape,
                         (1, len(action_normalization_parameters)))

        expected_output = postprocessor(
            actor(rlt.FeatureData(
                state_preprocessor(*input_prototype[0]))).action)
        self.assertTrue((expected_output == action).all())
コード例 #2
0
    def build_serving_module(
        self,
        actor: ModelBase,
        state_normalization_data: NormalizationData,
        action_normalization_data: NormalizationData,
    ) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """
        state_normalization_parameters = (
            state_normalization_data.dense_normalization_parameters)
        action_normalization_parameters = (
            action_normalization_data.dense_normalization_parameters)
        assert state_normalization_parameters is not None
        assert action_normalization_parameters is not None

        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          use_gpu=False)
        postprocessor = Postprocessor(action_normalization_parameters,
                                      use_gpu=False)
        actor_with_preprocessor = ActorWithPreprocessor(
            actor.cpu_model().eval(), state_preprocessor, postprocessor)
        action_features = Preprocessor(action_normalization_parameters,
                                       use_gpu=False).sorted_features
        return ActorPredictorWrapper(actor_with_preprocessor, action_features)
コード例 #3
0
    def build_serving_module(
        self,
        actor: ModelBase,
        state_feature_config: rlt.ModelFeatureConfig,
        state_normalization_data: NormalizationData,
        action_normalization_data: NormalizationData,
        serve_mean_policy: bool = False,
    ) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """

        state_preprocessor = Preprocessor(
            state_normalization_data.dense_normalization_parameters,
            use_gpu=False)
        postprocessor = Postprocessor(
            action_normalization_data.dense_normalization_parameters,
            use_gpu=False)
        actor_with_preprocessor = ActorWithPreprocessor(
            actor.cpu_model().eval(),
            state_preprocessor,
            state_feature_config,
            postprocessor,
            serve_mean_policy=serve_mean_policy,
        )
        action_features = Preprocessor(
            action_normalization_data.dense_normalization_parameters,
            use_gpu=False).sorted_features
        return ActorPredictorWrapper(actor_with_preprocessor,
                                     state_feature_config, action_features)
コード例 #4
0
ファイル: test_gridworld_td3.py プロジェクト: zwcdp/ReAgent
 def get_actor_predictor(self, trainer, environment):
     state_preprocessor = Preprocessor(environment.normalization, False)
     postprocessor = Postprocessor(
         environment.normalization_continuous_action, False)
     actor_with_preprocessor = ActorWithPreprocessor(
         trainer.actor_network.cpu_model().eval(), state_preprocessor,
         postprocessor)
     serving_module = ActorPredictorWrapper(actor_with_preprocessor)
     predictor = ActorTorchPredictor(
         serving_module,
         sort_features_by_normalization(
             environment.normalization_continuous_action)[0],
     )
     return predictor
コード例 #5
0
    def build_serving_module(
        self,
        actor: ModelBase,
        state_normalization_data: NormalizationData,
        action_feature_ids: List[int],
    ) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """

        state_preprocessor = Preprocessor(
            state_normalization_data.dense_normalization_parameters, use_gpu=False
        )
        actor_with_preprocessor = ActorWithPreprocessor(
            actor.cpu_model().eval(),
            state_preprocessor,
        )
        return ActorPredictorWrapper(actor_with_preprocessor, action_feature_ids)