def step( self, action: np.ndarray, q_values: np.ndarray = None ) -> Tuple[np.ndarray, np.float64, bool, dict]: """Take an action and store distillation data to buffer storage.""" output = None if (self.args.test and hasattr(self, "memory") and not isinstance(self.memory, DistillationBuffer)): # Teacher training's interim test. output = DQNAgent.step(self, action) else: current_ep_dir = f"{self.save_distillation_dir}/{self.save_count:07}.pkl" self.save_count += 1 if self.args.test: # Generating expert's test-phase data. next_state, reward, done, info = self.env.step(action) with open(current_ep_dir, "wb") as f: pickle.dump([self.curr_state, q_values], f, protocol=pickle.HIGHEST_PROTOCOL) if self.save_count >= self.hyper_params.n_frame_from_last: done = True output = next_state, reward, done, info else: # Teacher training. with open(current_ep_dir, "wb") as f: pickle.dump([self.curr_state], f, protocol=pickle.HIGHEST_PROTOCOL) output = DQNAgent.step(self, action) return output
def _initialize(self): """Initialize non-common things.""" self.save_distillation_dir = None if not self.hyper_params.is_student: # Since raining teacher do not require DistillationBuffer, # it overloads DQNAgent._initialize. print("[INFO] Teacher mode.") DQNAgent._initialize(self) self.make_distillation_dir() else: # Training student or generating distillation data(test). print("[INFO] Student mode.") self.softmax_tau = 0.01 build_args = dict( hyper_params=self.hyper_params, log_cfg=self.log_cfg, env_name=self.env_info.name, state_size=self.env_info.observation_space.shape, output_size=self.env_info.action_space.n, is_test=self.is_test, load_from=self.load_from, ) self.learner = build_learner(self.learner_cfg, build_args) self.dataset_path = self.hyper_params.dataset_path self.memory = DistillationBuffer(self.hyper_params.batch_size, self.dataset_path) if self.is_test: self.make_distillation_dir()
def _initialize(self): """Initialize non-common things.""" self.save_distillation_dir = None if not self.args.student: # Since raining teacher do not require DistillationBuffer, # it overloads DQNAgent._initialize. DQNAgent._initialize(self) self.make_distillation_dir() else: # Training student or generating distillation data(test). self.softmax_tau = 0.01 self.learner = build_learner(self.learner_cfg) self.dataset_path = self.hyper_params.dataset_path self.memory = DistillationBuffer(self.hyper_params.batch_size, self.dataset_path) if self.args.test: self.make_distillation_dir()
def train(self): """Execute appropriate learning code according to the running type.""" if self.args.student: self.memory.reset_dataloader() if not self.memory.is_contain_q: print( "train-phase student training. Generating expert agent Q.." ) assert ( self.args.load_from is not None ), "Train-phase training requires expert agent. Please use load-from argument." self.add_expert_q() self.hyper_params.dataset_path = [self.save_distillation_dir] self.args.load_from = None self._initialize() self.memory.reset_dataloader() print("start student training..") # train student assert self.memory.buffer_size >= self.hyper_params.batch_size if self.args.log: self.set_wandb() iter_1 = self.memory.buffer_size // self.hyper_params.batch_size train_steps = iter_1 * self.hyper_params.epochs print( f"[INFO] Total epochs: {self.hyper_params.epochs}\t Train steps: {train_steps}" ) n_epoch = 0 for steps in range(train_steps): loss = self.update_distillation() if self.args.log: wandb.log({"dqn loss": loss[0], "avg q values": loss[1]}) if steps % iter_1 == 0: print(f"Training {n_epoch} epochs, {steps} steps.. " + f"loss: {loss[0]}, avg_q_value: {loss[1]}") self.learner.save_params(steps) n_epoch += 1 self.memory.reset_dataloader() self.learner.save_params(steps) else: DQNAgent.train(self) if self.hyper_params.n_frame_from_last is not None: # Copy last n_frame in new directory. print("saving last %d frames.." % self.hyper_params.n_frame_from_last) last_frame_dir = ( self.save_distillation_dir + "_last_%d/" % self.hyper_params.n_frame_from_last) os.makedirs(last_frame_dir) # Load directory. episode_dir_list = sorted( os.listdir(self.save_distillation_dir)) episode_dir_list = episode_dir_list[-self.hyper_params. n_frame_from_last:] for _dir in tqdm(episode_dir_list): with open(self.save_distillation_dir + "/" + _dir, "rb") as f: tmp = pickle.load(f) with open(last_frame_dir + _dir, "wb") as f: pickle.dump(tmp, f, protocol=pickle.HIGHEST_PROTOCOL) print("\nsuccessfully saved") print(f"All train-phase dir: {self.save_distillation_dir}/") print( f"last {self.hyper_params.n_frame_from_last} frames dir: {last_frame_dir}" )