예제 #1
0
    def continuous_update_weights(self, shared_storage, replay_buffer):
        # Wait for the replay buffer to be filled
        while (shared_storage.get_info("num_played_games") < 1):
            time.sleep(0.1)

        next_batch = replay_buffer.get_batch()
        # Training loop
        while (self.training_step < self.config.training_steps
               and not shared_storage.get_info("terminate")):
            index_batch, batch = next_batch
            next_batch = replay_buffer.get_batch()
            self.update_lr()
            (
                priorities,
                total_loss,
                value_loss,
                reward_loss,
                policy_loss,
            ) = self.update_weights(batch)

            if (self.config.PER):
                # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933)
                replay_buffer.update_priorities(priorities, index_batch)

            # Save to the shared storage
            if (self.training_step % self.config.checkpoint_interval == 0):
                shared_storage.set_info({
                    "weights":
                    copy.deepcopy(self.model.get_weights()),
                    "optimizer_state":
                    copy.deepcopy(
                        models.dict_to_cpu(self.optimizer.state_dict())),
                })
                if self.config.save_model:
                    shared_storage.save_checkpoint()

            shared_storage.set_info({
                "training_step": self.training_step,
                "lr": self.optimizer.param_groups[0]["lr"],
                "total_loss": total_loss,
                "value_loss": value_loss,
                "reward_loss": reward_loss,
                "policy_loss": policy_loss,
            })

            # Managing the self-play / training ratio
            if self.config.training_delay:
                time.sleep(self.config.training_delay)

            if self.config.ratio:
                while (self.training_step /
                       max(1, shared_storage.get_info("num_played_steps")) >
                       self.config.ratio
                       and self.training_step < self.config.training_steps
                       and not shared_storage.get_info("terminate")):
                    time.sleep(0.5)
예제 #2
0
파일: trainer.py 프로젝트: pikaju/muzero-g
    def continuous_update_weights(self, replay_buffer, shared_storage):
        # Wait for the replay buffer to be filled
        while ray.get(shared_storage.get_info.remote("num_played_games")) < 1:
            time.sleep(0.1)

        # Training loop
        while self.training_step < self.config.training_steps and not ray.get(
                shared_storage.get_info.remote("terminate")):
            index_batch, batch = ray.get(replay_buffer.get_batch.remote())
            self_supervised = self.training_step < self.config.self_supervised_steps
            self.update_lr()
            (
                priorities,
                total_loss,
                value_loss,
                reward_loss,
                policy_loss,
                reconstruction_loss,
                consistency_loss,
                reward_prediction_error,
                value_prediction_error,
            ) = self.update_weights(batch, self_supervised)

            if self.config.PER:
                # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933)
                replay_buffer.update_priorities.remote(priorities, index_batch)

            # Save to the shared storage
            if self.training_step % self.config.checkpoint_interval == 0:
                shared_storage.set_info.remote({
                    "weights":
                    copy.deepcopy(self.model.get_weights()),
                    "optimizer_state":
                    copy.deepcopy(
                        models.dict_to_cpu(self.optimizer.state_dict())),
                })
                if self.config.save_model:
                    shared_storage.save_checkpoint.remote()
            shared_storage.set_info.remote({
                "training_step":
                self.training_step,
                "lr":
                self.optimizer.param_groups[0]["lr"],
                "total_loss":
                total_loss,
                "value_loss":
                value_loss,
                "reward_loss":
                reward_loss,
                "policy_loss":
                policy_loss,
                "reconstruction_loss":
                reconstruction_loss,
                "consistency_loss":
                consistency_loss,
                "reward_prediction_error":
                reward_prediction_error,
                "value_prediction_error":
                value_prediction_error,
            })

            # Managing the self-play / training ratio
            if self.config.training_delay:
                time.sleep(self.config.training_delay)
            if self.config.ratio:
                while (self.training_step / max(
                        1,
                        ray.get(
                            shared_storage.get_info.remote(
                                "num_played_steps"))) > self.config.ratio
                       and self.training_step < self.config.training_steps
                       and not ray.get(
                           shared_storage.get_info.remote("terminate"))):
                    time.sleep(0.5)