def test_parametric_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = {i: _cont_norm() for i in range(5, 9)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_preprocessor = Preprocessor(action_normalization_parameters, False) dqn = FullyConnectedParametricDQN( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) dqn_with_preprocessor = ParametricDqnWithPreprocessor( dqn, state_preprocessor=state_preprocessor, action_preprocessor=action_preprocessor, ) wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_value = wrapper(*input_prototype) self.assertEqual(output_action_names, ["Q"]) self.assertEqual(q_value.shape, (1, 1)) expected_output = dqn( rlt.PreprocessedStateAction.from_tensors( state=state_preprocessor(*input_prototype[0]), action=action_preprocessor(*input_prototype[1]), ) ).q_value self.assertTrue((expected_output == q_value).all())
def get_predictor(self, trainer, environment): state_preprocessor = Preprocessor(environment.normalization, False) action_preprocessor = Preprocessor(environment.normalization_action, False) q_network = self.current_predictor_network dqn_with_preprocessor = ParametricDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor, action_preprocessor) serving_module = ParametricDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor) predictor = ParametricDqnTorchPredictor(serving_module) return predictor
def save_models(self, path: str): export_time = round(time.time()) output_path = os.path.expanduser(path) pytorch_output_path = os.path.join(output_path, "trainer_{}.pt".format(export_time)) torchscript_output_path = os.path.join( path, "model_{}.torchscript".format(export_time)) state_preprocessor = Preprocessor(self.state_normalization, False) action_preprocessor = Preprocessor(self.action_normalization, False) q_network = self.trainer.q_network dqn_with_preprocessor = ParametricDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor, action_preprocessor) serving_module = ParametricDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor) logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path)) save_model_to_file(self.trainer, pytorch_output_path) self.save_torchscript_model(serving_module, torchscript_output_path)