def test_discrete_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_values = wrapper(*input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) expected_output = dqn( rlt.PreprocessedState.from_tensor( state_preprocessor(*input_prototype[0]))).q_values self.assertTrue((expected_output == q_values).all())
def get_predictor(self, trainer, environment): state_preprocessor = Preprocessor(environment.normalization, False) q_network = trainer.q_network dqn_with_preprocessor = DiscreteDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=environment.ACTIONS, ) predictor = DiscreteDqnTorchPredictor(serving_module) return predictor
def test_predictor_torch_export(self): """Verify that q-values before model export equal q-values after model export. Meant to catch issues with export logic.""" environment = Gridworld() samples = Samples( mdp_ids=["0"], sequence_numbers=[0], sequence_number_ordinals=[1], states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}], actions=["D"], action_probabilities=[0.5], rewards=[0], possible_actions=[["R", "D"]], next_states=[{5: 1.0}], next_actions=["U"], terminals=[False], possible_next_actions=[["R", "U", "D"]], ) tdps = environment.preprocess_samples(samples, 1) assert len(tdps) == 1, "Invalid number of data pages" trainer, exporter = self.get_modular_sarsa_trainer_exporter( environment, {}, False ) input = rlt.PreprocessedState.from_tensor(tdps[0].states) pre_export_q_values = trainer.q_network(input).q_values.detach().numpy() preprocessor = Preprocessor(environment.normalization, False) cpu_q_network = trainer.q_network.cpu_model() cpu_q_network.eval() dqn_with_preprocessor = DiscreteDqnWithPreprocessor(cpu_q_network, preprocessor) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor, action_names=environment.ACTIONS ) with tempfile.TemporaryDirectory() as tmpdirname: buf = export_module_to_buffer(serving_module) tmp_path = os.path.join(tmpdirname, "model") with open(tmp_path, "wb") as f: f.write(buf.getvalue()) f.close() predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path)) post_export_q_values = predictor.predict([samples.states[0]]) for i, action in enumerate(environment.ACTIONS): self.assertAlmostEqual( float(pre_export_q_values[0][i]), float(post_export_q_values[0][action]), places=4, )
def build_serving_module( self, q_network: ModelBase, state_normalization_parameters: Dict[int, NormalizationParameters], action_names: List[str], state_feature_config: rlt.ModelFeatureConfig, ) -> torch.nn.Module: """ Returns a TorchScript predictor module """ state_preprocessor = Preprocessor(state_normalization_parameters, False) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor) return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config)
def save_models(self, path: str): dqn_with_preprocessor = DiscreteDqnWithPreprocessor( self.trainer.q_network.cpu_model().eval(), Preprocessor(self.state_normalization, False), ) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=self.model_params.actions, ) export_time = round(time.time()) output_path = os.path.expanduser(path) pytorch_output_path = os.path.join(output_path, f"trainer_{export_time}.pt") torchscript_output_path = os.path.join( path, "model_{}.torchscript".format(export_time)) logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path)) save_model_to_file(self.trainer, pytorch_output_path) self.save_torchscript_model(serving_module, torchscript_output_path)