def continuous_update_weights(self, shared_storage, replay_buffer): # Wait for the replay buffer to be filled while (shared_storage.get_info("num_played_games") < 1): time.sleep(0.1) next_batch = replay_buffer.get_batch() # Training loop while (self.training_step < self.config.training_steps and not shared_storage.get_info("terminate")): index_batch, batch = next_batch next_batch = replay_buffer.get_batch() self.update_lr() ( priorities, total_loss, value_loss, reward_loss, policy_loss, ) = self.update_weights(batch) if (self.config.PER): # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933) replay_buffer.update_priorities(priorities, index_batch) # Save to the shared storage if (self.training_step % self.config.checkpoint_interval == 0): shared_storage.set_info({ "weights": copy.deepcopy(self.model.get_weights()), "optimizer_state": copy.deepcopy( models.dict_to_cpu(self.optimizer.state_dict())), }) if self.config.save_model: shared_storage.save_checkpoint() shared_storage.set_info({ "training_step": self.training_step, "lr": self.optimizer.param_groups[0]["lr"], "total_loss": total_loss, "value_loss": value_loss, "reward_loss": reward_loss, "policy_loss": policy_loss, }) # Managing the self-play / training ratio if self.config.training_delay: time.sleep(self.config.training_delay) if self.config.ratio: while (self.training_step / max(1, shared_storage.get_info("num_played_steps")) > self.config.ratio and self.training_step < self.config.training_steps and not shared_storage.get_info("terminate")): time.sleep(0.5)
def continuous_update_weights(self, replay_buffer, shared_storage): # Wait for the replay buffer to be filled while ray.get(shared_storage.get_info.remote("num_played_games")) < 1: time.sleep(0.1) # Training loop while self.training_step < self.config.training_steps and not ray.get( shared_storage.get_info.remote("terminate")): index_batch, batch = ray.get(replay_buffer.get_batch.remote()) self_supervised = self.training_step < self.config.self_supervised_steps self.update_lr() ( priorities, total_loss, value_loss, reward_loss, policy_loss, reconstruction_loss, consistency_loss, reward_prediction_error, value_prediction_error, ) = self.update_weights(batch, self_supervised) if self.config.PER: # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933) replay_buffer.update_priorities.remote(priorities, index_batch) # Save to the shared storage if self.training_step % self.config.checkpoint_interval == 0: shared_storage.set_info.remote({ "weights": copy.deepcopy(self.model.get_weights()), "optimizer_state": copy.deepcopy( models.dict_to_cpu(self.optimizer.state_dict())), }) if self.config.save_model: shared_storage.save_checkpoint.remote() shared_storage.set_info.remote({ "training_step": self.training_step, "lr": self.optimizer.param_groups[0]["lr"], "total_loss": total_loss, "value_loss": value_loss, "reward_loss": reward_loss, "policy_loss": policy_loss, "reconstruction_loss": reconstruction_loss, "consistency_loss": consistency_loss, "reward_prediction_error": reward_prediction_error, "value_prediction_error": value_prediction_error, }) # Managing the self-play / training ratio if self.config.training_delay: time.sleep(self.config.training_delay) if self.config.ratio: while (self.training_step / max( 1, ray.get( shared_storage.get_info.remote( "num_played_steps"))) > self.config.ratio and self.training_step < self.config.training_steps and not ray.get( shared_storage.get_info.remote("terminate"))): time.sleep(0.5)