示例#1
0
def train_network(params):
    logger.info("Running DQN workflow with params:")
    logger.info(params)

    action_names = np.array(params["actions"])
    rl_parameters = RLParameters(**params["rl"])
    training_parameters = TrainingParameters(**params["training"])
    rainbow_parameters = RainbowDQNParameters(**params["rainbow"])

    trainer_params = DiscreteActionModelParameters(
        actions=params["actions"],
        rl=rl_parameters,
        training=training_parameters,
        rainbow=rainbow_parameters,
    )

    dataset = JSONDataset(params["training_data_path"],
                          batch_size=training_parameters.minibatch_size)
    norm_data = JSONDataset(params["state_norm_data_path"])
    state_normalization = read_norm_params(norm_data.read_all())

    num_batches = int(len(dataset) / training_parameters.minibatch_size)

    logger.info("Read in batch data set {} of size {} examples. Data split "
                "into {} batches of size {}.".format(
                    params["training_data_path"],
                    len(dataset),
                    num_batches,
                    training_parameters.minibatch_size,
                ))

    trainer = DQNTrainer(trainer_params, state_normalization,
                         params["use_gpu"])

    for epoch in range(params["epochs"]):
        for batch_idx in range(num_batches):
            helpers.report_training_status(batch_idx, num_batches, epoch,
                                           params["epochs"])
            batch = dataset.read_batch(batch_idx)
            tdp = preprocess_batch_for_training(action_names, batch,
                                                state_normalization)
            trainer.train(tdp)

    logger.info("Training finished. Saving PyTorch model to {}".format(
        params["pytorch_output_path"]))
    helpers.save_model_to_file(trainer, params["pytorch_output_path"])
示例#2
0
    def save_models(self, path: str):
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            self.trainer.q_network.cpu_model().eval(),
            Preprocessor(self.state_normalization, False),
        )
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor=dqn_with_preprocessor,
            action_names=self.model_params.actions,
        )

        export_time = round(time.time())
        output_path = os.path.expanduser(path)
        pytorch_output_path = os.path.join(output_path,
                                           f"trainer_{export_time}.pt")
        torchscript_output_path = os.path.join(
            path, "model_{}.torchscript".format(export_time))
        logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path))
        save_model_to_file(self.trainer, pytorch_output_path)
        self.save_torchscript_model(serving_module, torchscript_output_path)