def test_step(self): # self.mse.reset(video_log_file='test_multi_step_env_from_env') # vl = VideoLogger('test_multi_step_env', self.height, self.width) replay_memory = ReplayBuffer(capacity=40000) mem_size = 0.0 for eps in range(50): obs = self.mse.reset() total_reward = 0.0 done = False n = 0 # all_objects = muppy.get_objects() # sum1 = summary.summarize(all_objects) # summary.print_(sum1) while not done: action_idx = np.random.randint(len(self.action_dict)) next_obs, reward, done, _ = self.mse.step(action_idx) total_reward += reward n += 1 replay_memory.insert(obs, action_idx, reward, next_obs, done) mem_size += obs.nbytes + next_obs.nbytes # print(obs.shape, obs.dtype, obs.nbytes / (1024.0 * 1024.0)) # print(f'n = {n}, total_reward = {total_reward}, done = {done}, action = {action}') if done: break obs = next_obs print(f'eps {eps} done...', replay_memory.size, mem_size / (1024.0 * 1024.0), 100 * (mem_size / 48) / (1024.0 * 1024.0 * 1024))
class EpisodeManager: env_reg_name: str buffer_size: int def __post_init__(self): self.env = gym.make(self.env_reg_name) self.num_steps = 0 self.episodes = 0 self.replay_buffer = ReplayBuffer(self.buffer_size) def play_episode(self, frame_wrapper: Frame, render=False, model: Callable = None): prev_obs = frame_wrapper.reset(self.env) done = False while not done: if render: self.env.render() if model is not None: action = model(prev_obs) else: action = self.env.action_space.sample() obs, reward, done, info = self.step(action, frame_wrapper) self.replay_buffer.insert( StateTransition(prev_obs, action, obs, reward, done, info)) prev_obs = obs self.num_steps += 1 self.episodes += 1 def step(self, action: int, frame_wrapper: Frame): if frame_wrapper is not None: out = frame_wrapper.step(self.env, action=action) else: out = self.env.step(action) return out
def train(env, critic_net, target_net, policy, total_steps=10**6, lr=3e-4, gamma=0.99, polyak=0.995, batch_size=256): update_interval = 1 target_entropy = -env.action_space.shape[0] state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] env.seed(seed) env.action_space.np_random.seed(seed) log_alpha = torch.zeros(1, requires_grad=True) alpha = log_alpha.exp() temp_optimizer = torch.optim.Adam([log_alpha], lr=lr) replay_buffer = ReplayBuffer(state_dim, action_dim) freq_print = 1000 start_after = 10000 update_after = 1000 update_times = 1 episode_rewards = [] state = env.reset() episode_reward = 0 q1_loss, q2_loss, policy_loss_ = 0., 0., 0. for step in range(total_steps): if step % 10**4 == 0 and step > 1: torch.save(policy, "Policy" + str(step) + "pt") with torch.no_grad(): if step < start_after: action = env.action_space.sample() else: action = get_action(state, policy) next_state, reward, done, _ = env.step(action) episode_reward += reward replay_buffer.insert(state, action, reward, next_state, done) state = next_state if done: state = env.reset() episode_rewards.append(episode_reward) episode_reward = 0 if step % freq_print == 0 and step > 1: running_reward = np.mean(episode_rewards) print(step, np.min(episode_rewards), running_reward, np.max(episode_rewards), q1_loss + q2_loss, policy_loss_, alpha) log = [ step, np.min(episode_rewards), running_reward, np.max(episode_rewards), q1_loss + q2_loss, policy_loss_ ] log = [str(x) for x in log] with open("plt_file", "a") as s: s.write(" ".join(log) + "\n") episode_rewards = [] if step > update_after: # and step % update_interval == 0: for update_count in range(update_times): batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = replay_buffer.get_batch( batch_size) batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = torch.from_numpy( batch_states).float(), torch.from_numpy( batch_actions).float(), torch.from_numpy( batch_rewards).float(), torch.from_numpy( batch_next_states).float(), torch.from_numpy( batch_dones).float() with torch.no_grad(): next_actions, logprob_next_actions, _ = policy( batch_next_states) q_t1, q_t2 = target_net(batch_next_states, next_actions) q_target = torch.min(q_t1, q_t2) critic_target = batch_rewards + ( 1.0 - batch_dones) * gamma * ( q_target - alpha * logprob_next_actions) q_1, q_2 = critic_net(batch_states, batch_actions) loss_1 = torch.nn.MSELoss()(q_1, critic_target) loss_2 = torch.nn.MSELoss()(q_2, critic_target) q_loss_step = loss_1 + loss_2 critic_net.optimizer.zero_grad() q_loss_step.backward() critic_net.optimizer.step() q1_loss = loss_1.detach().item() q2_loss = loss_2.detach().item() for p in critic_net.parameters(): p.requires_grad = False policy_action, log_prob_policy_action, _ = policy(batch_states) p1, p2 = critic_net(batch_states, policy_action) target = torch.min(p1, p2) policy_loss = (alpha * log_prob_policy_action - target).mean() policy.optimizer.zero_grad() policy_loss.backward() policy.optimizer.step() temp_loss = -log_alpha * (log_prob_policy_action.detach() + target_entropy).mean() temp_optimizer.zero_grad() temp_loss.backward() temp_optimizer.step() for p in critic_net.parameters(): p.requires_grad = True alpha = log_alpha.exp() policy_loss_ = policy_loss.detach().item() with torch.no_grad(): for target_q_param, q_param in zip( target_net.parameters(), critic_net.parameters()): target_q_param.data.copy_((1 - polyak) * q_param.data + polyak * target_q_param.data)
class DQN(object): def __init__(self, multi_step_env: MultiStepEnv = None, gamma: float = None, eps_max: float = None, eps_min: float = None, eps_decay_steps: int = None, replay_min_size: int = None, replay_max_size: int = None, target_update_freq: int = None, steps_per_update: int = None, train_batch_size: int = None, enable_rgb: bool = None, model_save_file: str = None, optim_l2_reg_coeff: float = None, optim_lr: float = None, eval_freq: int = None): self.env = multi_step_env self.gamma = gamma self.eps_max = eps_max self.eps_min = eps_min self.eps_decay_steps = eps_decay_steps self.replay_min_size = replay_min_size self.target_update_freq = target_update_freq self.train_batch_size = train_batch_size self.steps_per_update = steps_per_update self.model_save_file = model_save_file self.optim_lr = optim_lr self.optim_l2_reg_coeff = optim_l2_reg_coeff self.eval_freq = eval_freq self.replay_memory = ReplayBuffer(capacity=replay_max_size) self.n_steps = 0 if enable_rgb: self.q_train = Q(self.env.frame_stack_size * 3, self.env.height, self.env.width, self.env.num_actions).to(settings.device) self.q_target = Q(self.env.frame_stack_size * 3, self.env.height, self.env.width, self.env.num_actions).to(settings.device) else: self.q_train = Q(self.env.frame_stack_size, self.env.height, self.env.width, self.env.num_actions).to(settings.device) self.q_target = Q(self.env.frame_stack_size, self.env.height, self.env.width, self.env.num_actions).to(settings.device) self.optimizer = Adam(self.q_train.parameters(), eps=1e-7, lr=self.optim_lr, weight_decay=self.optim_l2_reg_coeff) # self.mse_loss = nn.MSELoss() assert (self.q_train.state_dict().keys() == self.q_target.state_dict().keys()) def _update_step_counter(self): self.n_steps += 1 def copyAtoB(self, A, B, tau=None): for paramA, paramB in zip(A.parameters(), B.parameters()): paramB.data.copy_(paramA.data) def _update_q_target(self): if (self.n_steps % self.target_update_freq) == 0: self.copyAtoB(self.q_train, self.q_target) @staticmethod def normalize(obs): return (obs / 255.0 * 2 - 1) def _update_q_train(self): if self.replay_memory.size >= self.replay_min_size and ( self.n_steps % self.steps_per_update) == 0: self.q_train.train() states, action_idxs, rewards, next_states, dones = self.replay_memory.sample( self.train_batch_size) states = torch.from_numpy(states).float().to( settings.device).permute(0, 3, 1, 2) action_idxs = torch.from_numpy(action_idxs).long().to( settings.device) rewards = torch.from_numpy(rewards).float().to(settings.device) next_states = torch.from_numpy(next_states).float().to( settings.device).permute(0, 3, 1, 2) dones = torch.from_numpy(dones).float().to(settings.device) states = self.normalize(states) next_states = self.normalize(next_states) q_cur = torch.gather(self.q_train(states), -1, action_idxs.unsqueeze(-1)).squeeze(-1) with torch.no_grad(): q_next = self.q_target(next_states) v_next, _ = torch.max(q_next, dim=-1) targets = rewards + self.gamma * v_next * (1 - dones) targets = targets.detach() loss = 0.5 * F.mse_loss(q_cur, targets) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.detach().cpu().item() return None def _get_eps_greedy_action(self, obs, eps=0): if np.random.rand() < eps: return self.env.random_action() else: self.q_train.eval() with torch.no_grad(): obs = self.normalize(obs) obs = torch.from_numpy(obs).float().to(settings.device) obs = obs.unsqueeze(0).permute(0, 3, 1, 2) q = self.q_train(obs).squeeze(0) action = torch.argmax(q).detach().cpu().item() return action def _get_epsilon(self): eps = self.eps_min + max( 0, (self.eps_decay_steps - self.n_steps) / self.eps_decay_steps) * (self.eps_max - self.eps_min) return eps def save_state(self, file): checkpoint = {} checkpoint['q_train'] = self.q_train.state_dict() checkpoint['q_target'] = self.q_target.state_dict() checkpoint['optimizer'] = self.optimizer.state_dict() torch.save(checkpoint, file) def load_state(self, file): checkpoint = torch.load(file) self.q_train.load_state_dict(checkpoint['q_train']) self.q_target.load_state_dict(checkpoint['q_target']) self.optimizer.load_state_dict(checkpoint['optimizer']) def train(self, num_episodes): all_rewards = [] for n_episode in range(num_episodes): # Log test episode if n_episode % self.eval_freq == 0: self.eval(1) # Reset the environment obs = self.env.reset() # Play an episode total_reward = 0.0 total_loss = 0.0 n = 0 done = False while not done: # Epsilon greedy action selection action_idx = self._get_eps_greedy_action( obs, eps=self._get_epsilon()) # Take a step next_obs, reward, done, _ = self.env.step(action_idx) # nobs, nnext = self.normalize(obs), self.normalize(next_obs) # print(np.mean(np.abs(nobs[:,:,0], nnext[:,:,0]))) # print(self.normalize(obs)) # Update replay buffer self.replay_memory.insert(obs, action_idx, reward, next_obs, done) # Update networks self._update_q_target() loss = self._update_q_train() # Bookkeeping total_reward += reward self._update_step_counter() if loss is not None: total_loss += loss n += 1 obs = next_obs # Save weights self.save_state(self.model_save_file) all_rewards.append(total_reward) if (len(all_rewards) > 100): all_rewards = list(np.array(all_rewards)[-100:]) last_100_avg_rwd = np.sum(all_rewards) / len(all_rewards) avg_loss = total_loss * self.steps_per_update / n print('[TRAIN] n_episode: {}/{}, steps: {}, total_steps: {}, episode_reward: {:.03f}, 100_avg_rwd: {:.03f}, avg_loss: {:.03f}, eps: {:.03f}, replay_size: {}'\ .format(n_episode+1, num_episodes, n, self.n_steps, total_reward, last_100_avg_rwd, avg_loss, self._get_epsilon(), self.replay_memory.size)) def eval(self, num_episodes): all_rewards = [] for n_episode in range(num_episodes): # Reset the environment obs = self.env.reset( video_log_file=f'dqn_test_log_episode_{n_episode}') # Play an episode total_reward = 0.0 total_loss = 0.0 n = 0 done = False while not done: # Greedy action selection action_idx = self._get_eps_greedy_action(obs, eps=0.0) # Take a step next_obs, reward, done, _ = self.env.step(action_idx) # Bookkeeping total_reward += reward n += 1 obs = next_obs all_rewards.append(total_reward) avg_rwd = np.sum(all_rewards) / len(all_rewards) print('[EVAL] n_episode: {}/{}, steps: {}, episode_reward: {:.03f}, avg_rwd: {:.03f}'\ .format(n_episode+1, num_episodes, n, total_reward, avg_rwd))
class DQNModelsHandler: def __init__(self, env_class: CartpoleEnv, buffer_size, lr=0.001, online_log=True): self.environment_class = env_class with env_class() as env: self.n_states = env.n_states self.n_actions = env.n_actions if online_log: online_logger.init( name=f"cart-{datetime.datetime.now().isoformat()}") self._online_log = online_log self.model = DQNNetwork(self.n_states, self.n_actions, lr=lr) self.target_model = DQNNetwork(self.n_states, self.n_actions, lr=lr) self.rolling_loss = deque(maxlen=12) self.replay_buffer = ReplayBuffer(buffer_size) self.episode_count = 0 self.rolling_reward = deque(maxlen=12) self.model_update_count = 0 def train_step(self, dis_fact=0.99): trans_sts = self.replay_buffer.sample(self._sampling_size) states = torch.stack([trans.state_tensor for trans in trans_sts]) next_states = torch.stack( [trans.next_state_tensor for trans in trans_sts]) not_done = torch.Tensor([trans.not_done_tensor for trans in trans_sts]) actions = [trans.action for trans in trans_sts] rewards = torch.stack([trans.reward_tensor for trans in trans_sts]) with torch.no_grad(): qvals_predicted = self.target_model(next_states).max(-1) self.model.optimizer.zero_grad() qvals_current = self.model(states) one_hot_actions = torch.nn.functional.one_hot( torch.LongTensor(actions), self.n_actions) loss = ((rewards + (not_done * qvals_predicted.values) - torch.sum(qvals_current * one_hot_actions, -1))**2).mean() loss.backward() self.model.optimizer.step() self.rolling_loss.append(loss.detach().item()) def update_target_model(self): state_dict = deepcopy(self.model.state_dict()) self.target_model.load_state_dict(state_dict) self.model_update_count += 1 def play_episode(self, update_model=True): with self.environment_class() as env: while not env.episode_finished: if (exp_decay(self.n_steps) > np.random.random() and self.n_steps > self._min_samples_before_update): state_trans = env.random_step() else: state = env.current_state predicted_action = self.target_model(torch.Tensor(state)) state_trans = env.step(predicted_action.argmax().item()) self.replay_buffer.insert(state_trans) if update_model: if self.matches_update_criteria(): self.train_step() self.update_target_model() self.check_reward() self.verbose_training() if (self.model_update_count % self._model_save_every_nth_update == 0): self.save_target_model() self.episode_count += 1 def save_target_model(self): file_name = f"{datetime.datetime.now().strftime('%H:%M:%S')}.pth" model_save_name = f"/tmp/{file_name}" torch.save(self.target_model.state_dict(), model_save_name) if self._online_log: online_logger.save(model_save_name) else: os.rename(model_save_name, f"./{file_name}") def check_reward(self): with torch.no_grad(): with self.environment_class() as reward_env: while not reward_env.episode_finished: state = reward_env.current_state predicted_action = self.target_model(torch.Tensor(state)) reward_env.step(predicted_action.argmax().item()) self.rolling_reward.append(reward_env.reward) def get_rolling_reward(self): return sum(self.rolling_reward) / len(self.rolling_reward) def get_rolling_loss(self): return sum(self.rolling_loss) / len(self.rolling_loss) def set_model_updt_criteria( self, min_samples_before_update, update_every, sampling_size, model_save_every_nth_update, ): self._min_samples_before_update = min_samples_before_update self._update_every = update_every self._sampling_size = sampling_size self._model_save_every_nth_update = model_save_every_nth_update @property def n_steps(self): return self.replay_buffer.idx def matches_update_criteria(self): if self.replay_buffer.idx >= self._min_samples_before_update: if self.episode_count % self._update_every == 0: return True return False def verbose_training(self): logging_dict = { "Rolling_reward: ": self.get_rolling_reward(), "Rolling Loss:": self.get_rolling_loss(), } if self._online_log: online_logger.log(logging_dict) else: print(logging_dict)