コード例 #1
0
ファイル: resconv.py プロジェクト: lim0606/pytorch-ardae-vae
    def logprob(self, input, sample_size=128, z=None):
        '''
        input: positive samples
        '''
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_channels, self.input_height,
                           self.input_height)
        ''' get log q(z|x) '''
        _, mu_qz, logvar_qz = self.encode(input)
        mu_qz = mu_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        z = self.encode.sample(mu_qz, logvar_qz)
        logposterior = logprob_gaussian(mu_qz,
                                        logvar_qz,
                                        z,
                                        do_unsqueeze=False,
                                        do_mean=False)
        logposterior = torch.sum(logposterior.view(batch_size, sample_size,
                                                   self.z_dim),
                                 dim=2)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    z,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        logit_x = []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _logit_x = self.decode(z[i, :, :])  # ssz x zdim
            logit_x += [_logit_x.detach().unsqueeze(0)]
        logit_x = torch.cat(logit_x, dim=0)  # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_channels, self.input_height,
            self.input_height)  # bsz x ssz x input_dim
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_x, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size,
                                                     -1),
                                  dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #2
0
    def logprob_w_prior(self, input, sample_size=128, z=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        ''' get z samples from p(z) '''
        # get prior (as unit normal dist)
        if z is None:
            mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
            logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
            z = sample_gaussian(mu_pz, logvar_pz)  # sample z
        ''' get log p(x|z) '''
        # decode
        _z = z.view(-1, self.z_dim)
        _, mu_x, logvar_x = self.decode(_z)  # bsz*ssz x zdim
        mu_x = mu_x.view(batch_size, sample_size, self.input_dim)
        logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim)
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = logprob_gaussian(mu_x,
                                         logvar_x,
                                         _input,
                                         do_unsqueeze=False,
                                         do_mean=False)
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x) '''
        logprob = loglikelihood  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #3
0
    def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        assert sample_size >= 2*self.z_dim
        #assert int(math.sqrt(sample_size))**2 == sample_size

        ''' get z and pseudo log q(newz|x) '''
        #z, newz = [], []
        #logposterior = []
        #inp = self.encode._forward_inp(input).detach()
        #for i in range(batch_size):
        #    _inp = inp[i:i+1, :].expand(sample_size, inp.size(1))
        #    _nos = self.encode._forward_nos(sample_size, std=std, device=input.device).detach()
        #    _z = self.encode._forward_all(_inp, _nos) # ssz x zdim
        #    z += [_z.detach().unsqueeze(0)]
        #z = torch.cat(z, dim=0) # bsz x ssz x zdim
        #_nz = int(math.sqrt(sample_size))
        _, _, _, _, z, _, _, _, _ = self.encode._forward(input, std=std, nz=sample_size) # bsz x ssz x zdim
        newz = []
        logposterior = []
        eye = torch.eye(self.z_dim, device=z.device)
        mu_qz = torch.mean(z, dim=1) # bsz x zdim
        for i in range(batch_size):
            _cov_qz = get_covmat(z[i, :, :]) + 1e-5*eye
            _rv_z = MultivariateNormal(mu_qz[i], _cov_qz)
            _newz = _rv_z.rsample(torch.Size([1, sample_size]))
            _logposterior = _rv_z.log_prob(_newz)

            newz += [_newz]
            logposterior += [_logposterior]
        newz = torch.cat(newz, dim=0) # bsz x ssz x zdim
        logposterior = torch.cat(logposterior, dim=0) # bsz x ssz

        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz

        ''' get log p(x|z) '''
        # decode
        mu_x, logvar_x = [], []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _mu_x, _logvar_x = self.decode(newz[i, :, :])
            mu_x += [_mu_x.detach().unsqueeze(0)]
            logvar_x += [_logvar_x.detach().unsqueeze(0)]
        mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim
        logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim
        loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False)
        loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz

        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp() # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1

        # return
        return logprob.mean()
コード例 #4
0
    def logprob(self, input, sample_size=128, z=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        ''' get log q(z|x) '''
        _, mu_qz, logvar_qz = self.encode(input)
        mu_qz = mu_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        z = self.encode.sample(mu_qz, logvar_qz)
        logposterior = logprob_gaussian(mu_qz,
                                        logvar_qz,
                                        z,
                                        do_unsqueeze=False,
                                        do_mean=False)
        logposterior = torch.sum(logposterior.view(batch_size, sample_size,
                                                   self.z_dim),
                                 dim=2)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    z,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        #mu_x, logvar_x = [], []
        #for i in range(batch_size):
        #    _, _mu_x, _logvar_x = self.decode(z[i, :, :]) # ssz x zdim
        #    mu_x += [_mu_x.detach().unsqueeze(0)]
        #    logvar_x += [_logvar_x.detach().unsqueeze(0)]
        #mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim
        #logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim
        _z = z.view(-1, self.z_dim)
        _, mu_x, logvar_x = self.decode(_z)  # bsz*ssz x zdim
        mu_x = mu_x.view(batch_size, sample_size, self.input_dim)
        logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim)
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = logprob_gaussian(mu_x,
                                         logvar_x,
                                         _input,
                                         do_unsqueeze=False,
                                         do_mean=False)
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #5
0
ファイル: mnist.py プロジェクト: lim0606/pytorch-ardae-vae
    def logprob_w_diag_gaussian_posterior(self,
                                          input,
                                          sample_size=128,
                                          z=None,
                                          std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        ''' get z '''
        z = []
        for i in range(sample_size):
            _z = self.encode(input, std=std)
            _z_flattened = _z.view(_z.size(1) * _z.size(2), -1)
            z += [_z_flattened.detach().unsqueeze(1)]
        z = torch.cat(z, dim=1)  # bsz x ssz x zdim
        mu_qz = torch.mean(z, dim=1)
        logvar_qz = torch.log(torch.var(z, dim=1) + 1e-10)
        ''' get pseudo log q(z|x) '''
        mu_qz = mu_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(
            batch_size, sample_size, self.z_dim)
        newz = sample_gaussian(mu_qz, logvar_qz)
        logposterior = logprob_gaussian(mu_qz,
                                        logvar_qz,
                                        newz,
                                        do_unsqueeze=False,
                                        do_mean=False)
        logposterior = torch.sum(logposterior.view(batch_size, sample_size,
                                                   self.z_dim),
                                 dim=2)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    newz,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        logit_x = []
        for i in range(sample_size):
            _, _logit_x = self.decode(newz[:, i, :])
            logit_x += [_logit_x.detach().unsqueeze(1)]
        logit_x = torch.cat(logit_x, dim=1)  # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_x, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #6
0
    def logprob(self, input, sample_size=128, z=None):
        #assert int(math.sqrt(sample_size))**2 == sample_size
        # init
        batch_size = input.size(0)
        sample_size1 = sample_size  #int(math.sqrt(sample_size))
        sample_size2 = 1  #int(math.sqrt(sample_size))
        input = input.view(batch_size, self.input_channels, self.input_height,
                           self.input_height)
        ''' get - (log q(z|z0,x) + log q(z0|z) - log p(z0|z,x) - log p(z)) '''
        ''' get log q(z0|x) '''
        _, mu_qz0, logvar_qz0, _ = self.aux_encode(input)
        mu_qz0 = mu_qz0.unsqueeze(1).expand(
            batch_size, sample_size1,
            self.z0_dim).contiguous().view(batch_size * sample_size1,
                                           self.z0_dim)  # bsz*ssz1 x z0_dim
        logvar_qz0 = logvar_qz0.unsqueeze(1).expand(
            batch_size, sample_size1,
            self.z0_dim).contiguous().view(batch_size * sample_size1,
                                           self.z0_dim)  # bsz*ssz1 x z0_dim
        z0 = self.aux_encode.sample(mu_qz0, logvar_qz0)  # bsz*ssz1 x z0_dim
        log_qz0 = logprob_gaussian(mu_qz0,
                                   logvar_qz0,
                                   z0,
                                   do_unsqueeze=False,
                                   do_mean=False)
        log_qz0 = torch.sum(log_qz0.view(batch_size, sample_size1,
                                         self.z0_dim),
                            dim=2)  # bsz x ssz1
        log_qz0 = log_qz0.unsqueeze(2).expand(
            batch_size, sample_size1, sample_size2).contiguous().view(
                batch_size, sample_size1 * sample_size2)  # bsz x ssz1*ssz2
        ''' get log q(z|z0,x) '''
        # forward
        _, mu_qz, logvar_qz, _ = self.encode(
            input, z0, nz=sample_size1)  # bsz*ssz1 x z_dim
        mu_qz = mu_qz.detach().repeat(1, sample_size2).view(
            batch_size * sample_size1, sample_size2, self.z_dim)
        logvar_qz = logvar_qz.detach().repeat(1, sample_size2).view(
            batch_size * sample_size1, sample_size2, self.z_dim)
        z = self.encode.sample(mu_qz, logvar_qz)  # bsz x ssz1 x ssz2 x z_dim
        log_qz = logprob_gaussian(mu_qz,
                                  logvar_qz,
                                  z,
                                  do_unsqueeze=False,
                                  do_mean=False)
        log_qz = torch.sum(log_qz.view(batch_size, sample_size1 * sample_size2,
                                       self.z_dim),
                           dim=2)  # bsz x ssz1*ssz2
        ''' get log p(z0|z,x) '''
        # encode
        _z0 = z0.unsqueeze(1).expand(batch_size * sample_size1, sample_size2,
                                     self.z0_dim).contiguous().view(
                                         batch_size, sample_size1,
                                         sample_size2, self.z0_dim).detach()
        _, mu_pz0, logvar_pz0 = self.aux_decode(
            input, z.view(-1, self.z_dim),
            nz=sample_size1 * sample_size2)  # bsz*ssz1 x z_dim
        mu_pz0 = mu_pz0.view(batch_size, sample_size1, sample_size2,
                             self.z0_dim)
        logvar_pz0 = logvar_pz0.view(batch_size, sample_size1, sample_size2,
                                     self.z0_dim)
        log_pz0 = logprob_gaussian(mu_pz0,
                                   logvar_pz0,
                                   _z0,
                                   do_unsqueeze=False,
                                   do_mean=False)  # bsz x ssz1 x ssz2 xz0_dim
        log_pz0 = torch.sum(log_pz0.view(batch_size,
                                         sample_size1 * sample_size2,
                                         self.z0_dim),
                            dim=2)  # bsz x ssz1*ssz2
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size * sample_size1, sample_size2,
                                self.z_dim)
        logvar_pz = input.new_zeros(batch_size * sample_size1, sample_size2,
                                    self.z_dim)
        log_pz = logprob_gaussian(mu_pz,
                                  logvar_pz,
                                  z,
                                  do_unsqueeze=False,
                                  do_mean=False)
        log_pz = torch.sum(log_pz.view(batch_size, sample_size1 * sample_size2,
                                       self.z_dim),
                           dim=2)  # bsz x ssz1*ssz2
        ''' get log p(x|z) '''
        # decode
        _input = input.unsqueeze(1).unsqueeze(1).expand(
            batch_size, sample_size1, sample_size2, self.input_channels,
            self.input_height,
            self.input_height)  # bsz x ssz1 x ssz2 x input_dim
        _z = z.view(-1, self.z_dim)
        #_, mu_x, logvar_x = self.decode(_z) # bsz*ssz1*ssz2 x zdim
        #mu_x = mu_x.view(batch_size, sample_size1, sample_size2, self.input_dim)
        #logvar_x = logvar_x.view(batch_size, sample_size1, sample_size2, self.input_dim)
        #loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False)
        _, logit_px = self.decode(_z)  # bsz*ssz1*ssz2 x zdim
        logit_px = logit_px.view(batch_size, sample_size1, sample_size2,
                                 self.input_channels, self.input_height,
                                 self.input_height)
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_px, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood.view(
            batch_size, sample_size1 * sample_size2, -1),
                                  dim=2)  # bsz x ssz1*ssz2
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + log_pz + log_pz0 - log_qz - log_qz0  # bsz x ssz1*ssz2
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #7
0
ファイル: mnist.py プロジェクト: lim0606/pytorch-ardae-vae
    def logprob_w_cov_gaussian_posterior(self,
                                         input,
                                         sample_size=128,
                                         z=None,
                                         std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        assert sample_size >= 2 * self.z_dim
        ''' get z and pseudo log q(newz|x) '''
        z, newz = [], []
        #cov_qz, rv_z = [], []
        logposterior = []
        inp = self.encode._forward_inp(input).detach()
        #for i in range(sample_size):
        for i in range(batch_size):
            _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1))
            _nos = self.encode._forward_nos(batch_size=sample_size,
                                            std=std,
                                            device=input.device).detach()
            _z = self.encode._forward_all(_inp, _nos)  # ssz x zdim
            z += [_z.detach().unsqueeze(0)]
        z = torch.cat(z, dim=0)  # bsz x ssz x zdim
        mu_qz = torch.mean(z, dim=1)  # bsz x zdim
        for i in range(batch_size):
            _cov_qz = get_covmat(z[i, :, :])
            _rv_z = MultivariateNormal(mu_qz[i], _cov_qz)
            _newz = _rv_z.rsample(torch.Size([1, sample_size]))
            _logposterior = _rv_z.log_prob(_newz)

            #cov_qz += [_cov_qz.unsqueeze(0)]
            #rv_z += [_rv_z]
            newz += [_newz]
            logposterior += [_logposterior]
        #cov_qz = torch.cat(cov_qz, dim=0) # bsz x zdim x zdim
        newz = torch.cat(newz, dim=0)  # bsz x ssz x zdim
        logposterior = torch.cat(logposterior, dim=0)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    newz,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        logit_x = []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _logit_x = self.decode(newz[i, :, :])  # ssz x zdim
            logit_x += [_logit_x.detach().unsqueeze(0)]
        logit_x = torch.cat(logit_x, dim=0)  # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_x, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #8
0
ファイル: mnist.py プロジェクト: lim0606/pytorch-ardae-vae
    def logprob_w_kde_posterior(self,
                                input,
                                sample_size=128,
                                z=None,
                                std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        assert sample_size >= 2 * self.z_dim
        ''' get z and pseudo log q(newz|x) '''
        z, newz = [], []
        logposterior = []
        inp = self.encode._forward_inp(input).detach()
        for i in range(batch_size):
            _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1))
            _nos = self.encode._forward_nos(sample_size,
                                            std=std,
                                            device=input.device).detach()
            _z = self.encode._forward_all(_inp, _nos)  # ssz x zdim
            z += [_z.detach().unsqueeze(0)]
        z = torch.cat(z, dim=0)  # bsz x ssz x zdim
        for i in range(batch_size):
            _z = z[i, :, :].cpu().numpy().T  # zdim x ssz
            kernel = stats.gaussian_kde(_z)
            _newz = kernel.resample(sample_size)  # zdim x ssz
            _logposterior = kernel.logpdf(_newz)  # ssz

            _newz = torch.from_numpy(_newz.T).float().to(
                input.device)  # ssz x zdim
            _logposterior = torch.from_numpy(_logposterior).float().to(
                input.device)  # ssz
            newz += [_newz.unsqueeze(0)]
            logposterior += [_logposterior.unsqueeze(0)]
        newz = torch.cat(newz, dim=0)  # bsz x ssz x zdim
        logposterior = torch.cat(logposterior, dim=0)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    newz,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        logit_x = []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _logit_x = self.decode(newz[i, :, :])  # ssz x zdim
            logit_x += [_logit_x.detach().unsqueeze(0)]
        logit_x = torch.cat(logit_x, dim=0)  # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_x, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
コード例 #9
0
    def infogain(self,
                 reps_context,
                 context_sizes,
                 reps_target,
                 target_sizes,
                 input_tuples,
                 num_steps=None,
                 beta=1.0,
                 std=1.0):
        # init
        num_episodes = len(reps_context)
        loss_kl = 0
        ''' forward posterior / prior '''
        # init states
        states_p = self.rnn_p.init_state(num_episodes,
                                         [self.z_height, self.z_width])
        states_q = self.rnn_q.init_state(num_episodes,
                                         [self.z_height, self.z_width])
        hiddens_p = [state_p[0] for state_p in states_p]
        hiddens_q = [state_q[0] for state_q in states_q]
        latents = []
        init_input_q = False
        init_input_p = False
        for i in range(num_steps if num_steps is not None else self.num_steps):
            # aggregate observations (posterior)
            if not init_input_q:
                reps_context = pad_sequence(reps_context, context_sizes)
                reps_context = torch.sum(reps_context, dim=1)
                reps_context = reps_context.view(-1, self.nc_context,
                                                 self.z_height, self.z_width)

                reps_target = pad_sequence(reps_target, target_sizes)
                reps_target = torch.sum(reps_target, dim=1)
                reps_target = reps_target.view(-1, self.nc_context,
                                               self.z_height, self.z_width)

                input_q = torch.cat([reps_target, reps_context], dim=1)
                init_input_q = True

            # forward posterior
            means_q, logvars_q, hiddens_q, states_q = self.rnn_q(
                input_q, states_q, hiddens_p)

            # sample z from posterior
            zs = self.rnn_q.sample(means_q, logvars_q)

            # aggregate observations (prior)
            if not init_input_p:
                input_p = reps_context
                init_input_p = True

            # forward prior
            _, means_p, logvars_p, hiddens_p, states_p = self.rnn_p(
                input_p, states_p, latents_q=zs)

            # append z to latent
            latents += [torch.cat(zs, dim=1).unsqueeze(1)
                        ] if len(zs) > 1 else [zs[0].unsqueeze(1)]

            # update accumulated KL
            for j in range(self.num_layers):
                #loss_kl += loss_kld_gaussian_vs_gaussian(means_q[j], logvars_q[j], means_p[j], logvars_p[j], do_sum=False)
                loss_kl += logprob_gaussian(
                    means_q[j],  #.view(num_episodes, -1),
                    logvars_q[j],  #.view(num_episodes, -1),
                    zs[j],  #.view(num_episodes, -1),
                    do_sum=False)
                loss_kl += -logprob_gaussian(
                    means_p[j],  #.view(num_episodes, -1),
                    logvars_p[j],  #.view(num_episodes, -1),
                    zs[j],  #.view(num_episodes, -1),
                    do_sum=False)
        ''' loss '''
        # additional loss info
        info = {}
        info['kl'] = loss_kl.detach()

        # return
        #return img_mean_recon, hpt_mean_recon, None, loss, info
        #return mean_recons, latents, loss, info
        return None, latents, loss_kl.detach(), info
コード例 #10
0
    def predict(self,
                reps_context,
                context_sizes,
                reps_target,
                target_sizes,
                input_tuples,
                num_steps=None,
                beta=1.0,
                std=1.0,
                is_grayscale=False,
                use_uint8=True):
        # init
        num_episodes = len(reps_context)
        logprob_kl = 0
        loss_kl = 0
        ''' forward posterior / prior '''
        # init states
        states_p = self.rnn_p.init_state(num_episodes,
                                         [self.z_height, self.z_width])
        states_q = self.rnn_q.init_state(num_episodes,
                                         [self.z_height, self.z_width])
        hiddens_p = [state_p[0] for state_p in states_p]
        hiddens_q = [state_q[0] for state_q in states_q]
        latents = []
        init_input_q = False
        init_input_p = False
        for i in range(num_steps if num_steps is not None else self.num_steps):
            # aggregate observations (posterior)
            if not init_input_q:
                reps_context = pad_sequence(reps_context, context_sizes)
                reps_context = torch.sum(reps_context, dim=1)
                reps_context = reps_context.view(-1, self.nc_context,
                                                 self.z_height, self.z_width)

                reps_target = pad_sequence(reps_target, target_sizes)
                reps_target = torch.sum(reps_target, dim=1)
                reps_target = reps_target.view(-1, self.nc_context,
                                               self.z_height, self.z_width)

                input_q = torch.cat([reps_target, reps_context], dim=1)
                init_input_q = True

            # forward posterior
            means_q, logvars_q, hiddens_q, states_q = self.rnn_q(
                input_q, states_q, hiddens_p)

            # sample z from posterior
            zs = self.rnn_q.sample(means_q, logvars_q)

            # aggregate observations (prior)
            if not init_input_p:
                input_p = reps_context
                init_input_p = True

            # forward prior
            _, means_p, logvars_p, hiddens_p, states_p = self.rnn_p(
                input_p, states_p, latents_q=zs)

            # append z to latent
            latents += [torch.cat(zs, dim=1).unsqueeze(1)
                        ] if len(zs) > 1 else [zs[0].unsqueeze(1)]

            # update accumulated KL
            for j in range(self.num_layers):
                loss_kl += loss_kld_gaussian_vs_gaussian(
                    means_q[j], logvars_q[j], means_p[j], logvars_p[j])
                logprob_kl += logprob_gaussian(
                    means_p[j],  #.view(num_episodes, -1),
                    logvars_p[j],  #.view(num_episodes, -1),
                    zs[j],  #.view(num_episodes, -1),
                    do_sum=False)
                logprob_kl += -logprob_gaussian(
                    means_q[j],  #.view(num_episodes, -1),
                    logvars_q[j],  #.view(num_episodes, -1),
                    zs[j],  #.view(num_episodes, -1),
                    do_sum=False)
        ''' likelihood '''
        info = {}
        info['logprob_mod_likelihoods'] = []
        logprob_likelihood = 0
        info['mod_likelihoods'] = []
        loss_likelihood = 0
        mean_recons = []
        for idx, (dim, input_tuple) in enumerate(zip(self.dims, input_tuples)):
            channels, height, width, _, mtype = dim
            mod_target, mod_queries, mod_target_indices, mod_batch_sizes = input_tuple
            if len(mod_queries) > 0:  # is not None:
                num_mod_data = len(mod_target)
                assert sum(mod_batch_sizes) == num_mod_data

                # run renderer (likelihood)
                mod_mean_recon = self._forward_renderer(
                    idx, mod_queries, latents, num_episodes, mod_batch_sizes,
                    mod_target_indices).detach()

                # convert to gray scale
                if mtype == 'image' and is_grayscale:
                    mod_mean_recon = rgb2gray(mod_mean_recon)
                    mod_target = rgb2gray(mod_target)
                    if not use_uint8:
                        mod_mean_recon = mod_mean_recon / 255
                        mod_target = mod_target / 255
                elif mtype == 'image' and use_uint8:
                    mod_mean_recon = 255 * mod_mean_recon
                    mod_target = 255 * mod_target

                # estimate recon loss
                loss_mod_likelihood = loss_recon_gaussian_w_fixed_var(
                    mod_mean_recon, mod_target, std=std,
                    add_logvar=False).detach()
                logprob_mod_likelihood = logprob_gaussian_w_fixed_var(
                    mod_mean_recon,  #.view(num_episodes, -1),
                    mod_target,  #.view(num_episodes, -1),
                    std=std,
                    do_sum=False).detach()

                # estimate recon loss without std
                loss_mod_likelihood_nostd = loss_recon_gaussian_w_fixed_var(
                    mod_mean_recon.detach(), mod_target).detach()
                #logprob_mod_likelihood_nostd = logprob_gaussian_w_fixed_var(
                #            mod_mean_recon.detach(), #.view(num_episodes, -1),
                #            mod_target, #.view(num_episodes, -1),
                #            do_sum=False).detach()

                # sum per episode
                logprob_mod_likelihood = sum_tensor_per_episode(
                    logprob_mod_likelihood, mod_batch_sizes,
                    mod_target_indices, num_episodes)
            else:
                mod_mean_recon = reps_context.new_zeros(
                    0, channels, height, width)
                loss_mod_likelihood = None
                loss_mod_likelihood_nostd = None
                logprob_mod_likelihood = None

            # add to loss_likelihood
            if loss_mod_likelihood is not None:
                loss_likelihood += loss_mod_likelihood
            if logprob_mod_likelihood is not None:
                logprob_likelihood += logprob_mod_likelihood

            # append to list
            mean_recons += [mod_mean_recon]
            info['mod_likelihoods'] += [loss_mod_likelihood]
            info['logprob_mod_likelihoods'] += [logprob_mod_likelihood]
        ''' loss '''
        # sum loss
        loss = loss_likelihood + beta * loss_kl
        logprob = logprob_likelihood + logprob_kl

        # additional loss info
        info['likelihood'] = loss_likelihood.detach() if type(
            loss_likelihood) is not int else 0
        info['kl'] = loss_kl.detach()

        # return
        #return img_mean_recon, hpt_mean_recon, None, loss, info
        #return mean_recons, latents, loss, info
        return mean_recons, latents, logprob, info
コード例 #11
0
ファイル: sac_ardae.py プロジェクト: lim0606/pytorch-ardae-rl
    def est_partition_func(
        self,
        sample_size=128,
        next_state_batch=None,
        mask_batch=None,
        memory=None,
        batch_size=None,
        ptflogvar=-2.,
    ):
        if memory is not None:
            assert batch_size is not None
            # sample
            _, _, _, next_state_batch, mask_batch = memory.sample(
                batch_size=batch_size)
            next_state_batch = torch.FloatTensor(next_state_batch).to(
                self.device)
            mask_batch = torch.FloatTensor(mask_batch).to(
                self.device).unsqueeze(1)
        else:
            assert next_state_batch is not None
            assert mask_batch is not None
            batch_size = next_state_batch.size(0)

        # context
        _, nxt_preact_mean, nxt_hidden, _ = self.policy.evaluate(
            next_state_batch, eval=True)
        nxt_preact_mean = nxt_preact_mean.view(batch_size, 1, -1).detach()
        if self.dae_ctx_type == 'state':
            nxt_context = next_state_batch.view(batch_size, 1, -1).detach()
        elif self.dae_ctx_type == 'hidden':
            nxt_context = nxt_hidden.view(batch_size, 1, -1).detach()

        # sample
        _nxt_preact_mean = nxt_preact_mean.expand(batch_size, sample_size,
                                                  self.num_actions)
        _nxt_preact_logvar = ptflogvar * nxt_preact_mean.new_ones(
            _nxt_preact_mean.size())
        _newz = sample_gaussian(_nxt_preact_mean,
                                _nxt_preact_logvar)  # bsz x ssz x zdim

        # proposal distribution
        logproposal = logprob_gaussian(
            _nxt_preact_mean,
            _nxt_preact_logvar,
            _newz,
            do_unsqueeze=False,
            do_mean=False,
        )  # bsz x ssz x 1
        logproposal = torch.sum(logproposal, dim=2, keepdim=True) \
                    - self.num_actions * math.log(self.std_scale) # bsz x ssz x 1

        # unnormalized distribution
        newz = _newz - nxt_preact_mean
        scaled_newz = self.std_scale * newz
        stdmat = torch.zeros(batch_size, sample_size, 1,
                             device=self.device).fill_(0)
        logp_ptfunc = (self.cdae.logprob(
            scaled_newz, nxt_context, std=stdmat,
            scale=self.std_scale).detach() - logproposal)

        logp_ptfunc_max, _ = torch.max(logp_ptfunc, dim=1, keepdim=True)
        rprob_ptfunc = (logp_ptfunc - logp_ptfunc_max).exp()  # relative prob
        logp_ptfunc = torch.log(
            torch.mean(rprob_ptfunc, dim=1, keepdim=True) +
            1e-12) + logp_ptfunc_max  # bsz x 1

        return logp_ptfunc.detach()