Esempio n. 1
0
    def test_discrete_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names)
        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_values = wrapper(*input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        expected_output = dqn(
            rlt.PreprocessedState.from_tensor(
                state_preprocessor(*input_prototype[0]))).q_values
        self.assertTrue((expected_output == q_values).all())
Esempio n. 2
0
    def test_discrete_wrapper(self):
        ids = range(1, 5)
        state_normalization_parameters = {i: _cont_norm() for i in ids}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = models.FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[
            rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids
        ])
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*state_with_presence)))
        self.assertTrue((expected_output == q_values).all())
Esempio n. 3
0
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        embedding_concat = models.EmbeddingBagConcat(
            state_dim=len(state_normalization_parameters),
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn = models.Sequential(
            embedding_concat,
            rlt.TensorFeatureData(),
            models.FullyConnectedDQN(
                embedding_concat.output_dim,
                action_dim=action_dim,
                sizes=[16],
                activations=["relu"],
            ),
        )

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype.id_list_features.items()
        }
        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(
                float_features=state_preprocessor(*state_with_presence),
                id_list_features=state_id_list_features,
            ))
        self.assertTrue((expected_output == q_values).all())
Esempio n. 4
0
    def test_predictor_torch_export(self):
        """Verify that q-values before model export equal q-values after
        model export. Meant to catch issues with export logic."""
        environment = Gridworld()
        samples = Samples(
            mdp_ids=["0"],
            sequence_numbers=[0],
            sequence_number_ordinals=[1],
            states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}],
            actions=["D"],
            action_probabilities=[0.5],
            rewards=[0],
            possible_actions=[["R", "D"]],
            next_states=[{5: 1.0}],
            next_actions=["U"],
            terminals=[False],
            possible_next_actions=[["R", "U", "D"]],
        )
        tdps = environment.preprocess_samples(samples, 1)
        assert len(tdps) == 1, "Invalid number of data pages"

        trainer = self.get_trainer(environment, {}, False, False, False)
        input = rlt.FeatureData(tdps[0].states)

        pre_export_q_values = trainer.q_network(input).detach().numpy()

        preprocessor = Preprocessor(environment.normalization, False)
        cpu_q_network = trainer.q_network.cpu_model()
        cpu_q_network.eval()
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(cpu_q_network, preprocessor)
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor, action_names=environment.ACTIONS
        )

        with tempfile.TemporaryDirectory() as tmpdirname:
            buf = export_module_to_buffer(serving_module)
            tmp_path = os.path.join(tmpdirname, "model")
            with open(tmp_path, "wb") as f:
                f.write(buf.getvalue())
                f.close()
                predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path))

        post_export_q_values = predictor.predict([samples.states[0]])

        for i, action in enumerate(environment.ACTIONS):
            self.assertAlmostEqual(
                float(pre_export_q_values[0][i]),
                float(post_export_q_values[0][action]),
                places=4,
            )
Esempio n. 5
0
 def build_serving_module(
     self,
     q_network: ModelBase,
     state_normalization_data: NormalizationData,
     action_names: List[str],
     state_feature_config: rlt.ModelFeatureConfig,
 ) -> torch.nn.Module:
     """
     Returns a TorchScript predictor module
     """
     state_preprocessor = Preprocessor(
         state_normalization_data.dense_normalization_parameters, False)
     dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor)
     return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names,
                                        state_feature_config)
 def build_serving_module(
     self,
     q_network: ModelBase,
     state_normalization_parameters: Dict[int, NormalizationParameters],
     action_names: List[str],
     state_feature_config: rlt.ModelFeatureConfig,
 ) -> torch.nn.Module:
     """
     Returns a TorchScript predictor module
     """
     state_preprocessor = Preprocessor(state_normalization_parameters,
                                       False)
     dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
         torch.nn.Sequential(  # type: ignore
             q_network.cpu_model().eval(), _Mean()),
         state_preprocessor,
     )
     return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names,
                                        state_feature_config)
Esempio n. 7
0
    def save_models(self, path: str):
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            self.trainer.q_network.cpu_model().eval(),
            Preprocessor(self.state_normalization, False),
        )
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor,
            action_names=self.model_params.actions,
        )

        export_time = round(time.time())
        output_path = os.path.expanduser(path)
        pytorch_output_path = os.path.join(output_path, f"trainer_{export_time}.pt")
        torchscript_output_path = os.path.join(
            path, "model_{}.torchscript".format(export_time)
        )
        logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path))
        save_model_to_file(self.trainer, pytorch_output_path)
        self.save_torchscript_model(serving_module, torchscript_output_path)
Esempio n. 8
0
    def get_predictor(self, trainer, environment):
        state_preprocessor = Preprocessor(environment.normalization, False)
        q_network = trainer.q_network
        if isinstance(trainer, QRDQNTrainer):

            class _Mean(torch.nn.Module):
                def forward(self, input):
                    assert input.ndim == 3
                    return input.mean(dim=2)

            q_network = models.Sequential(q_network, _Mean())

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            q_network.cpu_model().eval(), state_preprocessor
        )
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor,
            action_names=environment.ACTIONS,
        )
        predictor = DiscreteDqnTorchPredictor(serving_module)
        return predictor