Ejemplo n.º 1
0
 def get_predictor(self, trainer, environment):
     state_preprocessor = Preprocessor(environment.normalization, False)
     action_preprocessor = Preprocessor(environment.normalization_action,
                                        False)
     q_network = trainer.q_network
     dqn_with_preprocessor = ParametricDqnWithPreprocessor(
         q_network.cpu_model().eval(), state_preprocessor,
         action_preprocessor)
     serving_module = ParametricDqnPredictorWrapper(
         dqn_with_preprocessor=dqn_with_preprocessor)
     predictor = ParametricDqnTorchPredictor(serving_module)
     return predictor
Ejemplo n.º 2
0
    def _test_parametric_dqn_workflow(self,
                                      use_gpu=False,
                                      use_all_avail_gpus=False):
        """Run Parametric DQN workflow to ensure no crashes, algorithm correctness
        not tested here."""
        with tempfile.TemporaryDirectory() as tmpdirname:
            lockfile = os.path.join(tmpdirname, "multiprocess_lock")
            Path(lockfile).touch()
            params = {
                "training_data_path":
                os.path.join(
                    curr_dir,
                    "test_data/parametric_action/cartpole_training.json.bz2"),
                "eval_data_path":
                os.path.join(
                    curr_dir,
                    "test_data/parametric_action/cartpole_eval.json.bz2"),
                "state_norm_data_path":
                os.path.join(
                    curr_dir,
                    "test_data/parametric_action/state_features_norm.json"),
                "action_norm_data_path":
                os.path.join(curr_dir,
                             "test_data/parametric_action/action_norm.json"),
                "model_output_path":
                tmpdirname,
                "use_gpu":
                use_gpu,
                "use_all_avail_gpus":
                use_all_avail_gpus,
                "init_method":
                "file://" + lockfile,
                "num_nodes":
                1,
                "node_index":
                0,
                "epochs":
                1,
                "rl": {},
                "rainbow": {},
                "training": {
                    "minibatch_size": 128
                },
            }
            parametric_dqn_workflow.main(params)

            predictor_files = glob.glob(tmpdirname + "/model_*.torchscript")
            assert len(
                predictor_files) == 1, "Somehow created two predictor files!"
            predictor = ParametricDqnTorchPredictor(
                torch.jit.load(predictor_files[0]))
            test_float_state_features = [{
                "0": 1.0,
                "1": 1.0,
                "2": 1.0,
                "3": 1.0
            }]
            test_action_features = [{"4": 0.0, "5": 1.0}]
            q_values = predictor.predict(test_float_state_features,
                                         test_action_features)
            assert len(q_values[0].keys()) == 1