def forward(self): a = torch.randn(3, 2) b = torch.rand(3, 2) c = torch.rand(3) log_probs = torch.randn(50, 16, 20).log_softmax(2).detach() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16, ), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long) return len( F.binary_cross_entropy(torch.sigmoid(a), b), F.binary_cross_entropy_with_logits(torch.sigmoid(a), b), F.poisson_nll_loss(a, b), F.cosine_embedding_loss(a, b, c), F.cross_entropy(a, b), F.ctc_loss(log_probs, targets, input_lengths, target_lengths), # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module F.hinge_embedding_loss(a, b), F.kl_div(a, b), F.l1_loss(a, b), F.mse_loss(a, b), F.margin_ranking_loss(c, c, c), F.multilabel_margin_loss(self.x, self.y), F.multilabel_soft_margin_loss(self.x, self.y), F.multi_margin_loss(self.x, torch.tensor([3])), F.nll_loss(a, torch.tensor([1, 0, 1])), F.huber_loss(a, b), F.smooth_l1_loss(a, b), F.soft_margin_loss(a, b), F.triplet_margin_loss(a, b, -b), # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments )
def optimize(self): if len(self.memory._storage) < self.batch_size: return beta = self.beta_scheduler.value(self.optim_steps) state, action, reward, new_state, done, _, indices = self.memory.sample( self.batch_size, beta) state = torch.as_tensor(np.vstack(state), dtype=torch.float32, device=device) action = torch.as_tensor(np.vstack(action), dtype=torch.float32, device=device) done = torch.as_tensor(np.vstack(1 - done), dtype=torch.float32, device=device) reward = torch.as_tensor(np.vstack(reward), dtype=torch.float32, device=device) new_state = torch.as_tensor(np.vstack(new_state), dtype=torch.float32, device=device) self.target_actor.eval() self.target_critic.eval() self.critic.train() self.actor.train() Q_target = self.target_critic.forward( new_state, self.target_actor.forward(new_state)) Y = reward + (done * self.gamma * Q_target) Q = self.critic.forward(state, action) TD_errors = torch.sub(Y, Q).squeeze(dim=-1) # Not considering weighted td errors as this approach is better # considering all 'PER' weights as 1.0 is a hyperparameter too! critic_loss = F.huber_loss(TD_errors, torch.zeros_like(TD_errors)) self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute & Update Actor losses actor_loss = torch.mean(-1.0 * self.critic.forward(state, self.actor(state))) self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() td_errors: np.ndarray = TD_errors.detach().cpu().numpy() new_priorities = np.abs(td_errors) + 1e-6 self.memory.update_priorities(indices, new_priorities) self._update_networks(self.tau) self.optim_steps += 1
def optimize(self): if len(self.memory._storage) < self.batch_size: return state, action, reward, new_state, done = self.memory.sample( self.batch_size) state = torch.as_tensor(np.vstack(state), dtype=torch.float32, device=device) action = torch.as_tensor(np.vstack(action), dtype=torch.float32, device=device) done = torch.as_tensor(np.vstack(1 - done), dtype=torch.float32, device=device) reward = torch.as_tensor(np.vstack(reward), dtype=torch.float32, device=device) new_state = torch.as_tensor(np.vstack(new_state), dtype=torch.float32, device=device) self.target_actor.eval() self.target_critic.eval() self.critic.train() self.actor.train() Q_target = self.target_critic.forward( new_state, self.target_actor.forward(new_state)) Y = reward + (done * self.gamma * Q_target) Q = self.critic.forward(state, action) TD_errors = torch.sub(Y, Q).squeeze(dim=-1) critic_loss = F.huber_loss(TD_errors, torch.zeros_like(TD_errors)) self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute & Update Actor losses actor_loss = torch.mean(-1.0 * self.critic.forward(state, self.actor(state))) self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() self._update_networks(self.tau)
def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) # [T, B, A, N] => [T, B, N] target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] target_q = target_q_dist.mean(-1) # [T, B, A, N] => [T, B, A] _a = target_q.argmax(-1) # [T, B] next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1) # [T, B, A, 1] # [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * next_max_action).sum(-2) target = n_step_return(BATCH.reward.repeat(1, 1, self.nums), self.gamma, BATCH.done.repeat(1, 1, self.nums), target_q_dist, BATCH.begin_mask.repeat(1, 1, self.nums)).detach() # [T, B, N] q_eval = q_dist.mean(-1, keepdim=True) # [T, B, 1] q_target = target.mean(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1], used for PER target = target.unsqueeze(-2) # [T, B, 1, N] q_dist = q_dist.unsqueeze(-1) # [T, B, N, 1] # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N] quantile_error = target - q_dist huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [N,] - [T, B, N, N] => [T, B, N, N] huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): time_step = BATCH.reward.shape[0] batch_size = BATCH.reward.shape[1] quantiles, quantiles_tiled = self._generate_quantiles( # [T*B, N, 1], [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.online_quantiles) # [T*B, N, 1] => [T, B, N, 1] quantiles = quantiles.view(time_step, batch_size, -1, 1) quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1] quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True) q_eval = quantiles_value.mean(0) # [N, T, B, 1] => [T, B, 1] _, select_quantiles_tiled = self._generate_quantiles( # [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.select_quantiles) select_quantiles_tiled = select_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] q_values = self.q_net( BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] q_values = q_values.mean(1) # [T, N, B, A] => [T, B, A] next_max_action = q_values.argmax(-1) # [T, B] next_max_action = F.one_hot( next_max_action, self.a_dim).float() # [T, B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.target_quantiles) target_quantiles_tiled = target_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N'*T*B, X] => [T, N'*B, X] target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N', B, A] target_quantiles_value = target_quantiles_value.swapaxes(0, 1) # [T, N', B, A] => [N', T, B, A] target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True) # [N', T, B, 1] target_q = target_quantiles_value.mean(0) # [T, B, 1] q_target = n_step_return(BATCH.reward, # [T, B, 1] self.gamma, BATCH.done, # [T, B, 1] target_q, # [T, B, 1] BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] # [N', T, B, 1] => [N', T, B] target_quantiles_value = target_quantiles_value.squeeze(-1) target_quantiles_value = target_quantiles_value.permute( 1, 2, 0) # [N', T, B] => [T, B, N'] quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles), self.gamma, BATCH.done.repeat(1, 1, self.target_quantiles), target_quantiles_value, BATCH.begin_mask.repeat(1, 1, self.target_quantiles)).detach() # [T, B, N'] # [T, B, N'] => [T, B, 1, N'] quantiles_value_target = quantiles_value_target.unsqueeze(-2) quantiles_value_online = quantiles_value.permute(1, 2, 0, 3) # [N, T, B, 1] => [T, B, N, 1] # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N'] quantile_error = quantiles_value_online - quantiles_value_target huber = F.huber_loss(quantiles_value_online, quantiles_value_target, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N'] huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N'] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def compute_loss(self): if len(self.memory) < 5: return None transitions = self.memory.sample(min(self.batch_size, len(self.memory))) batch = Transition(*zip(*transitions)) state_batch = dict() for k in batch.state[0].keys(): state_batch[k] = torch.stack([data[k] for data in batch.state]) assert state_batch['image'].max() < 1.1 for i, img in enumerate(state_batch['image']): img_trans = torch.as_tensor(self.transform(img.permute(1,2,0).numpy() * 255)).permute(2, 0, 1) / 255 state_batch['image'][i] = img_trans # import cv2 # cv2.imshow('old', img.numpy().transpose(1,2,0)) # cv2.imshow('new', state_batch['image'][i].numpy().transpose(1,2,0)) # cv2.waitKey(1000) action_batch = torch.cat(batch.action) reward_batch = torch.stack(batch.reward) next_states = batch.next_state # Compute Q(s_t, a) - the model computes Q(s_t), then we select the # columns of actions taken. These are the actions which would've been taken # for each batch state according to policy_net state_action_values = self.policy_net(state_batch) Q_values = state_action_values[numpy.arange(len(state_action_values)), action_batch] # Compute a mask of non-final states and concatenate the batch elements # (a final state would've been the one after which simulation ended) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool) non_final_next_states = [s for s in batch.next_state if s is not None] non_final_state = dict() if non_final_next_states: for k in non_final_next_states[0].keys(): non_final_state[k] = torch.stack([data[k] for data in non_final_next_states]) device = next(self.target_net.parameters()).device next_Q_values = torch.zeros(len(non_final_mask)).to(device) # argmax a' Q(s', a') next_Q_values[non_final_mask] = self.target_net(non_final_state).max(1)[0].detach() # E[r + gamma argmax a' Q(s', a', theta)] expected_Q_values = (next_Q_values * self.gamma) + reward_batch.to(device) # Compute Huber loss #loss = F.smooth_l1_loss(Q_values, expected_Q_values, beta=101) #loss = F.mse_loss(Q_values, expected_Q_values) loss = F.huber_loss(Q_values, expected_Q_values, delta=10) if torch.isnan(loss): import pdb;pdb.set_trace() self.iteration += 1 # Update the target network, copying all weights and biases in DQN if self.iteration % self.target_update == 0: logging.debug('update target network') with torch.no_grad(): for pol_param, target_param in zip(self.policy_net.parameters(), self.target_net.parameters()): mean = 0.4 * pol_param.detach().cpu().numpy() + 0.6 * target_param.detach().cpu().numpy() target_param[:] = torch.as_tensor(mean).to(pol_param) if self.iteration % self.save_interval == 0 and (self.iteration): self.save_memory() return loss