class Sampler(): def __init__(self,device,actionsize): self.samplenet = DQN(actionsize).to(device) self.targetnet = DQN(actionsize).to(device) self.opt = torch.optim.Adam(itertools.chain(self.samplenet.parameters()),lr=0.00001,betas=(0.0,0.9)) self.device = device self.memory = ReplayMemory(1000,device=device) self.BATCH_SIZE = 10 self.GAMMA = 0.99 self.count = 0 def select_action(self, model): self.samplenet.eval() action = self.samplenet(model.conv2.weight.data.view(-1,5,5).unsqueeze(0)) return torch.max(action,1)[1] def step(self,state,action,next_state,reward,done): self.memory.push(state,action,next_state,reward,done) #don't bother if you don't have enough in memory if len(self.memory) >= self.BATCH_SIZE: self.optimize() def optimize(self): self.samplenet.train() self.targetnet.eval() s1,actions,r1,s2,d = self.memory.sample(self.BATCH_SIZE) #get old Q values and new Q values for belmont eq qvals = self.samplenet(s1) state_action_values = qvals.gather(1,actions[:,0].unsqueeze(1)) with torch.no_grad(): qvals_t = self.targetnet(s2) q1_t = qvals_t.max(1)[0].unsqueeze(1) expected_state_action_values = (q1_t * self.GAMMA) * (1-d) + r1 #LOSS IS l2 loss of belmont equation loss = torch.nn.MSELoss()(state_action_values,expected_state_action_values) self.opt.zero_grad() loss.backward() self.opt.step() if self.count % 20 == 0: self.targetnet.load_state_dict(self.samplenet.state_dict()) return loss.item()
class Generator(): def __init__(self, device, data): self.data = data self.actor = Actor().to(device) self.critic = Critic().to(device) #self.ctarget = Critic().to(device) self.actor_opt = torch.optim.Adam(itertools.chain( self.actor.parameters()), lr=0.0001, betas=(0.0, 0.9)) self.critic_opt = torch.optim.Adam(itertools.chain( self.critic.parameters()), lr=0.001, betas=(0.0, 0.9)) def init_weights(m): if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): torch.nn.init.xavier_uniform_(m.weight.data) self.actor.apply(init_weights) self.critic.apply(init_weights) #self.ctarget.apply(init_weights) self.device = device self.memory = ReplayMemory(1000, device=device) self.batch_size = 5 self.GAMMA = 0.99 self.count = 0 def select_action(self, imgs): with torch.no_grad(): self.actor.eval() action = self.actor(imgs) return action def step(self, state, action, next_state, reward, done): self.memory.push(state, action, next_state, reward, done) if len(self.memory) >= self.batch_size: self.optimize() def optimize(self): self.actor.train() self.critic.train() #self.ctarget.eval() s1, a, r, s2, d = self.memory.sample(self.batch_size) #train the critic for reward, action in zip(r, a): qval = self.critic(action) avgQ = qval.mean().unsqueeze(0) loss = torch.nn.L1Loss()(avgQ, reward) self.critic_opt.zero_grad() loss.backward() self.critic_opt.step() #train the actor img, target = self.data[random.randint(0, len(self.data) - 1)] batch = self.actor(img) score = self.critic(batch) actor_loss = -score.mean() self.actor_opt.zero_grad() actor_loss.backward() self.actor_opt.step() #if self.count % 5 == 0: # self.ctarget.load_state_dict(self.critic.state_dict()) #self.count += 1 def save(self): torch.save(self.actor.state_dict(), os.path.join('model', 'actor.pth')) torch.save(self.critic.state_dict(), os.path.join('model', 'critic.pth'))