def build(self): self.policy_net1 = DQN2D(84, 84, self.pars).to(self.device) self.target_net1 = DQN2D(84, 84, self.pars).to(self.device) self.target_net1.load_state_dict(self.policy_net1.state_dict()) self.target_net1.eval() self.policy_net2 = DQN2D(84, 84, self.pars).to(self.device) self.target_net2 = DQN2D(84, 84, self.pars).to(self.device) self.target_net2.load_state_dict(self.policy_net2.state_dict()) self.target_net2.eval() self.optimizer1 = optim.SGD(self.policy_net1.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum']) # self.optimizer2 = optim.SGD(self.policy_net2.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum']) # #self.optimizer1 = optim.Adam(self.policy_net1.parameters()) #self.optimizer2 = optim.Adam(self.policy_net2.parameters()) self.memory2 = ReplayMemory(10000) self.memory1 = ReplayMemory(10000) if self.pars['ppe'] == '1': self.memory1 = Memory(10000) self.memory2 = Memory(10000)
def __init__(self, env, exploration_method='epsilon_greedy'): self.action_space = env.action_space self.buffer = ReplayMemory(BUFFER_SIZE, MINI_BATCH_SIZE) self.exploration_method = exploration_method self.brain = MLP(env.observation_space.shape[0], env.action_space.n) self.brain_bis = copy.deepcopy(self.brain) self.Tensor = torch.Tensor self.LongTensor = torch.LongTensor self.cpt = 0 self.optimizer = optim.Adam(self.brain.parameters())
def build(self): self.policy_net = DQN(71, self.pars).to(self.device) self.q_net = DQN(71, self.pars).to(self.device) self.target_net = DQN(71, self.pars).to(self.device) self.target_net.load_state_dict(self.q_net.state_dict()) self.target_net.eval() if self.pars['momentum']>0: self.optimizer = optim.SGD( self.q_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# self.policy_optimizer = optim.SGD( self.policy_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# else: self.optimizer = optim.Adam(self.q_net.parameters()) self.policy_optimizer = optim.Adam(self.policy_net.parameters()) self.memory = ReplayMemory(10000) self.eps_threshold = 0.01 self.bufs = [[] for _ in range(len(self.envs)*2)]
def __init__(self, env, args, work_dir): self.env = env self.args = args self.work_dir = work_dir self.n_action = self.env.action_space.n self.arr_actions = np.arange(self.n_action) self.memory = ReplayMemory(self.args.buffer_size, self.args.device) self.qNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.targetNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.qNetwork.train() self.targetNetwork.eval() self.optimizer = optim.RMSprop(self.qNetwork.parameters(), lr=0.00025, eps=0.001, alpha=0.95) self.crit = nn.MSELoss() self.eps = max(self.args.eps, self.args.eps_min) self.eps_delta = ( self.eps - self.args.eps_min) / self.args.exploration_decay_speed
def __init__(self, env, lr=3e-4, gamma=0.99, polyak=5e-3, alpha=0.2, reward_scale=1.0, cuda=True, writer=None): state_size = env.observation_space.shape[0] action_size = env.action_space.shape[0] self.actor = Actor(state_size, action_size) self.critic = Critic(state_size, action_size) self.target_critic = Critic(state_size, action_size).eval() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr) self.q1_optimizer = optim.Adam(self.critic.q1.parameters(), lr=lr) self.q2_optimizer = optim.Adam(self.critic.q2.parameters(), lr=lr) self.target_critic.load_state_dict(self.critic.state_dict()) for param in self.target_critic.parameters(): param.requires_grad = False self.memory = ReplayMemory() self.gamma = gamma self.alpha = alpha self.polyak = polyak # Always between 0 and 1, usually close to 1 self.reward_scale = reward_scale self.writer = writer self.cuda = cuda if cuda: self.actor = self.actor.to('cuda') self.critic = self.critic.to('cuda') self.target_critic = self.target_critic.to('cuda')
def build(self): self.policy_net = DQN(97, self.pars, rec=self.pars['rec'] == 1).to(self.device) self.target_net = DQN(97, self.pars, rec=self.pars['rec'] == 1).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() if self.pars['momentum'] > 0: self.optimizer = optim.SGD(self.policy_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum']) # else: self.optimizer = optim.Adam(self.policy_net.parameters()) self.memory = ReplayMemory(10000) if 'ppe' in self.pars: self.memory = Memory(10000) if self.pars['load'] is not None: self.load(self.pars['load']) self.target_net.load_state_dict(self.policy_net.state_dict()) print('loaded')
def build(self): self.policy_net = DQN2D(84,84, self.pars, rec=self.pars['rec']==1).to(self.device) self.q_net = DQN2D(84,84, self.pars).to(self.device) self.target_net = DQN2D(84,84, self.pars).to(self.device) self.target_net.load_state_dict(self.q_net.state_dict()) self.target_net.eval() if self.pars['momentum']>0: self.optimizer = optim.SGD( self.q_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# self.policy_optimizer = optim.SGD( self.policy_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# else: self.optimizer = optim.Adam(self.q_net.parameters()) self.policy_optimizer = optim.Adam(self.policy_net.parameters()) self.memory = ReplayMemory(10000) if self.pars['ppe'] == '1': self.memory = Memory(10000) self.eps_threshold = 0.01
class DQNAgent(object): def __init__(self, env, args, work_dir): self.env = env self.args = args self.work_dir = work_dir self.n_action = self.env.action_space.n self.arr_actions = np.arange(self.n_action) self.memory = ReplayMemory(self.args.buffer_size, self.args.device) self.qNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.targetNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.qNetwork.train() self.targetNetwork.eval() self.optimizer = optim.RMSprop(self.qNetwork.parameters(), lr=0.00025, eps=0.001, alpha=0.95) self.crit = nn.MSELoss() self.eps = max(self.args.eps, self.args.eps_min) self.eps_delta = ( self.eps - self.args.eps_min) / self.args.exploration_decay_speed def reset(self): return torch.cat([preprocess_state(self.env.reset(), self.env)] * 4, 1) def select_action(self, state): action_prob = np.zeros(self.n_action, np.float32) action_prob.fill(self.eps / self.n_action) max_q, max_q_index = self.qNetwork(Variable(state.to( self.args.device))).data.cpu().max(1) action_prob[max_q_index[0]] += 1 - self.eps action = np.random.choice(self.arr_actions, p=action_prob) next_state, reward, done, _ = self.env.step(action) next_state = torch.cat( [state.narrow(1, 1, 3), preprocess_state(next_state, self.env)], 1) self.memory.push( (state, torch.LongTensor([int(action)]), torch.Tensor([reward]), next_state, torch.Tensor([done]))) return next_state, reward, done, max_q[0] def run(self): state = self.reset() # init buffer for _ in range(self.args.buffer_init_size): next_state, _, done, _ = self.select_action(state) state = self.reset() if done else next_state total_frame = 0 reward_list = np.zeros(self.args.log_size, np.float32) qval_list = np.zeros(self.args.log_size, np.float32) start_time = time.time() for epi in count(): reward_list[epi % self.args.log_size] = 0 qval_list[epi % self.args.log_size] = -1e9 state = self.reset() done = False ep_len = 0 if epi % self.args.save_freq == 0: model_file = os.path.join(self.work_dir, 'model.th') with open(model_file, 'wb') as f: torch.save(self.qNetwork, f) while not done: if total_frame % self.args.sync_period == 0: self.targetNetwork.load_state_dict( self.qNetwork.state_dict()) self.eps = max(self.args.eps_min, self.eps - self.eps_delta) next_state, reward, done, qval = self.select_action(state) reward_list[epi % self.args.log_size] += reward qval_list[epi % self.args.log_size] = max( qval_list[epi % self.args.log_size], qval) state = next_state total_frame += 1 ep_len += 1 if ep_len % self.args.learn_freq == 0: batch_state, batch_action, batch_reward, batch_next_state, batch_done = self.memory.sample( self.args.batch_size) batch_q = self.qNetwork(batch_state).gather( 1, batch_action.unsqueeze(1)).squeeze(1) batch_next_q = self.targetNetwork(batch_next_state).detach( ).max(1)[0] * self.args.gamma * (1 - batch_done) loss = self.crit(batch_q, batch_reward + batch_next_q) self.optimizer.zero_grad() loss.backward() self.optimizer.step() output_str = 'episode %d frame %d time %.2fs cur_rew %.3f mean_rew %.3f cur_maxq %.3f mean_maxq %.3f' % ( epi, total_frame, time.time() - start_time, reward_list[epi % self.args.log_size], np.mean(reward_list), qval_list[epi % self.args.log_size], np.mean(qval_list)) print(output_str) logging.info(output_str)
class AgentSep1D(Agent): def __init__(self, name, pars, nrenvs=1, job=None, experiment=None): Agent.__init__(self, name, pars, nrenvs, job, experiment) def build(self): self.policy_net1 = DQN(71, self.pars).to(self.device) self.target_net1 = DQN(71, self.pars).to(self.device) self.target_net1.load_state_dict(self.policy_net1.state_dict()) self.target_net1.eval() self.policy_net2 = DQN(71, self.pars).to(self.device) self.target_net2 = DQN(71, self.pars).to(self.device) self.target_net2.load_state_dict(self.policy_net2.state_dict()) self.target_net2.eval() self.optimizer1 = optim.SGD(self.policy_net1.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum']) # self.optimizer2 = optim.SGD(self.policy_net2.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum']) # self.optimizer1 = optim.Adam(self.policy_net1.parameters()) self.optimizer2 = optim.Adam(self.policy_net2.parameters()) self.memory2 = ReplayMemory(10000) self.memory1 = ReplayMemory(10000) def getaction(self, state1, state2, test=False): mes = torch.tensor([[0, 0, 0, 0]], device=self.device) comm2 = self.policy_net1( state2, 0, mes)[self.idC].detach() if np.random.rand() < self.prob else mes comm1 = self.policy_net2( state1, 0, mes)[self.idC].detach() if np.random.rand() < self.prob else mes if test: action1 = self.policy_net1(state1, 1, comm2)[0].max(1)[1].view(1, 1) action2 = self.policy_net2(state2, 1, comm1)[0].max(1)[1].view(1, 1) else: action1 = self.select_action(state1, comm2, self.policy_net1) action2 = self.select_action(state2, comm1, self.policy_net2) return action1, action2, [comm1, comm2] def getStates(self, env): screen1 = env.render_env_1d() #.transpose((2, 0, 1)) return torch.from_numpy(screen1).unsqueeze(0).to( self.device), torch.from_numpy(screen1).unsqueeze(0).to( self.device) def saveStates(self, state1, state2, action1, action2, next_state1, next_state2, reward1, reward2, env_id): self.capmem += 2 if self.pars['ppe'] != '1': self.memory2.push(state2, action2, next_state2, reward2, state1) self.memory1.push(state1, action1, next_state1, reward1, state2) else: self.memory1.store([state1, action1, next_state1, reward1, state2]) self.memory2.store([state2, action2, next_state2, reward2, state1]) #self.memory2.push(state2, action2, next_state2, reward2, state1) #self.memory1.push(state1, action1, next_state1, reward1, state2) def optimize(self): self.optimize_model(self.policy_net1, self.target_net1, self.memory1, self.optimizer1) self.optimize_model(self.policy_net2, self.target_net2, self.memory2, self.optimizer2) def updateTarget(self, i_episode, step=False): #soft_update(self.target_net, self.policy_net, tau=0.01) if step: return if i_episode % self.TARGET_UPDATE == 0: self.target_net1.load_state_dict(self.policy_net1.state_dict()) self.target_net2.load_state_dict(self.policy_net2.state_dict()) def save(self): torch.save(self.policy_net1.state_dict(), self.pars['results_path'] + self.name + '/model1') torch.save(self.policy_net2.state_dict(), self.pars['results_path'] + self.name + '/model2') def perturb_learning_rate(self, i_episode, nolast=True): if nolast: new_lr_factor = 10**np.random.normal(scale=1.0) new_momentum_delta = np.random.normal(scale=0.1) self.EPS_DECAY += np.random.normal(scale=50.0) if self.EPS_DECAY < 50: self.EPS_DECAY = 50 if self.prob >= 0: self.prob += np.random.normal(scale=0.05) - 0.025 self.prob = min(max(0, self.prob), 1) for param_group in self.optimizer1.param_groups: if nolast: param_group['lr'] *= new_lr_factor param_group['momentum'] += new_momentum_delta self.momentum1 = param_group['momentum'] self.lr1 = param_group['lr'] if nolast: new_lr_factor = 10**np.random.normal(scale=1.0) new_momentum_delta = np.random.normal(scale=0.1) for param_group in self.optimizer2.param_groups: if nolast: param_group['lr'] *= new_lr_factor param_group['momentum'] += new_momentum_delta self.momentum2 = param_group['momentum'] self.lr2 = param_group['lr'] with open( os.path.join(self.pars['results_path'] + self.name, 'hyper-{}.json').format(i_episode), 'w') as outfile: json.dump( { 'lr1': self.lr1, 'momentum1': self.momentum1, 'lr2': self.lr2, 'momentum2': self.momentum2, 'eps_decay': self.EPS_DECAY, 'prob': self.prob, 'i_episode': i_episode }, outfile) def clone(self, agent): state_dict = agent.policy_net1.state_dict() self.policy_net1.load_state_dict(state_dict) state_dict = agent.optimizer1.state_dict() self.optimizer1.load_state_dict(state_dict) state_dict = agent.policy_net2.state_dict() self.policy_net2.load_state_dict(state_dict) state_dict = agent.optimizer2.state_dict() self.optimizer2.load_state_dict(state_dict) self.target_net1.load_state_dict(self.policy_net1.state_dict()) self.target_net2.load_state_dict(self.policy_net2.state_dict()) self.EPS_DECAY = agent.EPS_DECAY
class Agent(object): def __init__(self, env, exploration_method='epsilon_greedy'): self.action_space = env.action_space self.buffer = ReplayMemory(BUFFER_SIZE, MINI_BATCH_SIZE) self.exploration_method = exploration_method self.brain = MLP(env.observation_space.shape[0], env.action_space.n) self.brain_bis = copy.deepcopy(self.brain) self.Tensor = torch.Tensor self.LongTensor = torch.LongTensor self.cpt = 0 self.optimizer = optim.Adam(self.brain.parameters()) def act(self, ob, reward, done): if self.exploration_method == 'epsilon_greedy': if random.random() < EPSILON: return self.action_space.sample() else: return self.get_best_action(ob) def learn(self, prev_ob, action, ob, reward, done): self.cpt += 1 if self.cpt % 10 == 0: self.brain_bis = copy.deepcopy(self.brain) self.buffer.add(prev_ob, action, ob, reward, done) batch = self.buffer.get_minibatch() [states, actions, next_states, rewards, dones] = zip(*batch) state_batch = Variable(self.Tensor(states)) action_batch = Variable(self.LongTensor(actions)) reward_batch = Variable(self.Tensor(rewards)) next_states_batch = Variable(self.Tensor(next_states)) state_action_values = self.brain(state_batch).gather(1, action_batch.view(-1,1)).view(-1) with torch.no_grad(): next_state_values = self.brain_bis(next_states_batch).max(1)[0] for i in range(len(batch)): if dones[i]: next_state_values.data[i]=0 # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch loss = F.mse_loss(state_action_values, expected_state_action_values) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm(self.brain.parameters(), 10) # Clip gradients (normalising by max value of gradient L2 norm) self.optimizer.step() ''' for interaction in minibatch: self.cpt += 1 if self.cpt % 100: self.brain_bis = copy.deepcopy(self.brain) q_vals_pred = self.brain(torch.tensor(interaction[0]).float()) q_vals_pred_next = self.brain_bis(torch.tensor(interaction[2]).float()) q_vals = [None for _ in range(self.action_space.n)] if interaction[4]: for i in range(self.action_space.n): q_vals[i] = (q_vals_pred[i] - interaction[3]) ** 2 else: for i in range(self.action_space.n): q_vals[i] = (q_vals_pred[i] - (interaction[3] + GAMMA * torch.max(q_vals_pred_next).item())) ** 2 self.brain.train() loss = self.brain.loss(q_vals_pred, torch.tensor(q_vals).float()) loss = Variable(loss, requires_grad=True) loss.backward() ''' def get_q_val(self, ob, action): return self.get_q_vals(ob)[action] def get_best_action(self, ob): index = torch.argmax(self.get_q_vals(ob), 0) print(index) return index.int().item() def get_max_q_val(self, ob): val, _ = torch.max(self.get_q_vals(ob), 0) return val.float().item() def get_q_vals(self, ob): self.brain.eval() with torch.no_grad(): ob = torch.tensor(ob).float().unsqueeze(0) print(ob.shape) output = self.brain(ob) print(output.shape) return output
class AgentACShare1D(Agent): def __init__(self, name, pars, nrenvs=1, job=None, experiment=None): Agent.__init__(self,name, pars, nrenvs, job, experiment) def build(self): self.policy_net = DQN(71, self.pars).to(self.device) self.q_net = DQN(71, self.pars).to(self.device) self.target_net = DQN(71, self.pars).to(self.device) self.target_net.load_state_dict(self.q_net.state_dict()) self.target_net.eval() if self.pars['momentum']>0: self.optimizer = optim.SGD( self.q_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# self.policy_optimizer = optim.SGD( self.policy_net.parameters(), lr=self.pars['lr'], momentum=self.pars['momentum'])# else: self.optimizer = optim.Adam(self.q_net.parameters()) self.policy_optimizer = optim.Adam(self.policy_net.parameters()) self.memory = ReplayMemory(10000) self.eps_threshold = 0.01 self.bufs = [[] for _ in range(len(self.envs)*2)] def updateTarget(self, i_episode, step=False): #soft_update(self.target_net, self.policy_net, tau=0.01) if step: return self.optimize_policy(self.policy_net, self.bufs, self.policy_optimizer) if i_episode % self.TARGET_UPDATE == 0: self.target_net.load_state_dict(self.q_net.state_dict()) self.eps_threshold -= 0.001 def saveStates(self, state1, state2, action1,action2, next_state1,next_state2, reward1,reward2, env_id): logp1, ent1, logp2, ent2 = self.rem if self.pars['ppe']!='1': self.memory.push(state2, action2, next_state2, reward2, state1) self.memory.push(state1, action1, next_state1, reward1, state2) else: self.memory.store([state1, action1, next_state1, reward1, state2]) self.memory.store([state2, action2, next_state2, reward2, state1]) #self.buf2.append([state2, action2,1, reward2, logp2, ent2]) #self.buf1.append([state1, action1,1, reward1, logp1, ent1]) self.bufs[2*env_id ].append([state2, action2,1, reward2, logp2, ent2]) self.bufs[2*env_id+1].append([state1, action1,1, reward1, logp1, ent1]) def select_action(self, state, comm, policy_net): probs1, _ = policy_net(state, 1, comm)#.cpu().data.numpy() m = Categorical(logits=probs1) action = m.sample() return action.view(1, 1), m.log_prob(action), m.entropy() def getComm(self, mes, policy_net, state1_batch): return self.policy_net(state1_batch, 1, mes)[self.idC].detach() if np.random.rand()<self.prob else mes def getaction(self, state1, state2, test=False): mes = torch.tensor([[0,0,0,0]], device=self.device) #maybe error comm2 = self.policy_net(state2, 0, mes)[self.idC] if (test and 0<self.prob) or np.random.rand()<self.prob else mes comm1 = self.policy_net(state1, 0, mes)[self.idC] if (test and 0<self.prob) or np.random.rand()<self.prob else mes action1, logp1, ent1 = self.select_action(state1, comm2, self.policy_net) action2, logp2, ent2 = self.select_action(state2, comm1, self.policy_net) self.rem =[logp1, ent1, logp2, ent2] return action1, action2, [comm1, comm2] def optimize_policy(self, policy_net, memories, optimizer): policy_loss = 0 value_loss = 0 ent = 0 for memory in memories:#[memory1, memory2]: R = torch.zeros(1, 1, device=self.device) #GAE = torch.zeros(1, 1, device=self.device) saved_r = torch.cat([c[3].float() for c in memory]) states = torch.cat([c[0].float() for c in memory]) action_batch = torch.cat([c[1].float() for c in memory]).view(-1,1) mes = torch.tensor([[0,0,0,0] for i in memory], device=self.device) actionV = self.q_net(states, 0, mes)[0].gather(1, action_batch.long()) mu = saved_r.mean() std = saved_r.std() eps = 0.000001 #print(memory) for i in reversed(range(len(memory)-1)): _,_,_,r,log_prob, entr = memory[i] ac = (actionV[i] - mu) / (std + eps)#actionV[i]#also use mu and std #Discounted Sum of Future Rewards + reward for the given state R = self.GAMMA * R + (r.float() - mu) / (std + eps) advantage = R - ac policy_loss += -log_prob *advantage .detach() #ent += entr#*0 optimizer.zero_grad() (policy_loss.mean() + self.eps_threshold*ent).backward() for param in policy_net.parameters(): if param.grad is not None: param.grad.data.clamp_(-1, 1) optimizer.step() def save(self): torch.save(self.policy_net.state_dict(), self.pars['results_path']+self.name+'/model') torch.save(self.q_net.state_dict(), self.pars['results_path']+self.name+'/modelQ') def load(self, PATH): #torch.cuda.is_available() self.policy_net.load_state_dict(torch.load(PATH, map_location= 'cuda' if torch.cuda.is_available() else 'cpu')) self.q_net.load_state_dict(torch.load(PATH+'Q', map_location= 'cuda' if torch.cuda.is_available() else 'cpu')) self.target_net.load_state_dict(self.q_net.state_dict()) def optimize(self): self.optimize_model(self.q_net, self.target_net, self.memory, self.optimizer) def perturb_learning_rate(self, i_episode, nolast=True): if nolast: new_lr_factor = 10**np.random.normal(scale=1.0) new_momentum_delta = np.random.normal(scale=0.1) self.eps_threshold += np.random.normal(scale=0.1) self.alpha += np.random.normal(scale=0.1) if self.alpha>1: self.alpha = 1 if self.alpha<0.5: self.alpha = 0.5 if self.eps_threshold<0: self.eps_threshold = 0.00001 self.EPS_DECAY += np.random.normal(scale=50.0) if self.EPS_DECAY<50: self.EPS_DECAY = 50 if self.prob>=0: self.prob += np.random.normal(scale=0.05)-0.025 self.prob = min(max(0,self.prob),1) for param_group in self.optimizer.param_groups: if nolast: param_group['lr'] *= new_lr_factor param_group['momentum'] += new_momentum_delta self.momentum =param_group['momentum'] self.lr = param_group['lr'] if nolast: new_lr_factor = 10**np.random.normal(scale=1.0) new_momentum_delta = np.random.normal(scale=0.1) for param_group in self.policy_optimizer.param_groups: if nolast: param_group['lr'] *= new_lr_factor param_group['momentum'] += new_momentum_delta self.momentum1 =param_group['momentum'] self.lr1 = param_group['lr'] with open(os.path.join(self.pars['results_path']+ self.name,'hyper-{}.json').format(i_episode), 'w') as outfile: json.dump({'lr':self.lr, 'momentum':self.momentum, 'alpha':self.alpha, 'lr1':self.lr1, 'momentum1':self.momentum1,'eps_decay':self.EPS_DECAY, 'eps_entropy':self.eps_threshold, 'prob':self.prob,'i_episode':i_episode}, outfile) def clone(self, agent): state_dict = agent.policy_net.state_dict() self.policy_net.load_state_dict(state_dict) state_dict = agent.policy_optimizer.state_dict() self.policy_optimizer.load_state_dict(state_dict) self.alpha = agent.alpha state_dict = agent.q_net.state_dict() self.q_net.load_state_dict(state_dict) state_dict = agent.optimizer.state_dict() self.optimizer.load_state_dict(state_dict) self.target_net.load_state_dict(self.q_net.state_dict()) self.EPS_DECAY = agent.EPS_DECAY self.eps_threshold = agent.eps_threshold self.prob = agent.prob
class SAC: def __init__(self, env, lr=3e-4, gamma=0.99, polyak=5e-3, alpha=0.2, reward_scale=1.0, cuda=True, writer=None): state_size = env.observation_space.shape[0] action_size = env.action_space.shape[0] self.actor = Actor(state_size, action_size) self.critic = Critic(state_size, action_size) self.target_critic = Critic(state_size, action_size).eval() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr) self.q1_optimizer = optim.Adam(self.critic.q1.parameters(), lr=lr) self.q2_optimizer = optim.Adam(self.critic.q2.parameters(), lr=lr) self.target_critic.load_state_dict(self.critic.state_dict()) for param in self.target_critic.parameters(): param.requires_grad = False self.memory = ReplayMemory() self.gamma = gamma self.alpha = alpha self.polyak = polyak # Always between 0 and 1, usually close to 1 self.reward_scale = reward_scale self.writer = writer self.cuda = cuda if cuda: self.actor = self.actor.to('cuda') self.critic = self.critic.to('cuda') self.target_critic = self.target_critic.to('cuda') def explore(self, state): if self.cuda: state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float) action, _, _ = self.actor.sample(state) # action, _ = self.actor(state) return action.cpu().detach().numpy().reshape(-1) def exploit(self, state): if self.cuda: state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float) _, _, action = self.actor.sample(state) return action.cpu().detach().numpy().reshape(-1) def store_step(self, state, action, next_state, reward, terminal): state = to_tensor_unsqueeze(state) if action.dtype == np.float32: action = torch.from_numpy(action) next_state = to_tensor_unsqueeze(next_state) reward = torch.from_numpy(np.array([reward]).astype(np.float)) terminal = torch.from_numpy(np.array([terminal]).astype(np.uint8)) self.memory.push(state, action, next_state, reward, terminal) def target_update(self, target_net, net): for t, s in zip(target_net.parameters(), net.parameters()): # t.data.copy_(t.data * (1.0 - self.polyak) + s.data * self.polyak) t.data.mul_(1.0 - self.polyak) t.data.add_(self.polyak * s.data) def calc_target_q(self, next_states, rewards, terminals): with torch.no_grad(): next_action, entropy, _ = self.actor.sample( next_states) # penalty term next_q1, next_q2 = self.target_critic(next_states, next_action) next_q = torch.min(next_q1, next_q2) - self.alpha * entropy target_q = rewards * self.reward_scale + ( 1. - terminals) * self.gamma * next_q return target_q def calc_critic_loss(self, states, actions, next_states, rewards, terminals): q1, q2 = self.critic(states, actions) target_q = self.calc_target_q(next_states, rewards, terminals) q1_loss = torch.mean((q1 - target_q).pow(2)) q2_loss = torch.mean((q2 - target_q).pow(2)) return q1_loss, q2_loss def calc_actor_loss(self, states): action, entropy, _ = self.actor.sample(states) q1, q2 = self.critic(states, action) q = torch.min(q1, q2) # actor_loss = torch.mean(-q - self.alpha * entropy) actor_loss = (self.alpha * entropy - q).mean() return actor_loss, entropy def train(self, timestep, batch_size=256): if len(self.memory) < batch_size: return transitions = self.memory.sample(batch_size) transitions = Transition(*zip(*transitions)) if self.cuda: states = torch.cat(transitions.state).to('cuda') actions = torch.stack(transitions.action).to('cuda') next_states = torch.cat(transitions.next_state).to('cuda') rewards = torch.stack(transitions.reward).to('cuda') terminals = torch.stack(transitions.terminal).to('cuda') else: states = torch.cat(transitions.state) actions = torch.stack(transitions.action) next_states = torch.cat(transitions.next_state) rewards = torch.stack(transitions.reward) terminals = torch.stack(transitions.terminal) # Compute target Q func q1_loss, q2_loss = self.calc_critic_loss(states, actions, next_states, rewards, terminals) # Compute actor loss actor_loss, mean = self.calc_actor_loss(states) update_params(self.q1_optimizer, self.critic.q1, q1_loss) update_params(self.q2_optimizer, self.critic.q2, q2_loss) update_params(self.actor_optimizer, self.actor, actor_loss) # target update self.target_update(self.target_critic, self.critic) if timestep % 100 and self.writer: self.writer.add_scalar('Loss/Actor', actor_loss.item(), timestep) self.writer.add_scalar('Loss/Critic', q1_loss.item(), timestep) def save_weights(self, path): self.actor.save(os.path.join(path, 'actor')) self.critic.save(os.path.join(path, 'critic'))