Пример #1
0
    def _compute_mse_nlp(self, input, target, size_average=True, out=False):
        """Evaluate the MSE and Negative Log Probability.

        Args:
            input (Tensor): (N, iC, iH, iW)
            target (Tensor): (N, oC, oH, oW)
            size_average (bool)
            out (bool): If True, return output of `bayes_nn` w. `input`

        Returns:
            (mse, nlp) if `out` is False, o.w. (mse, nlp, output)
            where output is of size (S, N, oC, oH, oW)
        """
        # S x N x oC x oH x oW
        output = self.forward(input)
        # S x 1 x 1 x 1 x 1
        log_beta = self.log_beta.unsqueeze(-1).unsqueeze(-1).unsqueeze(
            -1).unsqueeze(-1)
        log_2pi_S = torch.tensor(
            0.5 * target[0].numel() * math.log(2 * math.pi) +
            math.log(self.n_samples),
            device=device)
        # S x N
        exponent = - 0.5 * (log_beta.exp() * ((target - output) ** 2)).view(
            self.n_samples, target.size(0), -1).sum(-1) \
                   + 0.5 * target[0].numel() * self.log_beta.unsqueeze(-1)

        # n = target[0].numel()
        nlp = -log_sum_exp(exponent, dim=0).mean() + log_2pi_S
        mse = ((target - output.mean(0))**2).mean()

        if not size_average:
            mse *= target.numel()
            nlp *= target.size(0)
        if not out:
            return mse, nlp
        else:
            return mse, nlp, output
Пример #2
0
    def importance_sampling_mi(self, gaussian_dist, n_sample):
        assert n_sample % _n_sample == 0

        B = gaussian_dist.mean.shape[0]

        samplify = {'log_qz': [], 'log_qzx': [], 'z': []}
        for sample_id in range(n_sample // _n_sample):
            # ----- Sampling -----
            _z = gaussian_dist.rsample(torch.Size(
                [_n_sample]))  # shape = (_n_sample, n_batch, latent_dim)
            assert tuple(_z.shape) == (_n_sample, B, self.latent_dim)

            _log_qzx = gaussian_dist.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)
            _log_qz = gaussian_dist.log_prob(
                _z.unsqueeze(2).expand(-1, -1, B, -1)).sum(
                    3)  # shape = (_n_sample, n_batch, n_batch)
            # Exclude itself.
            _log_qz.masked_fill_(
                gpu_wrapper(torch.eye(B).long()).eq(1).unsqueeze(0).expand(
                    _n_sample, -1, -1),
                -float('inf'))  # shape = (_n_sample, n_batch, n_batch)
            _log_qz = (log_sum_exp(_log_qz, dim=2) - np.log(B - 1)
                       )  # shape = (_n_sample, n_batch)

            samplify['log_qzx'].append(
                _log_qzx)  # shape = (_n_sample, n_batch)
            samplify['log_qz'].append(_log_qz)  # shape = (_n_sample, n_batch)
            samplify['z'].append(_z)  # shape = (_n_sample, n_batch, out_dim)

        for key in samplify.keys():
            samplify[key] = torch.cat(samplify[key],
                                      dim=0)  # shape = (n_sample, ?)

        # ----- Importance sampling for MI -----
        mi = samplify['log_qzx'].mean(0) - samplify['log_qz'].mean(0)

        return mi, samplify['z'].transpose(0, 1)
Пример #3
0
    def importance_sampling(self, gaussian_dist, go, eos, n_sample):
        B = go.shape[0]
        assert n_sample % _n_sample == 0

        samplify = {
            'xent': [],
            'log_pz': [],
            'log_pxz': [],
            'log_qzx': [],
            'z': []
        }
        for sample_id in range(n_sample // _n_sample):

            # ----- Sampling -----
            _z = gaussian_dist.rsample(torch.Size(
                [_n_sample]))  # shape = (_n_sample, n_batch, latent_dim)
            assert tuple(_z.shape) == (_n_sample, B, self.latent_dim)

            # ----- Initial Decoding States -----
            assert self.enc_bi
            _init_states = gpu_wrapper(
                torch.zeros([
                    self.enc_layers, _n_sample * B, self.n_dir * self.hid_dim
                ])).float(
                )  # shape = (layers, _n_sample * n_batch, n_dir * hid_dim)

            # ----- Importance sampling for NLL -----
            _logits = self.Decoder(
                init_states=
                _init_states,  # shape = (layers, _n_sample * n_batch, n_dir * hid_dim)
                latent_vector=_z.contiguous().view(
                    _n_sample * B,
                    self.latent_dim),  # shape = (_n_sample * n_batch, out_dim)
                helper=go.unsqueeze(0).expand(
                    _n_sample, -1, -1).contiguous().view(
                        _n_sample * B,
                        -1),  # shape = (_n_sample * n_batch, 15)
                test_lm=True)  # shape = (_n_sample * n_batch, 16, V)
            _xent = self.criterionSeq(
                _logits,  # shape = (_n_sample * n_batch, 16, V)
                eos.unsqueeze(0).expand(_n_sample, -1, -1).contiguous().view(
                    _n_sample * B, -1),  # shape = (_n_sample * n_batch, 16)
                keep_batch=True).view(_n_sample,
                                      B)  # shape = (_n_sample, n_batch)

            _log_pz = self.PriorGaussian.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)
            _log_pxz = -_xent  # shape = (_n_sample, n_batch)
            _log_qzx = gaussian_dist.log_prob(_z).sum(
                2)  # shape = (_n_sample, n_batch)

            samplify['xent'].append(_xent)  # shape = (_n_sample, n_batch)
            samplify['log_pz'].append(_log_pz)  # shape = (_n_sample, n_batch)
            samplify['log_pxz'].append(
                _log_pxz)  # shape = (_n_sample, n_batch)
            samplify['log_qzx'].append(
                _log_qzx)  # shape = (_n_sample, n_batch)
            samplify['z'].append(_z)  # shape = (_n_sample, n_batch, out_dim)

        for key in samplify.keys():
            samplify[key] = torch.cat(samplify[key],
                                      dim=0)  # shape = (n_sample, ?)

        ll = log_sum_exp(
            samplify['log_pz'] + samplify['log_pxz'] - samplify['log_qzx'],
            dim=0) - np.log(n_sample)  # shape = (n_batch, )
        nll = -ll  # shape = (n_batch, )

        # ----- Importance sampling for KL -----
        # kl = kl_with_isogaussian(gaussian_dist)  # shape = (n_batch, )
        kl = (samplify['log_qzx'] - samplify['log_pz']).mean(
            0)  # shape = (n_batch, )

        return samplify['xent'].mean(0), nll, kl, samplify['z'].transpose(0, 1)
