class ELBOLoss(nn.Module): params = ['elbo', 'kld', 'nll'] def __init__(self, opt=None): super(ELBOLoss, self).__init__() self.log = Logger(*ELBOLoss.params) self._log = Logger(*ELBOLoss.params) self.kld = None self.nll = None self.opt = opt def __call__(self, x, z, xx): self.kld = Losses.kld(z) self.nll = Losses.nll(x, xx) self._log.append(self.logger()) if self.opt is not None: loss = self.kld + self.nll loss.backward() self.opt.step() self.opt.optimizer.zero_grad() else: return self.kld, self.nll def evolve(self): self.log.append(self._log.mean()) self._log.clear() if self.opt is not None: self.opt.evolve() def logger(self): kld, nll = self.kld.item(), self.nll.item() return dict(zip(ELBOLoss.params, (kld + nll, kld, nll))) def print_summary(self): print(3 * " " + print_format([self.log, self.opt.log], log10=True)) def get_logger(self, var): if var in ELBOLoss.params: return self.log elif var in OptimModule.params: return self.opt.log else: raise ValueError("invalid parameter %s" % var)
class InfoLoss(nn.Module): params = ['elbo', 'mmd', 'nll'] def __init__(self, mask, chain, opt, beta=1.0, gamma=500.0, reg=0.2): super(InfoLoss, self).__init__() self.log = Logger(*InfoLoss.params) self._log = Logger(*InfoLoss.params) self.mask = mask self.chain = tuple(chain) self.opt = opt self.beta = beta self.gamma = gamma self.reg = reg self.mmd = None self.nll = None def __call__(self, x, z, zz, xx): std_norm = rand_norm(0.0, 1.0, z.shape[0], z.shape[1]).unsqueeze(-1).double() self.mmd = self.beta * Losses.cmmd(z, std_norm) self.mmd += self.gamma * ( Losses.cmmd( zz, z, endo=self.chain[0:1], exo=self.chain[1:2], l=self.reg) + Losses.cmmd( zz, z, endo=self.chain[1:2], exo=self.chain[2:3], l=self.reg) + Losses.cmmd( zz, z, endo=self.chain[0:1], exo=self.chain[2:3], l=self.reg)) self.nll = Losses.nll(x, xx) self._log.append(self.logger()) if self.opt is not None: loss = self.mmd + self.nll loss.backward() self.opt.step() self.opt.optimizer.zero_grad() else: return self.mmd, self.nll def evolve(self): self.log.append(self._log.mean()) self._log.clear() if self.opt is not None: self.opt.evolve() def logger(self): mmd, nll = self.mmd.item(), self.nll.item() return dict(zip(InfoLoss.params, (mmd + nll, mmd, nll))) def print_summary(self): print(3 * " " + print_format([self.log, self.opt.log], log10=True)) def get_logger(self, var): if var in InfoLoss.params: return self.log elif var in OptimModule.params: return self.opt.log else: raise ValueError("invalid parameter %s" % var)