class DDPG(Base): def __init__(self, env_name='Pendulum-v0', load_dir='./ckpt', log_dir="./log", buffer_size=1e6, seed=1, max_episode_steps=None, batch_size=64, discount=0.99, learning_starts=500, tau=0.005, save_eps_num=100, external_env=None): self.env_name = env_name self.load_dir = load_dir self.log_dir = log_dir self.seed = seed self.max_episode_steps = max_episode_steps self.buffer_size = buffer_size self.batch_size = batch_size self.discount = discount self.learning_starts = learning_starts self.tau = tau self.save_eps_num = save_eps_num torch.manual_seed(self.seed) np.random.seed(self.seed) self.writer = SummaryWriter(log_dir=self.log_dir) if external_env == None: env = gym.make(self.env_name) else: env = external_env if self.max_episode_steps != None: env._max_episode_steps = self.max_episode_steps else: self.max_episode_steps = env._max_episode_steps self.env = NormalizedActions(env) self.ou_noise = OUNoise(self.env.action_space) state_dim = self.env.observation_space.shape[0] action_dim = self.env.action_space.shape[0] hidden_dim = 256 self.q_net = QNet(state_dim, action_dim, hidden_dim).to(device) self.policy_net = PolicyNet(state_dim, action_dim, hidden_dim).to(device) self.target_q_net = QNet(state_dim, action_dim, hidden_dim).to(device) self.target_policy_net = PolicyNet(state_dim, action_dim, hidden_dim).to(device) try: self.load(directory=self.load_dir, filename=self.env_name) print('Load model successfully !') except: print('WARNING: No model to load !') soft_update(self.q_net, self.target_q_net, soft_tau=1.0) soft_update(self.policy_net, self.target_policy_net, soft_tau=1.0) q_lr = 1e-3 policy_lr = 1e-4 self.value_optimizer = optim.Adam(self.q_net.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) self.value_criterion = nn.MSELoss() self.replay_buffer = ReplayBuffer(self.buffer_size) self.total_steps = 0 self.episode_num = 0 self.episode_timesteps = 0 def save(self, directory, filename): if not os.path.exists(directory): os.makedirs(directory) torch.save(self.q_net.state_dict(), '%s/%s_q_net.pkl' % (directory, filename)) torch.save(self.policy_net.state_dict(), '%s/%s_policy_net.pkl' % (directory, filename)) def load(self, directory, filename): self.q_net.load_state_dict( torch.load('%s/%s_q_net.pkl' % (directory, filename))) self.policy_net.load_state_dict( torch.load('%s/%s_policy_net.pkl' % (directory, filename))) def train_step(self, min_value=-np.inf, max_value=np.inf, soft_tau=1e-2): state, action, reward, next_state, done = self.replay_buffer.sample( self.batch_size) state = torch.FloatTensor(state).to(device) next_state = torch.FloatTensor(next_state).to(device) action = torch.FloatTensor(action).to(device) reward = torch.FloatTensor(reward).unsqueeze(1).to(device) done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device) policy_loss = self.q_net(state, self.policy_net(state)) policy_loss = -policy_loss.mean() next_action = self.target_policy_net(next_state) target_value = self.target_q_net(next_state, next_action.detach()) expected_value = reward + (1.0 - done) * self.discount * target_value expected_value = torch.clamp(expected_value, min_value, max_value) value = self.q_net(state, action) value_loss = self.value_criterion(value, expected_value.detach()) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.value_optimizer.zero_grad() value_loss.backward() self.value_optimizer.step() soft_update(self.q_net, self.target_q_net, self.tau) soft_update(self.policy_net, self.target_policy_net, self.tau) def predict(self, state): return self.policy_net.get_action(state) def learn(self, max_steps=1e7): while self.total_steps < max_steps: state = self.env.reset() self.episode_timesteps = 0 episode_reward = 0 for step in range(self.max_episode_steps): action = self.policy_net.get_action(state) action = self.ou_noise.get_action(action, self.total_steps) next_state, reward, done, _ = self.env.step(action) self.replay_buffer.push(state, action, reward, next_state, done) state = next_state episode_reward += reward self.total_steps += 1 self.episode_timesteps += 1 if done or self.episode_timesteps == self.max_episode_steps: if len(self.replay_buffer) > self.learning_starts: for _ in range(self.episode_timesteps): self.train_step() self.episode_num += 1 if self.episode_num > 0 and self.episode_num % self.save_eps_num == 0: self.save(directory=self.load_dir, filename=self.env_name) self.writer.add_scalar('episode_reward', episode_reward, self.episode_num) break self.env.close()
class RainbowDQN(object): def __init__(self, env_id="CartPole-v0", Vmin=-10, Vmax=10, num_atoms=51): self.Vmin = Vmin self.Vmax = Vmax self.num_atoms = num_atoms self.env_id = env_id self.env = gym.make(self.env_id) self.current_model = TinyRainbowDQN( self.env.observation_space.shape[0], self.env.action_space.n, self.num_atoms, self.Vmin, self.Vmax) self.target_model = TinyRainbowDQN(self.env.observation_space.shape[0], self.env.action_space.n, self.num_atoms, self.Vmin, self.Vmax) if torch.cuda.is_available(): self.current_model = self.current_model.cuda() self.target_model = self.target_model.cuda() self.optimizer = optim.Adam(self.current_model.parameters(), 0.001) self.replay_buffer = ReplayBuffer(10000) self.update_target(self.current_model, self.target_model) self.losses = [] def update_target(self, current_model, target_model): target_model.load_state_dict(current_model.state_dict()) def projection_distribution(self, next_state, rewards, dones): batch_size = next_state.size(0) delta_z = float(self.Vmax - self.Vmin) / (self.num_atoms - 1) support = torch.linspace(self.Vmin, self.Vmax, self.num_atoms) next_dist = self.target_model(next_state).data.cpu() * support next_action = next_dist.sum(2).max(1)[1] next_action = next_action.unsqueeze(1).unsqueeze(1).expand( next_dist.size(0), 1, next_dist.size(2)) next_dist = next_dist.gather(1, next_action).squeeze(1) rewards = rewards.unsqueeze(1).expand_as(next_dist) dones = dones.unsqueeze(1).expand_as(next_dist) support = support.unsqueeze(0).expand_as(next_dist) Tz = rewards + (1 - dones) * 0.99 * support Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) b = (Tz - self.Vmin) / delta_z l = b.floor().long() u = b.ceil().long() offset = torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size).long() \ .unsqueeze(1).expand(batch_size, self.num_atoms) proj_dist = torch.zeros(next_dist.size()) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) return proj_dist def train_step(self, gamma=0.99, batch_size=32): state, action, reward, next_state, done = self.replay_buffer.sample( batch_size) state = torch.FloatTensor(np.float32(state)) next_state = torch.FloatTensor(np.float32(next_state)) action = torch.LongTensor(action) reward = torch.FloatTensor(reward) done = torch.FloatTensor(np.float32(done)) proj_dist = self.projection_distribution(next_state, reward, done) dist = self.current_model(state) action = action.unsqueeze(1).unsqueeze(1).expand( batch_size, 1, self.num_atoms) dist = dist.gather(1, action).squeeze(1) dist.data.clamp_(0.01, 0.99) loss = -(proj_dist * dist.log()).sum(1) loss = loss.mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.current_model.reset_noise() self.target_model.reset_noise() self.losses.append(loss.item()) def learn(self, num_frames=15000, batch_size=32): all_rewards = [] episode_reward = 0 state = self.env.reset() for frame_idx in range(1, num_frames + 1): action = self.current_model.act(state) next_state, reward, done, _ = self.env.step(action) self.replay_buffer.push(state, action, reward, next_state, done) state = next_state episode_reward += reward if done: state = self.env.reset() all_rewards.append(episode_reward) episode_reward = 0 if len(self.replay_buffer) > batch_size: self.train_step() # if frame_idx % 200 == 0: if frame_idx == num_frames: plt.figure(figsize=(20, 5)) plt.subplot(121) plt.title('frame %s. reward: %s' % (frame_idx, np.mean(all_rewards[-10:]))) plt.plot(all_rewards) plt.subplot(122) plt.title('loss') plt.plot(self.losses) plt.show() if frame_idx % 1000 == 0: self.update_target(self.current_model, self.target_model) print(frame_idx)
class SQL(): def __init__( self, env_name = 'Pendulum-v0', load_dir = './ckpt', log_dir = "./log", buffer_size = 1e6, seed = 1, max_episode_steps = None, batch_size = 64, discount = 0.99, learning_starts = 500, tau = 0.005, save_eps_num = 100, external_env = None ): self.env_name = env_name self.load_dir = load_dir self.log_dir = log_dir self.seed = seed self.max_episode_steps = max_episode_steps self.buffer_size = buffer_size self.batch_size = batch_size self.discount = discount self.learning_starts = learning_starts self.tau = tau self.save_eps_num = save_eps_num torch.manual_seed(self.seed) np.random.seed(self.seed) self.writer = SummaryWriter(log_dir=self.log_dir) if external_env == None: env = gym.make(self.env_name) else: env = external_env if self.max_episode_steps != None: env._max_episode_steps = self.max_episode_steps else: self.max_episode_steps = env._max_episode_steps self.env = NormalizedActions(env) #self.env = env self.action_dim = self.env.action_space.shape[0] self.state_dim = self.env.observation_space.shape[0] self.hidden_dim = 256 self.q_net = QNet(self.state_dim, self.action_dim, self.hidden_dim).to(device) self.policy_net = PolicyNet(self.state_dim, self.action_dim, self.hidden_dim).to(device) try: self.load(directory=self.load_dir, filename=self.env_name) print('Load model successfully !') except: print('WARNING: No model to load !') self.q_criterion = nn.MSELoss() q_lr = 3e-4 policy_lr = 3e-4 self.exploration_prob = 0.0 self.alpha = 1.0 self.value_alpha = 0.3 self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) self.replay_buffer = ReplayBuffer(self.buffer_size) self.total_steps = 0 self.episode_num = 0 self.episode_timesteps = 0 self.action_set = [] for j in range(32): self.action_set.append((np.sin((3.14*2/32)*j), np.cos((3.14*2/32)*j))) def save(self, directory, filename): if not os.path.exists(directory): os.makedirs(directory) torch.save(self.q_net.state_dict(), '%s/%s_q_net.pkl' % (directory, filename)) torch.save(self.policy_net.state_dict(), '%s/%s_policy_net.pkl' % (directory, filename)) def load(self, directory, filename): self.q_net.load_state_dict(torch.load('%s/%s_q_net.pkl' % (directory, filename))) self.policy_net.load_state_dict(torch.load('%s/%s_policy_net.pkl' % (directory, filename))) def forward_QNet(self, obs, action): obs = torch.FloatTensor(obs).to(device) action = torch.FloatTensor(action).to(device) q_pred = self.q_net(obs, action).detach().cpu().numpy()#[0] return q_pred def forward_PolicyNet(self, obs, noise): inputs = torch.FloatTensor([obs + noise]).unsqueeze(0).to(device) action_pred = self.policy_net(inputs) return action_pred.detach().cpu().numpy()[0][0] def rbf_kernel(self, input1, input2): return np.exp(-3.14*(np.dot(input1-input2,input1-input2))) def rbf_kernel_grad(self, input1, input2): diff = (input1-input2) mult_val = self.rbf_kernel(input1, input2) * -2 * 3.14 return [x * mult_val for x in diff] def train_step(self): current_state, current_action, current_reward, next_state, done = self.replay_buffer.sample(self.batch_size) # Perform updates on the Q-Network best_q_val_next = 0 for j in range(32): # Sample 32 actions and use them in the next state to get an estimate of the state value action_temp = [[self.action_set[j][1]]]*self.batch_size q_value_temp = self.forward_QNet(next_state, action_temp) q_value_temp = (1.0/self.value_alpha) * q_value_temp q_value_temp = np.exp(q_value_temp) / (1.0/32) best_q_val_next += q_value_temp * (1.0/32) best_q_val_next = self.value_alpha * np.log(best_q_val_next) current_reward = current_reward.reshape(self.batch_size, 1) predicted_q = self.forward_QNet(current_state, current_action) expected_q = current_reward + 0.99 * best_q_val_next expected_q = (1-self.alpha) * predicted_q + self.alpha * expected_q predicted_q = torch.FloatTensor(predicted_q).to(device) expected_q = torch.FloatTensor(expected_q).to(device) print('BUG::', type(predicted_q), type(expected_q)) self.q_optimizer.zero_grad() loss = self.q_criterion(predicted_q, expected_q) print(loss, type(loss)) loss.backward() self.q_optimizer.step() # Perform updates on the Policy-Network using SVGD print('current_state:',current_state) action_predicted = self.forward_PolicyNet(current_state, (0.0, 0.0, 0.0)) final_action_gradient = [0.0, 0.0] for j in range(32): action_temp = tuple(self.forward_PolicyNet(current_state, (np.random.normal(0.0, 0.5))).data.numpy()[0].tolist()) inputs_temp = torch.FloatTensor([current_state + action_temp]) predicted_q = self.q_net(inputs_temp) # Perform standard Q-value computation for each of the selected actions best_q_val_next = 0 for k in range(32): # Sample 32 actions and use them in the next state to get an estimate of the state value action_temp_2 = self.action_set[k] q_value_temp = (1.0/self.value_alpha) * self.forward_QNet(next_state, action_temp_2).data.numpy()[0][0] q_value_temp = np.exp(q_value_temp) / (1.0/32) best_q_val_next += q_value_temp * (1.0/32) best_q_val_next = self.value_alpha * np.log(best_q_val_next) expected_q = current_reward + 0.99 * best_q_val_next expected_q = (1-self.alpha) * predicted_q.data.numpy()[0][0] + self.alpha * expected_q expected_q = torch.FloatTensor([[expected_q]]) loss = self.q_criterion(predicted_q, expected_q) loss.backward() action_gradient_temp = [inputs_temp.grad.data.numpy()[0][2], inputs_temp.grad.data.numpy()[0][3]] kernel_val = self.rbf_kernel(list(action_temp), action_predicted.data.numpy()[0]) kernel_grad = self.rbf_kernel_grad(list(action_temp), action_predicted.data.numpy()[0]) final_temp_grad = ([x * kernel_val for x in action_gradient_temp] + [x * self.value_alpha for x in kernel_grad]) final_action_gradient[0] += (1.0/32) * final_temp_grad[0] final_action_gradient[1] += (1.0/32) * final_temp_grad[1] action_predicted.backward(torch.FloatTensor([final_action_gradient])) # Apply the updates using the optimizers self.q_optimizer.zero_grad() self.q_optimizer.step() self.policy_optimizer.zero_grad() self.policy_optimizer.step() def learn(self, max_steps=1e7): while self.total_steps < max_steps: state = self.env.reset() self.episode_timesteps = 0 episode_reward = 0 for step in range(self.max_episode_steps): action = self.forward_PolicyNet(state, (np.random.normal(0.0, 0.5), np.random.normal(0.0, 0.5), np.random.normal(0.0, 0.5))) if random.uniform(0.0, 1.0) < self.exploration_prob: x_val = random.uniform(-1.0, 1.0) action = (x_val, random.choice([-1.0, 1.0])*np.sqrt(1.0 - x_val*x_val)) next_state, reward, done, _ = self.env.step(action) self.replay_buffer.push(state, action, reward, next_state, done) state = next_state episode_reward += reward self.total_steps += 1 self.episode_timesteps += 1 if done or self.episode_timesteps == self.max_episode_steps: if len(self.replay_buffer) > self.learning_starts: for _ in range(self.episode_timesteps): self.train_step() self.episode_num += 1 if self.episode_num > 0 and self.episode_num % self.save_eps_num == 0: self.save(directory=self.load_dir, filename=self.env_name) self.writer.add_scalar('episode_reward', episode_reward, self.episode_num) break self.env.close()
class TD3(DDPG): def __init__(self, env_name='Pendulum-v0', load_dir='./ckpt', log_dir="./log", buffer_size=1e6, seed=1, max_episode_steps=None, noise_decay_steps=1e5, batch_size=64, discount=0.99, train_freq=100, policy_freq=2, learning_starts=500, tau=0.005, save_eps_num=100, external_env=None): self.env_name = env_name self.load_dir = load_dir self.log_dir = log_dir self.seed = seed self.max_episode_steps = max_episode_steps self.buffer_size = buffer_size self.noise_decay_steps = noise_decay_steps self.batch_size = batch_size self.discount = discount self.policy_freq = policy_freq self.learning_starts = learning_starts self.tau = tau self.save_eps_num = save_eps_num torch.manual_seed(self.seed) np.random.seed(self.seed) self.writer = SummaryWriter(log_dir=self.log_dir) if external_env == None: env = gym.make(self.env_name) else: env = external_env if self.max_episode_steps != None: env._max_episode_steps = self.max_episode_steps else: self.max_episode_steps = env._max_episode_steps self.env = NormalizedActions(env) self.noise = GaussianExploration(self.env.action_space, decay_period=self.noise_decay_steps) state_dim = self.env.observation_space.shape[0] action_dim = self.env.action_space.shape[0] hidden_dim = 256 self.value_net1 = QNet(state_dim, action_dim, hidden_dim).to(device) self.value_net2 = QNet(state_dim, action_dim, hidden_dim).to(device) self.policy_net = PolicyNet(state_dim, action_dim, hidden_dim).to(device) self.target_value_net1 = QNet(state_dim, action_dim, hidden_dim).to(device) self.target_value_net2 = QNet(state_dim, action_dim, hidden_dim).to(device) self.target_policy_net = PolicyNet(state_dim, action_dim, hidden_dim).to(device) try: self.load(directory=self.load_dir, filename=self.env_name) print('Load model successfully !') except: print('WARNING: No model to load !') soft_update(self.value_net1, self.target_value_net1, soft_tau=1.0) soft_update(self.value_net2, self.target_value_net2, soft_tau=1.0) soft_update(self.policy_net, self.target_policy_net, soft_tau=1.0) self.value_criterion = nn.MSELoss() policy_lr = 3e-4 value_lr = 3e-4 self.value_optimizer1 = optim.Adam(self.value_net1.parameters(), lr=value_lr) self.value_optimizer2 = optim.Adam(self.value_net2.parameters(), lr=value_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) self.replay_buffer = ReplayBuffer(self.buffer_size) self.total_steps = 0 self.episode_num = 0 self.episode_timesteps = 0 def save(self, directory, filename): if not os.path.exists(directory): os.makedirs(directory) torch.save(self.value_net1.state_dict(), '%s/%s_value_net1.pkl' % (directory, filename)) torch.save(self.value_net2.state_dict(), '%s/%s_value_net2.pkl' % (directory, filename)) torch.save(self.policy_net.state_dict(), '%s/%s_policy_net.pkl' % (directory, filename)) def load(self, directory, filename): self.value_net1.load_state_dict( torch.load('%s/%s_value_net1.pkl' % (directory, filename))) self.value_net2.load_state_dict( torch.load('%s/%s_value_net2.pkl' % (directory, filename))) self.policy_net.load_state_dict( torch.load('%s/%s_policy_net.pkl' % (directory, filename))) def train_step(self, step, noise_std=0.2, noise_clip=0.5): state, action, reward, next_state, done = self.replay_buffer.sample( self.batch_size) state = torch.FloatTensor(state).to(device) next_state = torch.FloatTensor(next_state).to(device) action = torch.FloatTensor(action).to(device) reward = torch.FloatTensor(reward).unsqueeze(1).to(device) done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device) next_action = self.target_policy_net(next_state) noise = torch.normal(torch.zeros(next_action.size()), noise_std).to(device) noise = torch.clamp(noise, -noise_clip, noise_clip) next_action += noise target_q_value1 = self.target_value_net1(next_state, next_action) target_q_value2 = self.target_value_net2(next_state, next_action) target_q_value = torch.min(target_q_value1, target_q_value2) expected_q_value = reward + (1.0 - done) * self.discount * target_q_value q_value1 = self.value_net1(state, action) q_value2 = self.value_net2(state, action) value_loss1 = self.value_criterion(q_value1, expected_q_value.detach()) value_loss2 = self.value_criterion(q_value2, expected_q_value.detach()) self.value_optimizer1.zero_grad() value_loss1.backward() self.value_optimizer1.step() self.value_optimizer2.zero_grad() value_loss2.backward() self.value_optimizer2.step() if step % self.policy_freq == 0: policy_loss = self.value_net1(state, self.policy_net(state)) policy_loss = -policy_loss.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() soft_update(self.value_net1, self.target_value_net1, soft_tau=self.tau) soft_update(self.value_net2, self.target_value_net2, soft_tau=self.tau) soft_update(self.policy_net, self.target_policy_net, soft_tau=self.tau) def learn(self, max_steps=1e7): while self.total_steps < max_steps: state = self.env.reset() self.episode_timesteps = 0 episode_reward = 0 for step in range(self.max_episode_steps): action = self.policy_net.get_action(state) action = self.noise.get_action(action, self.total_steps) next_state, reward, done, _ = self.env.step(action) self.replay_buffer.push(state, action, reward, next_state, done) state = next_state episode_reward += reward self.total_steps += 1 self.episode_timesteps += 1 if done or self.episode_timesteps == self.max_episode_steps: if len(self.replay_buffer) > self.learning_starts: for i in range(self.episode_timesteps): self.train_step(i) self.episode_num += 1 if self.episode_num > 0 and self.episode_num % self.save_eps_num == 0: self.save(directory=self.load_dir, filename=self.env_name) self.writer.add_scalar('episode_reward', episode_reward, self.episode_num) break self.env.close()
class QuantileRegressionDQN(object): def __init__(self, env_id="CartPole-v0", num_quant=51, Vmin=-10, Vmax=10, batch_size=32): self.env_id = env_id self.env = gym.make(self.env_id) self.num_quant = num_quant self.Vmin = Vmin self.Vmax = Vmax self.batch_size = batch_size self.current_model = TinyQRDQN(self.env.observation_space.shape[0], self.env.action_space.n, self.num_quant) self.target_model = TinyQRDQN(self.env.observation_space.shape[0], self.env.action_space.n, self.num_quant) if USE_CUDA: self.current_model = self.current_model.cuda() self.target_model = self.target_model.cuda() self.optimizer = optim.Adam(self.current_model.parameters()) self.replay_buffer = ReplayBuffer(10000) self.update_target(self.current_model, self.target_model) self.losses = [] def update_target(self, current_model, target_model): target_model.load_state_dict(current_model.state_dict()) def projection_distribution(self, dist, next_state, reward, done): next_dist = self.target_model(next_state) next_action = next_dist.mean(2).max(1)[1] next_action = next_action.unsqueeze(1).unsqueeze(1).expand( self.batch_size, 1, self.num_quant) next_dist = next_dist.gather(1, next_action).squeeze(1).cpu().data expected_quant = reward.unsqueeze( 1) + 0.99 * next_dist * (1 - done.unsqueeze(1)) expected_quant = expected_quant quant_idx = torch.sort(dist, 1, descending=False)[1] tau_hat = torch.linspace(0.0, 1.0 - 1. / self.num_quant, self.num_quant) + 0.5 / self.num_quant tau_hat = tau_hat.unsqueeze(0).repeat(self.batch_size, 1) quant_idx = quant_idx.cpu().data batch_idx = np.arange(self.batch_size) tau = tau_hat[:, quant_idx][batch_idx, batch_idx] return tau, expected_quant def train_step(self): state, action, reward, next_state, done = self.replay_buffer.sample( self.batch_size) state = torch.FloatTensor(np.float32(state)) next_state = torch.FloatTensor(np.float32(next_state)) action = torch.LongTensor(action) reward = torch.FloatTensor(reward) done = torch.FloatTensor(np.float32(done)) dist = self.current_model(state) action = action.unsqueeze(1).unsqueeze(1).expand( self.batch_size, 1, self.num_quant) dist = dist.gather(1, action).squeeze(1) tau, expected_quant = self.projection_distribution( dist, next_state, reward, done) k = 1 u = expected_quant - dist huber_loss = 0.5 * u.abs().clamp(min=0.0, max=k).pow(2) huber_loss += k * (u.abs() - u.abs().clamp(min=0.0, max=k)) quantile_loss = (tau - (u < 0).float()).abs() * huber_loss loss = quantile_loss.sum() / self.num_quant self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm(self.current_model.parameters(), 0.5) self.optimizer.step() self.losses.append(loss.item()) def learn(self, num_frames=10000, batch_size=32, epsilon_start=1.0, epsilon_final=0.01, epsilon_decay=500): all_rewards = [] episode_reward = 0 state = self.env.reset() for frame_idx in range(1, num_frames + 1): action = self.current_model.act( state, epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)) next_state, reward, done, _ = self.env.step(action) self.replay_buffer.push(state, action, reward, next_state, done) state = next_state episode_reward += reward if done: state = self.env.reset() all_rewards.append(episode_reward) episode_reward = 0 if len(self.replay_buffer) > batch_size: self.train_step() # if frame_idx % 200 == 0: if frame_idx == num_frames: clear_output(True) plt.figure(figsize=(20, 5)) plt.subplot(121) plt.title('frame %s. reward: %s' % (frame_idx, np.mean(all_rewards[-10:]))) plt.plot(all_rewards) plt.subplot(122) plt.title('loss') plt.plot(self.losses) plt.show() if frame_idx % 1000 == 0: self.update_target(self.current_model, self.target_model) print(frame_idx)