class Tester(): def __init__(self, render_flag): self.model = DDQN(36, 36) self.render_flag = render_flag self.width = 6 self.height = 6 self.env = MineSweeper(self.width, self.height, 6) if (self.render_flag): self.renderer = Render(self.env.state) self.load_models(20000) def get_action(self, state): state = state.flatten() mask = (1 - self.env.fog).flatten() action = self.model.act(state, mask) return action def load_models(self, number): path = "pre-trained\ddqn_dnn" + str(number) + ".pth" dict = torch.load(path) self.model.load_state_dict(dict['current_state_dict']) self.model.epsilon = 0 def do_step(self, action): i = int(action / self.width) j = action % self.width if (self.render_flag): self.renderer.state = self.env.state self.renderer.draw() self.renderer.bugfix() next_state, terminal, reward = self.env.choose(i, j) return next_state, terminal, reward
class Play(): def __init__(self): self.width = 20 self.height = 20 self.bombs = 20 self.env = MineSweeper(self.width, self.height, self.bombs) self.renderer = Render(self.env.state) self.renderer.state = self.env.state def do_step(self, i, j): i = int(i / 30) j = int(j / 30) next_state, terminal, reward = self.env.choose(i, j) self.renderer.state = self.env.state self.renderer.draw() return next_state, terminal, reward
class Driver(): def __init__(self,width,height,bomb_no,render_flag): self.width = width self.height = height self.bomb_no = bomb_no self.box_count = width*height self.env = MineSweeper(self.width,self.height,self.bomb_no) self.current_model = DDQN(self.box_count,self.box_count) self.target_model = DDQN(self.box_count,self.box_count) self.target_model.eval() self.optimizer = torch.optim.Adam(self.current_model.parameters(),lr=0.003,weight_decay=1e-5) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,step_size=2000,gamma=0.95) self.target_model.load_state_dict(self.current_model.state_dict()) self.buffer = Buffer(100000) self.gamma = 0.99 self.render_flag = render_flag self.epsilon_min = 0.01 self.epsilon_decay = 0.90 self.reward_threshold = 0.12 self.reward_step = 0.01 self.batch_size = 4096 self.tau = 5e-5 self.log = open("./Logs/ddqn_log.txt",'w') if(self.render_flag): self.Render = Render(self.env.state) def load_models(self,number): path = "./pre-trained/ddqn_dnn"+str(number)+".pth" weights = torch.load(path) self.current_model.load_state_dict(weights['current_state_dict']) self.target_model.load_state_dict(weights['target_state_dict']) self.optimizer.load_state_dict(weights['optimizer_state_dict']) self.current_model.epsilon = weights['epsilon'] ### Get an action from the DDQN model by supplying it State and Mask def get_action(self,state,mask): state = state.flatten() mask = mask.flatten() action = self.current_model.act(state,mask) return action ### Does the action and returns Next State, If terminal, Reward, Next Mask def do_step(self,action): i = int(action/self.width) j = action%self.width if(self.render_flag): self.Render.state = self.env.state self.Render.draw() self.Render.bugfix() next_state,terminal,reward = self.env.choose(i,j) next_fog = 1-self.env.fog return next_state,terminal,reward,next_fog ### Reward Based Epsilon Decay def epsilon_update(self,avg_reward): if(avg_reward>self.reward_threshold): self.current_model.epsilon = max(self.epsilon_min,self.current_model.epsilon*self.epsilon_decay) self.reward_threshold+= self.reward_step def TD_Loss(self): ### Samples batch from buffer memory state,action,mask,reward,next_state,next_mask,terminal = self.buffer.sample(self.batch_size) ### Converts the variabls to tensors for processing by DDQN state = Variable(FloatTensor(float32(state))) mask = Variable(FloatTensor(float32(mask))) next_state = FloatTensor(float32(next_state)) action = LongTensor(float32(action)) next_mask = FloatTensor(float32(next_mask)) reward = FloatTensor(reward) done = FloatTensor(terminal) ### Predicts Q value for present and next state with current and target model q_values = self.current_model(state,mask) next_q_values = self.target_model(next_state,next_mask) # Calculates Loss: # If not Terminal: # Loss = (reward + gamma*Q_val(next_state)) - Q_val(current_state) # If Terminal: # Loss = reward - Q_val(current_state) q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) next_q_value = next_q_values.max(1)[0] expected_q_value = reward + self.gamma * next_q_value * (1 - done) loss = (q_value - expected_q_value.detach()).pow(2).mean() loss_print = loss.item() # Propagates the Loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() for target_param, local_param in zip(self.target_model.parameters(), self.current_model.parameters()): target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data) return loss_print def save_checkpoints(self,batch_no): path = "./pre-trained/ddqn_dnn"+str(batch_no)+".pth" torch.save({ 'epoch': batch_no, 'current_state_dict': self.current_model.state_dict(), 'target_state_dict' : self.target_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon':self.current_model.epsilon }, path) def save_logs(self,batch_no,avg_reward,loss,wins): res = [ str(batch_no), "\tAvg Reward: ", str(avg_reward), "\t Loss: ", str(loss), "\t Wins: ", str(wins), "\t Epsilon: ", str(self.current_model.epsilon) ] log_line = " ".join(res) print(log_line) self.log.write(log_line+"\n") self.log.flush()