def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        # print (self.B, 'B')
        # print (k)
        # fsdaf



        #q(v|x)
        out = x
        for i in range(len(self.qv_weights)-1):
            out = self.act_func(self.qv_weights[i](out))
        out = self.qv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        v = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar) #[P,B]

        #[PB,Z]
        v = v.view(-1,self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())

        xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = xv
        for i in range(len(self.qz_weights)-1):
            out = self.act_func(self.qz_weights[i](out))
        out = self.qz_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        self.B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        # print (eps.size(),'eps')
        # print (mean.size(),'mean')
        # print (self.P, 'P')

        # print (mean)
        mean = mean.contiguous().view(self.P,self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P,self.B,self.z_size)

        # print (mean)
        # mean = mean.contiguous().view(self.P,1,self.z_size)
        # logvar = logvar.contiguous().view(self.P,1,self.z_size)


        # print (mean.size(),'mean')

        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(self.P*self.B,self.z_size)
        # logvar = logvar.contiguous().view(self.P*self.B,self.z_size)

        logqz0 = lognormal333(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,self.z_size)

        # print (z.size())

        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i],z,v)
            logdetsum += logdet


        xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = xz
        for i in range(len(self.rv_weights)-1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        mean = mean.contiguous().view(self.P,self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P,self.B,self.z_size)

        v = v.view(k,self.B,self.z_size)
        logrvT = lognormal333(v, mean, logvar) #[P,B]

        z = z.view(k,self.B,self.z_size)

        # print(logqz0.size(), 'here')
        # print(logqv0.size())
        # print(logdetsum.size())
        # print(logrvT.size())

        logdetsum = logdetsum.view(k,self.B)

        # print (logqz0+logqv0-logdetsum-logrvT)

        # fadfdsa

        return z, logqz0+logqv0-logdetsum-logrvT
Exemple #2
0
    def sample(k):

        P = k

        # #Sample
        # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # logqz = lognormal(z, mean, logvar) #[P,B]

        # logpx = logposterior(z)
        # optimizer.zero_grad()

        #q(v|x)
        # out = x
        # for i in range(len(self.qv_weights)-1):
        #     out = self.act_func(self.qv_weights[i](out))
        # out = self.qv_weights[-1](out)
        # mean = out[:,:self.z_size]
        # logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(
            torch.FloatTensor(k, B,
                              z_size).normal_().type(model.dtype))  #[P,B,Z]
        # print (eps)
        v = eps.mul(torch.exp(.5 * logvar_v)) + mean_v  #[P,B,Z]
        logqv0 = lognormal(v, mean_v, logvar_v)  #[P,B]

        #[PB,Z]
        v = v.view(-1, model.z_size)
        # print (v)
        # fsaf

        # print(v)
        # fasd
        #[PB,X]
        # x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = v
        for i in range(len(qz_weights) - 1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean = out[:, :z_size]
        logvar = out[:, z_size:] + 5.

        # print (mean)

        # B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(
            torch.FloatTensor(k, B,
                              z_size).normal_().type(model.dtype))  #[P,B,Z]

        # print (mean)
        mean = mean.contiguous().view(P, B, model.z_size)
        logvar = logvar.contiguous().view(P, B, model.z_size)

        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(P*B,model.z_size)
        # logvar = logvar.contiguous().view(P*B,model.z_size)

        # print (z)
        # fad

        logqz0 = lognormal333(z, mean, logvar)  #[P,B]

        #[PB,Z]
        z = z.view(-1, z_size)

        logdetsum = 0.
        for i in range(n_flows):

            z, v, logdet = norm_flow(params[i], z, v)
            logdetsum += logdet

        # xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = z
        for i in range(len(rv_weights) - 1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean = out[:, :model.z_size]
        logvar = out[:, model.z_size:]

        mean = mean.contiguous().view(P, B, model.z_size)
        logvar = logvar.contiguous().view(P, B, model.z_size)

        v = v.view(k, B, model.z_size)
        logrvT = lognormal333(v, mean, logvar)  #[P,B]

        z = z.view(k, B, model.z_size)

        logq = logqz0 + logqv0 - logdetsum - logrvT

        # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT))

        return z, logq
    def sample(self, mean, logvar, k, logposterior):

        self.P = k
        self.B = mean.size()[0]

        if torch.cuda.is_available():
            self.grad_outputs = torch.ones(k, self.B).cuda()
        else:
            self.grad_outputs = torch.ones(k, self.B)


        gaus = Gaussian(self.hyper_config)

        # q(z0)
        z, logqz0 = gaus.sample(mean, logvar, k)

        # q(v0)
        zeros = Variable(torch.zeros(self.B, self.z_size)).cuda()
        v, logqv0 = gaus.sample(zeros, zeros, k)


        #[PB,Z]
        z = z.view(-1,self.z_size)
        v = v.view(-1,self.z_size)

        #Transform
        logdetsum = 0.
        for i in range(self.n_flows):

            params = self.flow_params[i]

            # z, v, logdet = self.norm_flow([self.flow_params[i]],z,v)
            z, v, logdet = self.norm_flow(params,z,v, logposterior)
            logdetsum += logdet

        logdetsum = logdetsum.view(k,self.B)

        #r(vT|x,zT)
        #r(vT|zT)  try that
        out = z #[PB,Z]
        # print (out.size())
        # fasda
        for i in range(len(self.rv_weights)-1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        # print (out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]
        # r_vt = Gaussian(self.hyper_config, mean, logvar)



        v = v.view(k, self.B, self.z_size)
        z = z.view(k, self.B, self.z_size)

        mean = mean.contiguous().view(k, self.B, self.z_size)
        logvar = logvar.contiguous().view(k, self.B, self.z_size)

        # print (mean.size()) #[PB,Z]
        # print (v.size())   #[P,B,Z]
        # print (self.B)
        # print (k)

        # logrvT = gaus.logprob(v, mean, logvar)
        logrvT = lognormal333(v, mean, logvar)

        # print (logqz0.size())
        # print (logqv0.size())
        # print (logdetsum.size())
        # print (logrvT.size())
        # fadsf




        logpz = logqz0+logqv0-logdetsum-logrvT

        return z, logpz
Exemple #4
0
    def sample(self, mean, logvar, k):

        self.B = mean.size()[0]
        gaus = Gaussian(self.hyper_config)

        # q(z0)
        z, logqz0 = gaus.sample(mean, logvar, k)

        # q(v0)
        zeros = Variable(torch.zeros(self.B, self.z_size)).cuda()
        v, logqv0 = gaus.sample(zeros, zeros, k)

        #[PB,Z]
        z = z.view(-1, self.z_size)
        v = v.view(-1, self.z_size)

        #Transform
        logdetsum = 0.
        for i in range(self.n_flows):

            params = self.flow_params[i]

            # z, v, logdet = self.norm_flow([self.flow_params[i]],z,v)
            z, v, logdet = self.norm_flow(params, z, v)
            logdetsum += logdet

        logdetsum = logdetsum.view(k, self.B)

        #r(vT|x,zT)
        #r(vT|zT)  try that
        out = z  #[PB,Z]
        # print (out.size())
        # fasda
        for i in range(len(self.rv_weights) - 1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        # print (out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]
        # r_vt = Gaussian(self.hyper_config, mean, logvar)

        v = v.view(k, self.B, self.z_size)
        z = z.view(k, self.B, self.z_size)

        mean = mean.contiguous().view(k, self.B, self.z_size)
        logvar = logvar.contiguous().view(k, self.B, self.z_size)

        # print (mean.size()) #[PB,Z]
        # print (v.size())   #[P,B,Z]
        # print (self.B)
        # print (k)

        # logrvT = gaus.logprob(v, mean, logvar)
        logrvT = lognormal333(v, mean, logvar)

        # print (logqz0.size())
        # print (logqv0.size())
        # print (logdetsum.size())
        # print (logrvT.size())
        # fadsf

        logpz = logqz0 + logqv0 - logdetsum - logrvT

        return z, logpz
    def sample(k):

        P = k

        # #Sample
        # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # logqz = lognormal(z, mean, logvar) #[P,B]

        # logpx = logposterior(z)
        # optimizer.zero_grad()


        #q(v|x)
        # out = x
        # for i in range(len(self.qv_weights)-1):
        #     out = self.act_func(self.qv_weights[i](out))
        # out = self.qv_weights[-1](out)
        # mean = out[:,:self.z_size]
        # logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z]
        # print (eps)
        v = eps.mul(torch.exp(.5*logvar_v)) + mean_v  #[P,B,Z]
        logqv0 = lognormal(v, mean_v, logvar_v) #[P,B]

        #[PB,Z]
        v = v.view(-1,model.z_size)
        # print (v)
        # fsaf

        # print(v)
        # fasd
        #[PB,X]
        # x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = v
        for i in range(len(qz_weights)-1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean = out[:,:z_size]
        logvar = out[:,z_size:] + 5.

        # print (mean)

        # B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z]

        # print (mean)
        mean = mean.contiguous().view(P,B,model.z_size)
        logvar = logvar.contiguous().view(P,B,model.z_size)

        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(P*B,model.z_size)
        # logvar = logvar.contiguous().view(P*B,model.z_size)



        # print (z)
        # fad

        logqz0 = lognormal333(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,z_size)

        logdetsum = 0.
        for i in range(n_flows):

            z, v, logdet = norm_flow(params[i],z,v)
            logdetsum += logdet


        # xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = z
        for i in range(len(rv_weights)-1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean = out[:,:model.z_size]
        logvar = out[:,model.z_size:]

        mean = mean.contiguous().view(P,B,model.z_size)
        logvar = logvar.contiguous().view(P,B,model.z_size)

        v = v.view(k,B,model.z_size)
        logrvT = lognormal333(v, mean, logvar) #[P,B]

        z = z.view(k,B,model.z_size)


        logq = logqz0+logqv0-logdetsum-logrvT

        # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT))

        return z, logq
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        # print (self.B, 'B')
        # print (k)
        # fsdaf

        #q(v|x)
        out = x
        for i in range(len(self.qv_weights) - 1):
            out = self.act_func(self.qv_weights[i](out))
        out = self.qv_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample v0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        v = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar)  #[P,B]

        #[PB,Z]
        v = v.view(-1, self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k, 1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())

        xv = torch.cat((x_tiled, v), 1)

        #q(z|x,v)
        out = xv
        for i in range(len(self.qz_weights) - 1):
            out = self.act_func(self.qz_weights[i](out))
        out = self.qz_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        self.B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        # print (eps.size(),'eps')
        # print (mean.size(),'mean')
        # print (self.P, 'P')

        # print (mean)
        mean = mean.contiguous().view(self.P, self.B, self.z_size)
        logvar = logvar.contiguous().view(self.P, self.B, self.z_size)

        # print (mean)
        # mean = mean.contiguous().view(self.P,1,self.z_size)
        # logvar = logvar.contiguous().view(self.P,1,self.z_size)

        # print (mean.size(),'mean')

        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(self.P*self.B,self.z_size)
        # logvar = logvar.contiguous().view(self.P*self.B,self.z_size)

        logqz0 = lognormal333(z, mean, logvar)  #[P,B]

        #[PB,Z]
        z = z.view(-1, self.z_size)

        # print (z.size())

        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i], z, v)
            logdetsum += logdet

        xz = torch.cat((x_tiled, z), 1)

        #r(vT|x,zT)
        out = xz
        for i in range(len(self.rv_weights) - 1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        mean = mean.contiguous().view(self.P, self.B, self.z_size)
        logvar = logvar.contiguous().view(self.P, self.B, self.z_size)

        v = v.view(k, self.B, self.z_size)
        logrvT = lognormal333(v, mean, logvar)  #[P,B]

        z = z.view(k, self.B, self.z_size)

        # print(logqz0.size(), 'here')
        # print(logqv0.size())
        # print(logdetsum.size())
        # print(logrvT.size())

        logdetsum = logdetsum.view(k, self.B)

        # print (logqz0+logqv0-logdetsum-logrvT)

        # fadfdsa

        return z, logqz0 + logqv0 - logdetsum - logrvT