def get_loss(self, examples, step): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=step, x0=self.args.x0, k=self.args.k) explore = self.forward(examples) kl_loss, kl_weight = self.compute_kl_loss(explore['mean'], explore['logv'], step) kl_weight *= self.args.kl_factor nll_loss = torch.sum(explore['nll_loss']) / explore['batch_size'] kl_loss = kl_loss / explore['batch_size'] kl_item = kl_loss * kl_weight loss = kl_item + nll_loss return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'Model Score': nll_loss + kl_loss, 'Loss': loss, 'ELBO': loss, 'KL Weight': kl_weight, 'WD Drop': self.step_unk_rate, 'KL Item': kl_item, }
def get_loss(self, examples, train_iter, is_dis=False, **kwargs): self.step_unk_rate = wd_anneal_function( unk_max=self.unk_rate, anneal_function=self.args.unk_schedule, step=train_iter, x0=self.args.x0, k=self.args.k ) explore = self.forward(examples, is_dis, norm_by_word=False) if is_dis: return explore sem_kl, kl_weight = self.compute_kl_loss( mean=explore['sem_mean'], logv=explore['sem_logv'], step=train_iter, ) syn_kl, _ = self.compute_kl_loss( mean=explore['syn_mean'], logv=explore['syn_logv'], step=train_iter, ) batch_size = explore['batch_size'] kl_weight *= self.args.kl_factor kl_loss = (self.args.sem_weight * sem_kl + self.args.syn_weight * syn_kl) / ( self.args.sem_weight + self.args.syn_weight) kl_loss /= batch_size mul_loss = explore['mul'] / batch_size adv_loss = explore['adv'] / batch_size nll_loss = explore['nll_loss'] / batch_size kl_item = kl_loss * kl_weight args = self.args return { 'KL Loss': kl_loss, 'NLL Loss': nll_loss, 'MUL Loss': mul_loss, 'ADV Loss': adv_loss, 'KL Weight': kl_weight, 'KL Item': kl_item, 'Model Score': kl_loss + nll_loss, 'ELBO': kl_item + nll_loss, 'Loss': kl_item + nll_loss + mul_loss - adv_loss if train_iter > args.warm_up else kl_item + nll_loss + mul_loss, 'SYN KL Loss': syn_kl / explore['batch_size'], 'SEM KL Loss': sem_kl / explore['batch_size'], }