def train_network(params): logger.info("Running DQN workflow with params:") logger.info(params) action_names = np.array(params["actions"]) rl_parameters = RLParameters(**params["rl"]) training_parameters = TrainingParameters(**params["training"]) rainbow_parameters = RainbowDQNParameters(**params["rainbow"]) trainer_params = DiscreteActionModelParameters( actions=params["actions"], rl=rl_parameters, training=training_parameters, rainbow=rainbow_parameters, ) dataset = JSONDataset(params["training_data_path"], batch_size=training_parameters.minibatch_size) norm_data = JSONDataset(params["state_norm_data_path"]) state_normalization = read_norm_params(norm_data.read_all()) num_batches = int(len(dataset) / training_parameters.minibatch_size) logger.info("Read in batch data set {} of size {} examples. Data split " "into {} batches of size {}.".format( params["training_data_path"], len(dataset), num_batches, training_parameters.minibatch_size, )) trainer = DQNTrainer(trainer_params, state_normalization, params["use_gpu"]) for epoch in range(params["epochs"]): for batch_idx in range(num_batches): helpers.report_training_status(batch_idx, num_batches, epoch, params["epochs"]) batch = dataset.read_batch(batch_idx) tdp = preprocess_batch_for_training(action_names, batch, state_normalization) trainer.train(tdp) logger.info("Training finished. Saving PyTorch model to {}".format( params["pytorch_output_path"])) helpers.save_model_to_file(trainer, params["pytorch_output_path"])
def save_models(self, path: str): dqn_with_preprocessor = DiscreteDqnWithPreprocessor( self.trainer.q_network.cpu_model().eval(), Preprocessor(self.state_normalization, False), ) serving_module = DiscreteDqnPredictorWrapper( dqn_with_preprocessor=dqn_with_preprocessor, action_names=self.model_params.actions, ) export_time = round(time.time()) output_path = os.path.expanduser(path) pytorch_output_path = os.path.join(output_path, f"trainer_{export_time}.pt") torchscript_output_path = os.path.join( path, "model_{}.torchscript".format(export_time)) logger.info("Saving PyTorch trainer to {}".format(pytorch_output_path)) save_model_to_file(self.trainer, pytorch_output_path) self.save_torchscript_model(serving_module, torchscript_output_path)