Example #1
0
    def forward(self, z, z_targ=None, context=None):

        ep = Variable(torch.zeros(z.size()).normal_())
        if cuda:
            ep = ep.cuda()
        if context is not None:
            h = self.actv(self.hidn(z, context))
        else:
            h = self.actv(self.hidn(z))
        if self.ifgate:
            gate = torch.sigmoid(self.gate(h))
            mean = gate * (self.mean(h)) + (1 - gate) * z
        else:
            mean = self.mean(h)
        lstd = self.lstd(h)
        std = F.softplus(lstd)
        z_ = mean + ep * std
        if z_targ is None:
            if not self.doubler:
                return z_, log_normal(z_, mean, torch.log(std) * 2).sum(1)
            else:
                return z_, log_normal(z_, mean.detach(),
                                      torch.log(std).detach() * 2).sum(1)
        else:
            return z_, log_normal(z_targ, mean, torch.log(std) * 2).sum(1)
Example #2
0
    def loss(self, x, weight=1.0, bits=0.0, breakdown=False):
        n = x.size(0)
        zero = utils.varify(np.zeros(1).astype('float32'))
        context = self.enc(x)

        ep = utils.varify(np.random.randn(n, self.dimz).astype('float32'))
        lgd = utils.varify(np.zeros(n).astype('float32'))
        if self.cuda:
            ep = ep.cuda()
            lgd = lgd.cuda()
            zero = zero.cuda()

        z, logdet, _ = self.inf((ep, lgd, context))
        pi = nn_.sigmoid(self.dec(z))

        logpx = -utils.bceloss(pi, x).sum(1).sum(1).sum(1)
        logqz = utils.log_normal(ep, zero, zero).sum(1) - logdet
        logpz = utils.log_normal(z, zero, zero).sum(1)
        kl = logqz - logpz
        if breakdown:
            return -logpx, -logqz, -logpz
        else:
            return (-(logpx - torch.max(kl * weight,
                                        torch.ones_like(kl) * bits)), -logpx,
                    kl)
Example #3
0
File: hiwvi.py Project: zizai/HIWAE
    def evaluate(self, z_, z0, context=None):
        n = z_.size(0)
        if context is None:
            context = torch.ones(n, self.dimc)
        if cuda:
            context = context.cuda()

        if self.wtype == 'ar':
            context = self.function_emb(context).view(n, self.k, self.dimc)
        else:  #elif self.wtype == 'bh' or self.wtype == 'pi':
            context = self.function_emb(context).view(n, 1, self.dimc)
            context = context.repeat((1, self.k, 1))

        context = context.view(n * self.k, self.dimc)
        z_ = z_.view(n * self.k, self.dim2)

        skip = self.skip((z_, context))[0]
        h = self.actv(self.hidn((z_, context))[0])
        mean = self.mean((h, context))[0] + skip
        lvar = self.lvar((h, context))[0]

        mean = mean.view(n, self.k, self.dim1)
        lvar = lvar.view(n, self.k, self.dim1)
        logr = log_normal(z0, mean, lvar).sum(2)

        return logr
Example #4
0
 def density(self, spl, lgd=None, context=None, zeros=None):
     lgd = self.lgd if lgd is None else lgd
     context = self.context if context is None else context
     zeros = self.zeros if zeros is None else zeros
     z, logdet, _ = self.mdl((spl, lgd, context))
     losses = -utils.log_normal(z, zeros, zeros + 1.0).sum(1) - logdet
     return -losses
Example #5
0
    def density(self, spl):
        n = spl.size(0)
        context = Variable(torch.FloatTensor(n, 1).zero_())
        lgd = Variable(torch.FloatTensor(n).zero_())
        zeros = Variable(torch.FloatTensor(n, self.p).zero_())
        if self.cuda:
            context = context.cuda()
            lgd = lgd.cuda()
            zeros = zeros.cuda()

        z, logdet, _ = self.flow((spl, lgd, context))
        losses = -utils.log_normal(z, zeros, zeros + 1.0).sum(1) - logdet
        return -losses
