예제 #1
0
    def test_parametric_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {i: _cont_norm() for i in range(5, 9)}
        state_preprocessor = Preprocessor(state_normalization_parameters, False)
        action_preprocessor = Preprocessor(action_normalization_parameters, False)
        dqn = FullyConnectedParametricDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        dqn_with_preprocessor = ParametricDqnWithPreprocessor(
            dqn,
            state_preprocessor=state_preprocessor,
            action_preprocessor=action_preprocessor,
        )
        wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor)

        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_value = wrapper(*input_prototype)
        self.assertEqual(output_action_names, ["Q"])
        self.assertEqual(q_value.shape, (1, 1))

        expected_output = dqn(
            rlt.PreprocessedStateAction.from_tensors(
                state=state_preprocessor(*input_prototype[0]),
                action=action_preprocessor(*input_prototype[1]),
            )
        ).q_value
        self.assertTrue((expected_output == q_value).all())
예제 #2
0
 def get_predictor(self, trainer, environment):
     state_preprocessor = Preprocessor(environment.normalization, False)
     action_preprocessor = Preprocessor(environment.normalization_action,
                                        False)
     q_network = self.current_predictor_network
     dqn_with_preprocessor = ParametricDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor,
         action_preprocessor)
     serving_module = ParametricDqnPredictorWrapper(
         dqn_with_preprocessor=dqn_with_preprocessor)
     predictor = ParametricDqnTorchPredictor(serving_module)
     return predictor
예제 #3
0
 def build_serving_module(
     self,
     q_network: ModelBase,
     state_normalization_parameters: Dict[int, NormalizationParameters],
     action_normalization_parameters: Dict[int, NormalizationParameters],
 ) -> torch.nn.Module:
     """
     Returns a TorchScript predictor module
     """
     state_preprocessor = Preprocessor(state_normalization_parameters,
                                       False)
     action_preprocessor = Preprocessor(action_normalization_parameters,
                                        False)
     dqn_with_preprocessor = ParametricDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor,
         action_preprocessor)
     return ParametricDqnPredictorWrapper(
         dqn_with_preprocessor=dqn_with_preprocessor)
    def save_models(self, path: str):
        export_time = round(time.time())
        output_path = os.path.expanduser(path)
        pytorch_output_path = os.path.join(output_path,
                                           "trainer_{}.pt".format(export_time))
        torchscript_output_path = os.path.join(
            path, "model_{}.torchscript".format(export_time))

        state_preprocessor = Preprocessor(self.state_normalization, False)
        action_preprocessor = Preprocessor(self.action_normalization, False)
        q_network = self.trainer.q_network
        dqn_with_preprocessor = ParametricDqnWithPreprocessor(
            q_network.cpu_model().eval(), state_preprocessor,
            action_preprocessor)
        serving_module = ParametricDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor)
        logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path))
        save_model_to_file(self.trainer, pytorch_output_path)
        self.save_torchscript_model(serving_module, torchscript_output_path)