def sequence_ce_lm_loss( logits: torch.FloatTensor, lm_logits: torch.FloatTensor, targets: torch.LongTensor, mask: torch.FloatTensor, kl_coef: float, ): """ Sequence Cross Entropy with Language Model KL """ # shape : (batch, sequence_length, num_classes) log_probs = torch.log_softmax(logits, dim=-1) lm_probs = torch.softmax(lm_logits, dim=-1) # shape : (batch, sequence_length) negative_log_likelihood = -torch.gather( log_probs, dim=2, index=targets.unsqueeze(2) ).squeeze(2) # ignored mask and normalized by length lm_kl = ( torch.kl_div(input=log_probs, target=lm_probs, reduction=2) / log_probs.shape[1] ) # shape : (batch, sequence_length) loss = negative_log_likelihood * mask + kl_coef * lm_kl loss = loss.sum(1) / (mask.sum(1) + 1e-5) loss = loss.mean() return loss, lm_kl
def calculate_inception_score(p_yx): p_y = torch.unsqueeze(p_yx.mean(axis=0), 0) kl_d = torch.kl_div(torch.log(p_y), p_yx) sum_kl_d = kl_d.sum(axis=1) avg_kl_d = torch.mean(sum_kl_d) is_score = torch.exp(avg_kl_d) return is_score
def test_backwards(self): a = torch.zeros(5, 5, device='msnpu', requires_grad=True) self.assertEqual(msnpu_extension.get_test_int(), 0) b = torch.zeros(5, 5, device='msnpu') self.assertEqual(msnpu_extension.get_test_int(), 0) c = torch.kl_div(a, b) self.assertEqual(msnpu_extension.get_test_int(), 3) d = c.sum() self.assertEqual(msnpu_extension.get_test_int(), 2) d.backward() self.assertEqual(msnpu_extension.get_test_int(), 4)
def train(self, T_max, graph_name=None): step = 0 self.num_lookahead = 5 self.reset_workers() self.wait_for_workers() stat = { 'ploss': [], 'vloss': [], 'score': [], 'int_reward': [], 'entropy': [], 'fwd_kl_div': [], 'running_loss': 0 } reward_tracker = RunningMeanStd() reward_buffer = np.empty((self.batch_size, self.num_lookahead), dtype=np.float32) while step < T_max: # these will keep tensors, which we'll use later for backpropagation values = [] log_probs = [] rewards = [] entropies = [] actions = [] actions_pred = [] features = [] features_pred = [] state = torch.from_numpy(self.sh_state).to(self.device) for i in range(self.num_lookahead): step += self.batch_size logit, value = self.model(state) prob = torch.softmax(logit, dim=1) log_prob = torch.log_softmax(logit, dim=1) entropy = -(prob * log_prob).sum(1, keepdim=True) action = prob.multinomial(1) sampled_lp = log_prob.gather(1, action) # one-hot action oh_action = torch.zeros(self.batch_size, self.num_actions, device=self.device).scatter_( 1, action, 1) self.broadcast_actions(action) self.wait_for_workers() next_state = torch.from_numpy(self.sh_state).to(self.device) s1, s1_pred, action_pred = self.icm(state, oh_action, next_state) with torch.no_grad(): int_reward = 0.5 * (s1 - s1_pred).pow(2).sum(dim=1, keepdim=True) reward_buffer[:, i] = int_reward.cpu().numpy().ravel() state = next_state # save variables for gradient descent values.append(value) log_probs.append(sampled_lp) rewards.append(int_reward) entropies.append(entropy) if not self.random: actions.append(action.flatten()) actions_pred.append(action_pred) features.append(s1) features_pred.append(s1_pred) stat['entropy'].append(entropy.sum(dim=1).mean().item()) stat['fwd_kl_div'].append( torch.kl_div(s1_pred, s1).mean().item()) # may have to update reward_buffer with gamma first reward_mean, reward_std, count = mpi_moments(reward_buffer.ravel()) reward_tracker.update_from_moments(reward_mean, reward_std**2, count) std = np.sqrt(reward_tracker.var) rewards = [rwd / std for rwd in rewards] for rwd in rewards: stat['int_reward'].append(rwd.mean().item()) state = torch.from_numpy(self.sh_state.astype(np.float32)).to( self.device) with torch.no_grad(): _, R = self.model(state) # R is the estimated return values.append(R) ploss = 0 vloss = 0 fwd_loss = 0 inv_loss = 0 delta = torch.zeros((self.batch_size, 1), dtype=torch.float, device=self.device) for i in reversed(range(self.num_lookahead)): R = rewards[i] + self.gamma * R advantage = R - values[i] vloss += (0.5 * advantage.pow(2)).mean() delta = rewards[i] + self.gamma * values[ i + 1].detach() - values[i].detach() ploss += -(log_probs[i] * delta + 0.01 * entropies[i]).mean() # beta = 0.01 fwd_loss += 0.5 * (features[i] - features_pred[i]).pow(2).sum(dim=1).mean() if not self.random: inv_loss += self.cross_entropy(actions_pred[i], actions[i]) self.optim.zero_grad() # inv_loss is 0 if using random features loss = ploss + vloss + fwd_loss + inv_loss # 2018 Large scale curiosity paper simply sums them (no lambda and beta anymore) loss.backward() torch.nn.utils.clip_grad_norm_( list(self.model.parameters()) + list(self.icm.parameters()), 40) self.optim.step() while not self.channel.empty(): score = self.channel.get() stat['score'].append(score) stat['ploss'].append(ploss.item() / self.num_lookahead) stat['vloss'].append(vloss.item() / self.num_lookahead) stat['running_loss'] = 0.99 * stat[ 'running_loss'] + 0.01 * loss.item() / self.num_lookahead if len(stat['score']) > 20 and step % (self.batch_size * 1000) == 0: now = datetime.datetime.now().strftime("%H:%M") print( f"Step {step: <10} | Running loss: {stat['running_loss']:.4f} | Running score: {np.mean(stat['score'][-10:]):.2f} | Time: {now}" ) if graph_name is not None and step % (self.batch_size * 10000) == 0: plot(step, stat['score'], stat['int_reward'], stat['ploss'], stat['vloss'], stat['entropy'], name=graph_name)