def declare_networks(self): self.model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody) self.target_model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody)
def declare_networks(self): self.model = DQN(self.input_shape, self.num_actions, self.noisy, self.sigma_init, body=AtariBody) self.target_model = DQN(self.input_shape, self.num_actions, self.noisy, self.sigma_init, body=AtariBody)
class Model(BaseAgent): def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'): super(Model, self).__init__(config=config, env=env, log_dir=log_dir) self.device = config.device self.noisy = config.USE_NOISY_NETS self.priority_replay = config.USE_PRIORITY_REPLAY self.gamma = config.GAMMA self.lr = config.LR self.target_net_update_freq = config.TARGET_NET_UPDATE_FREQ self.experience_replay_size = config.EXP_REPLAY_SIZE self.batch_size = config.BATCH_SIZE self.learn_start = config.LEARN_START self.update_freq = config.UPDATE_FREQ self.sigma_init = config.SIGMA_INIT self.priority_beta_start = config.PRIORITY_BETA_START self.priority_beta_frames = config.PRIORITY_BETA_FRAMES self.priority_alpha = config.PRIORITY_ALPHA self.static_policy = static_policy self.num_feats = env.observation_space.shape self.num_actions = env.action_space.n self.env = env self.declare_networks() self.target_model.load_state_dict(self.model.state_dict()) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) #move to correct device self.model = self.model.to(self.device) self.target_model.to(self.device) if self.static_policy: self.model.eval() self.target_model.eval() else: self.model.train() self.target_model.train() self.update_count = 0 self.declare_memory() self.nsteps = config.N_STEPS self.nstep_buffer = [] def declare_networks(self): self.model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody) self.target_model = DQN(self.num_feats, self.num_actions, noisy=self.noisy, sigma_init=self.sigma_init, body=AtariBody) def declare_memory(self): self.memory = ExperienceReplayMemory( self.experience_replay_size ) if not self.priority_replay else PrioritizedReplayMemory( self.experience_replay_size, self.priority_alpha, self.priority_beta_start, self.priority_beta_frames) def append_to_replay(self, s, a, r, s_): self.nstep_buffer.append((s, a, r, s_)) if (len(self.nstep_buffer) < self.nsteps): return R = sum([ self.nstep_buffer[i][2] * (self.gamma**i) for i in range(self.nsteps) ]) state, action, _, _ = self.nstep_buffer.pop(0) self.memory.push((state, action, R, s_)) def prep_minibatch(self): # random transition batch is taken from experience replay memory transitions, indices, weights = self.memory.sample(self.batch_size) batch_state, batch_action, batch_reward, batch_next_state = zip( *transitions) shape = (-1, ) + self.num_feats batch_state = torch.tensor(batch_state, device=self.device, dtype=torch.float).view(shape) batch_action = torch.tensor(batch_action, device=self.device, dtype=torch.long).squeeze().view(-1, 1) batch_reward = torch.tensor(batch_reward, device=self.device, dtype=torch.float).squeeze().view(-1, 1) non_final_mask = torch.tensor(tuple( map(lambda s: s is not None, batch_next_state)), device=self.device, dtype=torch.uint8) try: #sometimes all next states are false non_final_next_states = torch.tensor( [s for s in batch_next_state if s is not None], device=self.device, dtype=torch.float).view(shape) empty_next_state_values = False except: non_final_next_states = None empty_next_state_values = True return batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights def compute_loss(self, batch_vars): #faster batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars #estimate self.model.sample_noise() current_q_values = self.model(batch_state).gather(1, batch_action) #target with torch.no_grad(): max_next_q_values = torch.zeros(self.batch_size, device=self.device, dtype=torch.float).unsqueeze(dim=1) if not empty_next_state_values: max_next_action = self.get_max_next_state_action( non_final_next_states) self.target_model.sample_noise() max_next_q_values[non_final_mask] = self.target_model( non_final_next_states).gather(1, max_next_action) expected_q_values = batch_reward + ( (self.gamma**self.nsteps) * max_next_q_values) diff = (expected_q_values - current_q_values) if self.priority_replay: self.memory.update_priorities( indices, diff.detach().squeeze().abs().cpu().numpy().tolist()) loss = self.MSE(diff).squeeze() * weights else: loss = self.MSE(diff) loss = loss.mean() return loss def update(self, s, a, r, s_, frame=0): if self.static_policy: return None self.append_to_replay(s, a, r, s_) if frame < self.learn_start or frame % self.update_freq != 0: return None batch_vars = self.prep_minibatch() loss = self.compute_loss(batch_vars) # Optimize the model self.optimizer.zero_grad() loss.backward() for group in self.optimizer.param_groups: for p in group['params']: state = self.optimizer.state[p] if ('step' in state and state['step'] >= 1024): state['step'] = 1000 for param in self.model.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() self.update_target_model() self.save_td(loss.item(), frame) self.save_sigma_param_magnitudes(frame) def get_action(self, s, eps=0.1): #faster with torch.no_grad(): if np.random.random() >= eps or self.static_policy or self.noisy: X = torch.tensor([s], device=self.device, dtype=torch.float) self.model.sample_noise() a = self.model(X).max(1)[1].view(1, 1) return a.item() else: return np.random.randint(0, self.num_actions) def update_target_model(self): self.update_count += 1 self.update_count = self.update_count % self.target_net_update_freq if self.update_count == 0: self.target_model.load_state_dict(self.model.state_dict()) def get_max_next_state_action(self, next_states): return self.target_model(next_states).max(dim=1)[1].view(-1, 1) def finish_nstep(self): while len(self.nstep_buffer) > 0: R = sum([ self.nstep_buffer[i][2] * (self.gamma**i) for i in range(len(self.nstep_buffer)) ]) state, action, _, _ = self.nstep_buffer.pop(0) self.memory.push((state, action, R, None)) def reset_hx(self): pass
class DQNAgent(object): def __init__(self, config=None, env=None, log_dir=None, static_policy=False): # Train or Test self.static_policy = static_policy # Tricks Flags self.noisy = config.USE_NOISY_NETS self.priority_replay = config.USE_PRIORITY_REPLAY self.nsteps = config.N_STEPS # Tricks Parameters self.sigma_init = config.SIGMA_INIT self.alpha = config.PRIORITY_ALPHA self.priority_beta_start = config.PRIORITY_BETA_START self.priority_beta_frames = config.PRIORITY_BETA_FRAMES self.nstep_buffer = [] # Categorical-DQN self.atoms = config.ATOMS self.v_max = config.V_MAX self.v_min = config.V_MIN # QR-DQN self.quantiles = config.QUANTILES # Device self.device = config.device # Memory self.replay_buffer_size = config.REPLAY_BUFFER_SIZE # LR & BATCH_SIZE & Discount self.lr = config.LR self.batch_size = config.BATCH_SIZE self.gamma = config.GAMMA # Learn Procedure self.max_frames = config.MAX_FRAMES self.learn_start = config.LEARN_START self.update_freq = config.CURRENT_NET_UPDATE_FREQUENCY self.target_update_freq = config.TARGET_NET_UPDATE_FREQUENCY self.update_count = 0 # Log Info self.action_log_frequency = config.ACTION_SELECTION_COUNT_FREQUENCY self.log_dir = log_dir self.rewards = [] self.action_selections = [0 for _ in range(env.action_space.n)] # Exploration_policy self.epsilon_start = config.EPSILON_START self.epsilon_final = config.EPSILON_FINAL self.epsilon_decay = config.EPSILON_DECAY self.epsilon_by_frame = config.EPSILON_BY_FRAME # Env self.env = env self.input_shape = env.observation_space.shape self.num_actions = env.action_space.n # Construct Entities self.declare_memory() self.declare_networks() # Network Initialization & Optimizer & Movement self.target_model.load_state_dict(self.model.state_dict()) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.model = self.model.to(self.device) self.target_model.to(self.device) # Train or Test Function if self.static_policy: self.model.eval() self.target_model.eval() else: self.model.train() self.target_model.train() def declare_memory(self): if not self.priority_replay: self.memory = ExperienceReplayBuffer(self.replay_buffer_size) else: pass def declare_networks(self): self.model = DQN(self.input_shape, self.num_actions, self.noisy, self.sigma_init, body=AtariBody) self.target_model = DQN(self.input_shape, self.num_actions, self.noisy, self.sigma_init, body=AtariBody) def loss_functon(self, loss_name, x): if loss_name == 'huber': cond = (x.abs() < 1.0).float().detach() return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1.0 - cond) if loss_name == 'MSE': return 0.5 * x.pow(2) ###################### Main Update Step ####################### def update(self, s, a, r, s_, frame_idx): # train or test if self.static_policy: return None # store in memroy self.append_to_replay(s, a, r, s_) # when to learn & how often we learn if frame_idx < self.learn_start or frame_idx % self.update_freq != 0: return None # prepare_data from memory batch_vars = self.prep_minibatch() # compute the task-specific loss loss = self.compute_loss(batch_vars) # gradient update self.optimizer.zero_grad() loss.backward() for parameter in self.model.parameters(): parameter.grad.data.clamp_(-1, 1) self.optimizer.step() # when to update the target self.update_target_model() # save some info like the TD-error self.save_td(loss.item(), frame_idx) self.save_sigma_param_magnitudes(frame_idx) ############################### Utility Function ################ def append_to_replay(self, s, a, r, s_): self.nstep_buffer.append((s, a, r, s_)) if (len(self.nstep_buffer) < self.nsteps): return R = sum([ self.nstep_buffer[i][2] * (self.gamma**i) for i in range(self.nsteps) ]) state, action, _, _ = self.nstep_buffer.pop(0) self.memory.push((state, action, R, s_)) def prep_minibatch(self): transitions, indices, weights = self.memory.sample(self.batch_size) batch_states, batch_actions, batch_rewards, batch_next_states = zip( *transitions) batch_states_shape = (-1, ) + self.input_shape batch_states = torch.tensor(batch_states, device=self.device, dtype=torch.float).view(batch_states_shape) batch_actions = torch.tensor(batch_actions, device=self.device, dtype=torch.long).view(-1, 1) batch_rewards = torch.tensor(batch_rewards, device=self.device, dtype=torch.float).view(-1, 1) non_final_mask = torch.tensor(tuple( map(lambda s: s is not None, batch_next_states)), device=self.device, dtype=torch.bool) try: #sometimes all next states are false non_final_next_states = torch.tensor( [s for s in batch_next_states if s is not None], device=self.device, dtype=torch.float).view(batch_states_shape) empty_next_state_values = False except: non_final_next_states = None empty_next_state_values = True return batch_states, batch_actions, batch_rewards, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights def compute_loss(self, batch_vars): batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values, indices, weights = batch_vars # current-q-values self.model.sample_noise() current_q_values = self.model(batch_state).gather(1, batch_action) # target-q-values with torch.no_grad(): max_next_q_values = torch.zeros(self.batch_size, device=self.device, dtype=torch.float).unsqueeze(dim=1) if not empty_next_state_values: # get_max_next_state_action max_next_action = self.get_max_next_state_action( non_final_next_states ) # action selection comes from target model (not double) self.target_model.sample_noise() max_next_q_values[non_final_mask] = self.target_model( non_final_next_states).gather(1, max_next_action) target_q_values = batch_reward + (self.gamma** self.nsteps) * max_next_q_values diff = target_q_values - current_q_values loss = self.loss_functon('MSE', diff) loss = loss.mean() return loss def get_max_next_state_action(self, non_final_next_states): max_next_action = self.target_model(non_final_next_states).max( dim=1)[1].view(-1, 1) return max_next_action def update_target_model(self): self.update_count += 1 if self.update_count % self.target_update_freq == 0: self.target_model.load_state_dict(self.model.state_dict()) self.update_count = 0 def get_action(self, s, eps=0.1): with torch.no_grad(): if np.random.random() >= eps or self.static_policy or self.noisy: X = torch.tensor([s], device=self.device, dtype=torch.float) self.model.sample_noise() a = self.model(X).max(1)[1].view(1, 1) return a.item() else: return np.random.randint(0, self.num_actions) def finish_nstep(self): while len(self.nstep_buffer) > 0: R = sum([ self.nstep_buffer[i][2] * (self.gamma**i) for i in range(len(self.nstep_buffer)) ]) state, action, _, _ = self.nstep_buffer.pop(0) self.memory.push((state, action, R, None)) ################################ save & load & log ############################ def save_weight(self): torch.save(self.model.state_dict(), os.path.join(self.log_dir, 'model.dump')) torch.save(self.optimizer.state_dict(), os.path.join(self.log_dir, 'optim.dump')) def load_weight(self): fname_model = os.path.join(self.log_dir, 'model.dump') fname_optim = os.path.join(self.log_dir, 'optim.dump') if os.path.isfile(fname_model): self.model.load_state_dict(torch.load(fname_model)) self.target_model.load_state_dict(self.model.state_dict()) if os.path.isfile(fname_optim): self.optimizer.load_state_dict(torch.load(fname_optim)) def save_replay(self): pickle.dump( self.memory, open(os.path.join(self.log_dir, 'exp_replay_agent.dump'), 'wb')) def load_replay(self): fname = os.path.join(self.log_dir, 'exp_replay_agent.dump') if os.path.isfile(fname): self.memory = pickle.load(open(fname, 'rb')) def save_td(self, loss, timestep): with open(os.path.join(self.log_dir, 'td.csv'), 'a') as f: writer = csv.writer(f) writer.writerow((timestep, loss)) def save_reward(self, episode_reward): self.rewards.append(episode_reward) def save_action(self, action, tstep): self.action_selections[int(action)] += 1.0 / self.action_log_frequency if (tstep + 1) % self.action_log_frequency == 0: with open(os.path.join(self.log_dir, 'action_log.csv'), 'a') as f: writer = csv.writer(f) writer.writerow(list([tstep] + self.action_selections)) self.action_selections = [ 0 for _ in range(len(self.action_selections)) ] def save_sigma_param_magnitudes(self, tstep): with torch.no_grad(): sum_, count = 0.0, 0.0 for name, param in self.model.named_parameters(): if param.requires_grad and 'sigma' in name: sum_ += torch.sum(param.abs()).item() count += np.prod(param.shape) if count > 0: with open(os.path.join(self.log_dir, 'sig_param_mag.csv'), 'a') as f: writer = csv.writer(f) writer.writerow((tstep, sum_ / count))