Пример #4
0
    def importance_sampling_mi(self, Q0, last_states, n_sample):
        assert n_sample % _n_sample == 0

        B = Q0.mean.shape[0]

        samplify = {'log_qz': [], 'log_qzx': [], 'z': []}
        for sample_id in range(n_sample // _n_sample):
            # ----- Sampling -----
            _z0 = Q0.rsample(torch.Size(
                [_n_sample]))  # shape = (_n_sample, n_batch, out_dim)
            assert tuple(_z0.shape) == (_n_sample, B, self.latent_dim)

            # ----- Flows -----
            _zk, _sum_log_jacobian = self.Flows(
                z0=_z0.contiguous().view(
                    _n_sample * B,
                    self.latent_dim),  # shape = (_n_sample * n_batch, out_dim)
                cond=last_states.unsqueeze(0).expand(_n_sample, -1,
                                                     -1).contiguous().view(
                                                         _n_sample * B, -1)
                # shape = (_n_sample * n_batch, layers * n_dir * hid_dim)
            )
            # _zk.shape = (_n_sample * n_batch, latent_dim)
            # _sum_log_jacobian.shape = (_n_sample * n_batch, )
            _zk = _zk.view(
                _n_sample, B,
                self.latent_dim)  # shape = (_n_sample, n_batch, latent_dim)
            _sum_log_jacobian = _sum_log_jacobian.view(
                _n_sample, B)  # shape = (_n_sample, n_batch)

            # ----- Flows for the aggregate posterior -----
            _, _sum_log_jacobian_batch = self.Flows(
                z0=_z0.unsqueeze(2).expand(-1, -1, B, -1).contiguous().view(
                    _n_sample * B * B, self.latent_dim
                ),  # shape = (_n_sample * n_batch * n_batch, out_dim)
                cond=last_states.unsqueeze(0).unsqueeze(1).
                expand(_n_sample, B, -1, -1).contiguous().view(
                    _n_sample * B * B, -1
                )  # shape = (_n_sample * n_batch * n_batch, layers * n_dir * hid_dim)
            )
            # _sum_log_jacobian_batch.shape = (_n_sample * n_batch * n_batch, )
            _sum_log_jacobian_batch = _sum_log_jacobian_batch.view(
                _n_sample, B, B)  # shape = (_n_sample, n_batch, n_batch)

            _log_qzx = Q0.log_prob(_z0).sum(
                2) - _sum_log_jacobian  # shape = (_n_sample, n_batch)
            _log_qz = Q0.log_prob(_z0.unsqueeze(2).expand(-1, -1, B, -1)).sum(
                3
            ) - _sum_log_jacobian_batch  # shape = (_n_sample, n_batch, n_batch)
            # Exclude itself.
            _log_qz.masked_fill_(
                gpu_wrapper(torch.eye(B).long()).eq(1).unsqueeze(0).expand(
                    _n_sample, -1, -1),
                -float('inf'))  # shape = (_n_sample, n_batch, n_batch)
            _log_qz = (log_sum_exp(_log_qz, dim=2) - np.log(B - 1)
                       )  # shape = (_n_sample, n_batch)

            samplify['log_qzx'].append(
                _log_qzx)  # shape = (_n_sample, n_batch)
            samplify['log_qz'].append(_log_qz)  # shape = (_n_sample, n_batch)
            samplify['z'].append(_zk)  # shape = (_n_sample, n_batch, out_dim)

        for key in samplify.keys():
            samplify[key] = torch.cat(samplify[key],
                                      dim=0)  # shape = (n_sample, ?)

        # ----- Importance sampling for MI -----
        mi = samplify['log_qzx'].mean(0) - samplify['log_qz'].mean(0)

        return mi, samplify['z'].transpose(0, 1)