示例#1
0
    def step(
        self,
        action: np.ndarray,
        q_values: np.ndarray = None
    ) -> Tuple[np.ndarray, np.float64, bool, dict]:
        """Take an action and store distillation data to buffer storage."""

        output = None
        if (self.args.test and hasattr(self, "memory")
                and not isinstance(self.memory, DistillationBuffer)):
            # Teacher training's interim test.
            output = DQNAgent.step(self, action)
        else:
            current_ep_dir = f"{self.save_distillation_dir}/{self.save_count:07}.pkl"
            self.save_count += 1
            if self.args.test:
                # Generating expert's test-phase data.
                next_state, reward, done, info = self.env.step(action)
                with open(current_ep_dir, "wb") as f:
                    pickle.dump([self.curr_state, q_values],
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
                if self.save_count >= self.hyper_params.n_frame_from_last:
                    done = True
                output = next_state, reward, done, info
            else:
                # Teacher training.
                with open(current_ep_dir, "wb") as f:
                    pickle.dump([self.curr_state],
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
                output = DQNAgent.step(self, action)

        return output
示例#2
0
    def _initialize(self):
        """Initialize non-common things."""
        self.save_distillation_dir = None
        if not self.hyper_params.is_student:
            # Since raining teacher do not require DistillationBuffer,
            # it overloads DQNAgent._initialize.
            print("[INFO] Teacher mode.")
            DQNAgent._initialize(self)
            self.make_distillation_dir()
        else:
            # Training student or generating distillation data(test).
            print("[INFO] Student mode.")
            self.softmax_tau = 0.01

            build_args = dict(
                hyper_params=self.hyper_params,
                log_cfg=self.log_cfg,
                env_name=self.env_info.name,
                state_size=self.env_info.observation_space.shape,
                output_size=self.env_info.action_space.n,
                is_test=self.is_test,
                load_from=self.load_from,
            )
            self.learner = build_learner(self.learner_cfg, build_args)
            self.dataset_path = self.hyper_params.dataset_path

            self.memory = DistillationBuffer(self.hyper_params.batch_size,
                                             self.dataset_path)
            if self.is_test:
                self.make_distillation_dir()
示例#3
0
    def _initialize(self):
        """Initialize non-common things."""
        self.save_distillation_dir = None
        if not self.args.student:
            # Since raining teacher do not require DistillationBuffer,
            # it overloads DQNAgent._initialize.

            DQNAgent._initialize(self)
            self.make_distillation_dir()
        else:
            # Training student or generating distillation data(test).

            self.softmax_tau = 0.01
            self.learner = build_learner(self.learner_cfg)
            self.dataset_path = self.hyper_params.dataset_path

            self.memory = DistillationBuffer(self.hyper_params.batch_size,
                                             self.dataset_path)
            if self.args.test:
                self.make_distillation_dir()
示例#4
0
    def train(self):
        """Execute appropriate learning code according to the running type."""
        if self.args.student:
            self.memory.reset_dataloader()
            if not self.memory.is_contain_q:
                print(
                    "train-phase student training. Generating expert agent Q.."
                )
                assert (
                    self.args.load_from is not None
                ), "Train-phase training requires expert agent. Please use load-from argument."
                self.add_expert_q()
                self.hyper_params.dataset_path = [self.save_distillation_dir]
                self.args.load_from = None
                self._initialize()
                self.memory.reset_dataloader()
                print("start student training..")

            # train student
            assert self.memory.buffer_size >= self.hyper_params.batch_size
            if self.args.log:
                self.set_wandb()

            iter_1 = self.memory.buffer_size // self.hyper_params.batch_size
            train_steps = iter_1 * self.hyper_params.epochs
            print(
                f"[INFO] Total epochs: {self.hyper_params.epochs}\t Train steps: {train_steps}"
            )
            n_epoch = 0
            for steps in range(train_steps):
                loss = self.update_distillation()

                if self.args.log:
                    wandb.log({"dqn loss": loss[0], "avg q values": loss[1]})

                if steps % iter_1 == 0:
                    print(f"Training {n_epoch} epochs, {steps} steps.. " +
                          f"loss: {loss[0]}, avg_q_value: {loss[1]}")
                    self.learner.save_params(steps)
                    n_epoch += 1
                    self.memory.reset_dataloader()

            self.learner.save_params(steps)

        else:
            DQNAgent.train(self)

            if self.hyper_params.n_frame_from_last is not None:
                # Copy last n_frame in new directory.
                print("saving last %d frames.." %
                      self.hyper_params.n_frame_from_last)
                last_frame_dir = (
                    self.save_distillation_dir +
                    "_last_%d/" % self.hyper_params.n_frame_from_last)
                os.makedirs(last_frame_dir)

                # Load directory.
                episode_dir_list = sorted(
                    os.listdir(self.save_distillation_dir))
                episode_dir_list = episode_dir_list[-self.hyper_params.
                                                    n_frame_from_last:]
                for _dir in tqdm(episode_dir_list):
                    with open(self.save_distillation_dir + "/" + _dir,
                              "rb") as f:
                        tmp = pickle.load(f)
                    with open(last_frame_dir + _dir, "wb") as f:
                        pickle.dump(tmp, f, protocol=pickle.HIGHEST_PROTOCOL)
                print("\nsuccessfully saved")
                print(f"All train-phase dir: {self.save_distillation_dir}/")
                print(
                    f"last {self.hyper_params.n_frame_from_last} frames dir: {last_frame_dir}"
                )