示例#1
0
    def test_discrete_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names)
        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_values = wrapper(*input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        expected_output = dqn(
            rlt.PreprocessedState.from_tensor(
                state_preprocessor(*input_prototype[0]))).q_values
        self.assertTrue((expected_output == q_values).all())
 def get_predictor(self, trainer, environment):
     state_preprocessor = Preprocessor(environment.normalization, False)
     q_network = trainer.q_network
     dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor)
     serving_module = DiscreteDqnPredictorWrapper(
         dqn_with_preprocessor=dqn_with_preprocessor,
         action_names=environment.ACTIONS,
     )
     predictor = DiscreteDqnTorchPredictor(serving_module)
     return predictor
示例#3
0
    def test_predictor_torch_export(self):
        """Verify that q-values before model export equal q-values after
        model export. Meant to catch issues with export logic."""
        environment = Gridworld()
        samples = Samples(
            mdp_ids=["0"],
            sequence_numbers=[0],
            sequence_number_ordinals=[1],
            states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}],
            actions=["D"],
            action_probabilities=[0.5],
            rewards=[0],
            possible_actions=[["R", "D"]],
            next_states=[{5: 1.0}],
            next_actions=["U"],
            terminals=[False],
            possible_next_actions=[["R", "U", "D"]],
        )
        tdps = environment.preprocess_samples(samples, 1)
        assert len(tdps) == 1, "Invalid number of data pages"

        trainer, exporter = self.get_modular_sarsa_trainer_exporter(
            environment, {}, False
        )
        input = rlt.PreprocessedState.from_tensor(tdps[0].states)

        pre_export_q_values = trainer.q_network(input).q_values.detach().numpy()

        preprocessor = Preprocessor(environment.normalization, False)
        cpu_q_network = trainer.q_network.cpu_model()
        cpu_q_network.eval()
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(cpu_q_network, preprocessor)
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor, action_names=environment.ACTIONS
        )

        with tempfile.TemporaryDirectory() as tmpdirname:
            buf = export_module_to_buffer(serving_module)
            tmp_path = os.path.join(tmpdirname, "model")
            with open(tmp_path, "wb") as f:
                f.write(buf.getvalue())
                f.close()
                predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path))

        post_export_q_values = predictor.predict([samples.states[0]])

        for i, action in enumerate(environment.ACTIONS):
            self.assertAlmostEqual(
                float(pre_export_q_values[0][i]),
                float(post_export_q_values[0][action]),
                places=4,
            )
 def build_serving_module(
     self,
     q_network: ModelBase,
     state_normalization_parameters: Dict[int, NormalizationParameters],
     action_names: List[str],
     state_feature_config: rlt.ModelFeatureConfig,
 ) -> torch.nn.Module:
     """
     Returns a TorchScript predictor module
     """
     state_preprocessor = Preprocessor(state_normalization_parameters,
                                       False)
     dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor)
     return DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names,
                                        state_feature_config)
示例#5
0
    def save_models(self, path: str):
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            self.trainer.q_network.cpu_model().eval(),
            Preprocessor(self.state_normalization, False),
        )
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor,
            action_names=self.model_params.actions,
        )

        export_time = round(time.time())
        output_path = os.path.expanduser(path)
        pytorch_output_path = os.path.join(output_path,
                                           f"trainer_{export_time}.pt")
        torchscript_output_path = os.path.join(
            path, "model_{}.torchscript".format(export_time))
        logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path))
        save_model_to_file(self.trainer, pytorch_output_path)
        self.save_torchscript_model(serving_module, torchscript_output_path)