def test_discrete_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_values = wrapper(*input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) expected_output = dqn( rlt.PreprocessedState.from_tensor( state_preprocessor(*input_prototype[0]))).q_values self.assertTrue((expected_output == q_values).all())
def test_discrete_wrapper(self): ids = range(1, 5) state_normalization_parameters = {i: _cont_norm() for i in ids} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = models.FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[ rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids ]) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData(state_preprocessor(*state_with_presence))) self.assertTrue((expected_output == q_values).all())
def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) embedding_concat = models.EmbeddingBagConcat( state_dim=len(state_normalization_parameters), model_feature_config=state_feature_config, embedding_dim=8, ) dqn = models.Sequential( embedding_concat, rlt.TensorFeatureData(), models.FullyConnectedDQN( embedding_concat.output_dim, action_dim=action_dim, sizes=[16], activations=["relu"], ), ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype.id_list_features.items() } state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData( float_features=state_preprocessor(*state_with_presence), id_list_features=state_id_list_features, )) self.assertTrue((expected_output == q_values).all())
def test_predictor_torch_export(self): """Verify that q-values before model export equal q-values after model export. Meant to catch issues with export logic.""" environment = Gridworld() samples = Samples( mdp_ids=["0"], sequence_numbers=[0], sequence_number_ordinals=[1], states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}], actions=["D"], action_probabilities=[0.5], rewards=[0], possible_actions=[["R", "D"]], next_states=[{5: 1.0}], next_actions=["U"], terminals=[False], possible_next_actions=[["R", "U", "D"]], ) tdps = environment.preprocess_samples(samples, 1) assert len(tdps) == 1, "Invalid number of data pages" trainer = self.get_trainer(environment, {}, False, False, False) input = rlt.FeatureData(tdps[0].states) pre_export_q_values = trainer.q_network(input).detach().numpy() preprocessor = Preprocessor(environment.normalization, False) cpu_q_network = trainer.q_network.cpu_model() cpu_q_network.eval() dqn_with_preprocessor = DiscreteDqnWithPreprocessor(cpu_q_network, preprocessor) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor, action_names=environment.ACTIONS ) with tempfile.TemporaryDirectory() as tmpdirname: buf = export_module_to_buffer(serving_module) tmp_path = os.path.join(tmpdirname, "model") with open(tmp_path, "wb") as f: f.write(buf.getvalue()) f.close() predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path)) post_export_q_values = predictor.predict([samples.states[0]]) for i, action in enumerate(environment.ACTIONS): self.assertAlmostEqual( float(pre_export_q_values[0][i]), float(post_export_q_values[0][action]), places=4, )
def build_serving_module( self, q_network: ModelBase, state_normalization_data: NormalizationData, action_names: List[str], state_feature_config: rlt.ModelFeatureConfig, ) -> torch.nn.Module: """ Returns a TorchScript predictor module """ state_preprocessor = Preprocessor( state_normalization_data.dense_normalization_parameters, False) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor) return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config)
def build_serving_module( self, q_network: ModelBase, state_normalization_parameters: Dict[int, NormalizationParameters], action_names: List[str], state_feature_config: rlt.ModelFeatureConfig, ) -> torch.nn.Module: """ Returns a TorchScript predictor module """ state_preprocessor = Preprocessor(state_normalization_parameters, False) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( torch.nn.Sequential( # type: ignore q_network.cpu_model().eval(), _Mean()), state_preprocessor, ) return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config)
def save_models(self, path: str): dqn_with_preprocessor = DiscreteDqnWithPreprocessor( self.trainer.q_network.cpu_model().eval(), Preprocessor(self.state_normalization, False), ) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=self.model_params.actions, ) export_time = round(time.time()) output_path = os.path.expanduser(path) pytorch_output_path = os.path.join(output_path, f"trainer_{export_time}.pt") torchscript_output_path = os.path.join( path, "model_{}.torchscript".format(export_time) ) logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path)) save_model_to_file(self.trainer, pytorch_output_path) self.save_torchscript_model(serving_module, torchscript_output_path)
def get_predictor(self, trainer, environment): state_preprocessor = Preprocessor(environment.normalization, False) q_network = trainer.q_network if isinstance(trainer, QRDQNTrainer): class _Mean(torch.nn.Module): def forward(self, input): assert input.ndim == 3 return input.mean(dim=2) q_network = models.Sequential(q_network, _Mean()) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( q_network.cpu_model().eval(), state_preprocessor ) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=environment.ACTIONS, ) predictor = DiscreteDqnTorchPredictor(serving_module) return predictor