Example #6
0
 def loss(self, x):
     n = x.size(0)
     zero = utils.varify(np.zeros(1).astype('float32'))
     if cuda:
         zero = zero.cuda()
     context = self.enc(x)
     if self.mode == 'iwae':
         context = context.repeat(1,self.niw).view(n*self.niw,self.dimc)
         z, logq = self.qnet.sample(context, n*self.niw)
         logq = logq.view(n, self.niw)
         logr = 0
     elif self.mode == 'hiwae':
         if self.dep == 0:
             z0 = list()
             z = list()
             logq = list()
             for j in range(self.niw):
                 z0_, z_, logq_ = self.qnet.sample(context)
                 # z0_: batch_size x dimz
                 # z_: batch_size x niw x dimz
                 # logq_: batch_size x niw
                 z0.append(z0_.unsqueeze(1))
                 z.append(z_[:,j:j+1])
                 logq.append(logq_[:,j:j+1])
             z0 = torch.cat(z0, 1)
             z = torch.cat(z, 1)
             logq = torch.cat(logq, 1)
             logr = self.rnet.evaluate(z, z0, context)
         elif self.dep == 1:
             z0, z, logq = self.qnet.sample(context)
             logr = self.rnet.evaluate(z, z0.unsqueeze(1), context)
         elif self.dep == 2:
             """
                 iwae with hierarchical q; baseline
             """
             context = context.repeat(1,self.niw).view(n*self.niw,self.dimc)
             z0, z, logq = self.qnet.sample(context)
             logr = self.rnet.evaluate(z, z0.unsqueeze(1), context)
             logq = logq.view(n, self.niw)
             logr = logr.view(n, self.niw)
             
         z = z.view(n*self.niw,self.dimz)
     
     pi = nn_.sigmoid(self.dec(z))
     pi = pi.view(n, self.niw, *x.size()[1:])
     logpx = - utils.bceloss(pi, x.unsqueeze(1)).sum(2).sum(2).sum(2)
     logpz = utils.log_normal(z, zero, zero).sum(1).view(n, self.niw)
     
     return logpx, logpz, logq, logr
Example #7
0
File: hiwvi.py Project: zizai/HIWAE
    def sample(self, context=None, n=None):
        ep0 = Variable(torch.zeros(n, self.dim).normal_())
        zero = Variable(torch.zeros(1))
        if cuda:
            ep0 = ep0.cuda()
            zero = zero.cuda()

        mean = self.mean(context)
        lstd = self.lstd(context)
        std = self.realify(lstd)

        z = mean + std * ep0
        logq0 = log_normal(z, mean, torch.log(std) * 2).sum(1)

        return z, logq0
Example #8
0
File: hiwvi.py Project: zizai/HIWAE
    def sample(self, context=None, n=None):
        if context is None:
            assert n is not None, 'context and n cannot both be None'
            context = torch.ones(n, self.dimc)

        n = context.size(0)

        ep = Variable(torch.zeros(n, self.dim2 * self.k).normal_())
        if cuda:
            ep = ep.cuda()
            context = context.cuda()
        z0, logq0 = self.z0.sample(context, n)
        logq0 = logq0.unsqueeze(1)

        skip = self.skip((z0, context))[0]
        h = self.actv(self.hidn((z0, context))[0])
        mean = self.mean((h, context))[0] + skip
        lstd = self.lstd((h, context))[0]
        std = F.softplus(lstd)

        ep = ep.view(n, self.k, self.dim2)
        mean = mean.view(n, self.k, self.dim2)
        lstd = lstd.view(n, self.k, self.dim2)
        std = std.view(n, self.k, self.dim2)
        z_ = mean + ep * std

        if self.wtype == 'ar' or self.wtype == 'pi':
            if not self.doubler:
                logq = logq0 + log_normal(z_, mean, torch.log(std) * 2).sum(2)
            else:
                logq = logq0 + log_normal(z_, mean.detach(),
                                          torch.log(std).detach() * 2).sum(2)
        elif self.wtype == 'bh':
            if not self.doubler:
                logq = logq0 + log_mean_exp(
                    log_normal(z_.unsqueeze(2), mean.unsqueeze(1),
                               torch.log(std).unsqueeze(1) * 2).sum(3),
                    2)[:, :, 0]
            else:
                logq = logq0 + log_mean_exp(
                    log_normal(z_.unsqueeze(2),
                               mean.detach().unsqueeze(1),
                               torch.log(std).unsqueeze(1).detach() *
                               2).sum(3), 2)[:, :, 0]
        elif self.wtype[0] == 'l':
            p = float(self.wtype[1:])  # power
            if not self.doubler:
                logq = logq0 + log_normal(z_, mean, torch.log(std) * 2).sum(2)
                den = log_sum_exp(
                    log_normal(z_.unsqueeze(2), mean.unsqueeze(1),
                               torch.log(std).unsqueeze(1) * 2).sum(3) * p,
                    2)[:, :, 0]
                nom = log_normal(z_, mean, torch.log(std) * 2).sum(2) * p
                logq = logq - (nom - den)

            else:
                logq = logq0 + log_normal(z_, mean.detach(),
                                          torch.log(std).detach() * 2).sum(2)
                den = log_sum_exp(
                    log_normal(z_.unsqueeze(2),
                               mean.detach().unsqueeze(1),
                               torch.log(std).detach().unsqueeze(1) * 2).sum(3)
                    * p, 2)[:, :, 0]
                nom = log_normal(z_, mean.detach(),
                                 torch.log(std).detach() * 2).sum(2) * p
                logq = logq - (nom - den)

            logq = logq - np.log(self.k)
        return z0, z_, logq
Example #9
0
def energy1(f):
    mu = torch.mul(torch.sin(x0.permute(1, 0) * 2.0 * np.pi * f + b0), a0)
    return -((mu - y0.permute(1, 0))**2 * (1 / 0.25)).sum(1)
    ll = utils.log_normal(y0.permute(1, 0), mu, zero).sum(1)
    return ll