def forward(self, z, z_targ=None, context=None): ep = Variable(torch.zeros(z.size()).normal_()) if cuda: ep = ep.cuda() if context is not None: h = self.actv(self.hidn(z, context)) else: h = self.actv(self.hidn(z)) if self.ifgate: gate = torch.sigmoid(self.gate(h)) mean = gate * (self.mean(h)) + (1 - gate) * z else: mean = self.mean(h) lstd = self.lstd(h) std = F.softplus(lstd) z_ = mean + ep * std if z_targ is None: if not self.doubler: return z_, log_normal(z_, mean, torch.log(std) * 2).sum(1) else: return z_, log_normal(z_, mean.detach(), torch.log(std).detach() * 2).sum(1) else: return z_, log_normal(z_targ, mean, torch.log(std) * 2).sum(1)
def loss(self, x, weight=1.0, bits=0.0, breakdown=False): n = x.size(0) zero = utils.varify(np.zeros(1).astype('float32')) context = self.enc(x) ep = utils.varify(np.random.randn(n, self.dimz).astype('float32')) lgd = utils.varify(np.zeros(n).astype('float32')) if self.cuda: ep = ep.cuda() lgd = lgd.cuda() zero = zero.cuda() z, logdet, _ = self.inf((ep, lgd, context)) pi = nn_.sigmoid(self.dec(z)) logpx = -utils.bceloss(pi, x).sum(1).sum(1).sum(1) logqz = utils.log_normal(ep, zero, zero).sum(1) - logdet logpz = utils.log_normal(z, zero, zero).sum(1) kl = logqz - logpz if breakdown: return -logpx, -logqz, -logpz else: return (-(logpx - torch.max(kl * weight, torch.ones_like(kl) * bits)), -logpx, kl)
def evaluate(self, z_, z0, context=None): n = z_.size(0) if context is None: context = torch.ones(n, self.dimc) if cuda: context = context.cuda() if self.wtype == 'ar': context = self.function_emb(context).view(n, self.k, self.dimc) else: #elif self.wtype == 'bh' or self.wtype == 'pi': context = self.function_emb(context).view(n, 1, self.dimc) context = context.repeat((1, self.k, 1)) context = context.view(n * self.k, self.dimc) z_ = z_.view(n * self.k, self.dim2) skip = self.skip((z_, context))[0] h = self.actv(self.hidn((z_, context))[0]) mean = self.mean((h, context))[0] + skip lvar = self.lvar((h, context))[0] mean = mean.view(n, self.k, self.dim1) lvar = lvar.view(n, self.k, self.dim1) logr = log_normal(z0, mean, lvar).sum(2) return logr
def density(self, spl, lgd=None, context=None, zeros=None): lgd = self.lgd if lgd is None else lgd context = self.context if context is None else context zeros = self.zeros if zeros is None else zeros z, logdet, _ = self.mdl((spl, lgd, context)) losses = -utils.log_normal(z, zeros, zeros + 1.0).sum(1) - logdet return -losses
def density(self, spl): n = spl.size(0) context = Variable(torch.FloatTensor(n, 1).zero_()) lgd = Variable(torch.FloatTensor(n).zero_()) zeros = Variable(torch.FloatTensor(n, self.p).zero_()) if self.cuda: context = context.cuda() lgd = lgd.cuda() zeros = zeros.cuda() z, logdet, _ = self.flow((spl, lgd, context)) losses = -utils.log_normal(z, zeros, zeros + 1.0).sum(1) - logdet return -losses
def loss(self, x): n = x.size(0) zero = utils.varify(np.zeros(1).astype('float32')) if cuda: zero = zero.cuda() context = self.enc(x) if self.mode == 'iwae': context = context.repeat(1,self.niw).view(n*self.niw,self.dimc) z, logq = self.qnet.sample(context, n*self.niw) logq = logq.view(n, self.niw) logr = 0 elif self.mode == 'hiwae': if self.dep == 0: z0 = list() z = list() logq = list() for j in range(self.niw): z0_, z_, logq_ = self.qnet.sample(context) # z0_: batch_size x dimz # z_: batch_size x niw x dimz # logq_: batch_size x niw z0.append(z0_.unsqueeze(1)) z.append(z_[:,j:j+1]) logq.append(logq_[:,j:j+1]) z0 = torch.cat(z0, 1) z = torch.cat(z, 1) logq = torch.cat(logq, 1) logr = self.rnet.evaluate(z, z0, context) elif self.dep == 1: z0, z, logq = self.qnet.sample(context) logr = self.rnet.evaluate(z, z0.unsqueeze(1), context) elif self.dep == 2: """ iwae with hierarchical q; baseline """ context = context.repeat(1,self.niw).view(n*self.niw,self.dimc) z0, z, logq = self.qnet.sample(context) logr = self.rnet.evaluate(z, z0.unsqueeze(1), context) logq = logq.view(n, self.niw) logr = logr.view(n, self.niw) z = z.view(n*self.niw,self.dimz) pi = nn_.sigmoid(self.dec(z)) pi = pi.view(n, self.niw, *x.size()[1:]) logpx = - utils.bceloss(pi, x.unsqueeze(1)).sum(2).sum(2).sum(2) logpz = utils.log_normal(z, zero, zero).sum(1).view(n, self.niw) return logpx, logpz, logq, logr
def sample(self, context=None, n=None): ep0 = Variable(torch.zeros(n, self.dim).normal_()) zero = Variable(torch.zeros(1)) if cuda: ep0 = ep0.cuda() zero = zero.cuda() mean = self.mean(context) lstd = self.lstd(context) std = self.realify(lstd) z = mean + std * ep0 logq0 = log_normal(z, mean, torch.log(std) * 2).sum(1) return z, logq0
def sample(self, context=None, n=None): if context is None: assert n is not None, 'context and n cannot both be None' context = torch.ones(n, self.dimc) n = context.size(0) ep = Variable(torch.zeros(n, self.dim2 * self.k).normal_()) if cuda: ep = ep.cuda() context = context.cuda() z0, logq0 = self.z0.sample(context, n) logq0 = logq0.unsqueeze(1) skip = self.skip((z0, context))[0] h = self.actv(self.hidn((z0, context))[0]) mean = self.mean((h, context))[0] + skip lstd = self.lstd((h, context))[0] std = F.softplus(lstd) ep = ep.view(n, self.k, self.dim2) mean = mean.view(n, self.k, self.dim2) lstd = lstd.view(n, self.k, self.dim2) std = std.view(n, self.k, self.dim2) z_ = mean + ep * std if self.wtype == 'ar' or self.wtype == 'pi': if not self.doubler: logq = logq0 + log_normal(z_, mean, torch.log(std) * 2).sum(2) else: logq = logq0 + log_normal(z_, mean.detach(), torch.log(std).detach() * 2).sum(2) elif self.wtype == 'bh': if not self.doubler: logq = logq0 + log_mean_exp( log_normal(z_.unsqueeze(2), mean.unsqueeze(1), torch.log(std).unsqueeze(1) * 2).sum(3), 2)[:, :, 0] else: logq = logq0 + log_mean_exp( log_normal(z_.unsqueeze(2), mean.detach().unsqueeze(1), torch.log(std).unsqueeze(1).detach() * 2).sum(3), 2)[:, :, 0] elif self.wtype[0] == 'l': p = float(self.wtype[1:]) # power if not self.doubler: logq = logq0 + log_normal(z_, mean, torch.log(std) * 2).sum(2) den = log_sum_exp( log_normal(z_.unsqueeze(2), mean.unsqueeze(1), torch.log(std).unsqueeze(1) * 2).sum(3) * p, 2)[:, :, 0] nom = log_normal(z_, mean, torch.log(std) * 2).sum(2) * p logq = logq - (nom - den) else: logq = logq0 + log_normal(z_, mean.detach(), torch.log(std).detach() * 2).sum(2) den = log_sum_exp( log_normal(z_.unsqueeze(2), mean.detach().unsqueeze(1), torch.log(std).detach().unsqueeze(1) * 2).sum(3) * p, 2)[:, :, 0] nom = log_normal(z_, mean.detach(), torch.log(std).detach() * 2).sum(2) * p logq = logq - (nom - den) logq = logq - np.log(self.k) return z0, z_, logq
def energy1(f): mu = torch.mul(torch.sin(x0.permute(1, 0) * 2.0 * np.pi * f + b0), a0) return -((mu - y0.permute(1, 0))**2 * (1 / 0.25)).sum(1) ll = utils.log_normal(y0.permute(1, 0), mu, zero).sum(1) return ll