예제 #1
0
    def _zero_mean_forward(self, x):
        if not isinstance(x, tuple):
            x_mean = x
            x_var = None
        else:
            x_mean = x[0]
            x_var = x[1]

        y_mean = F.linear(x_mean, torch.zeros_like(self.W).t()) + self.bias

        W_var = self._get_var(self.W_logvar)
        bias_var = self._get_var(self.bias_logvar)

        if x_var is None:
            xx = x_mean * x_mean
            y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
        else:
            y_var = compute_linear_var(x_mean, x_var, torch.zeros_like(self.W),
                                       W_var, self.bias, bias_var)

        if self.deterministic:
            return y_mean, y_var
        else:
            dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
            sample = dst.rsample()
            return sample, None
예제 #2
0
    def _mc_forward(self, x):
        if isinstance(x, tuple):
            x_mean = x[0]
            x_var = x[1]
        else:
            x_mean = x

        if self.zero_mean:
            lrt_mean = 0.0
        else:
            lrt_mean = F.linear(x_mean, self.W)
        if self.bias is not None:
            lrt_mean = lrt_mean + self.bias

        sigma2 = torch.exp(self.log_alpha) * self.W * self.W
        if self.permute_sigma:
            sigma2 = sigma2.view(-1)[torch.randperm(
                self.in_features * self.out_features).cuda()].view(
                    self.out_features, self.in_features)

        if x_var is None:
            x_var = torch.diag_embed(x_mean * x_mean)

        lrt_cov = compute_linear_var(x_mean, x_var, self.W.t(), sigma2.t())
        dst = MultivariateNormal(lrt_mean, covariance_matrix=lrt_cov)
        return dst.rsample(), None
예제 #3
0
    def forward(self, *state_args, deterministic=True):
        x = super(Policy, self).forward(*state_args)

        mean, log_std = torch.split(x, x.shape[1] // 2, dim=1)

        log_std = self.std_clamp(log_std)

        if deterministic:
            action = torch.tanh(mean)
            log_prob = torch.zeros(log_std.shape[0]).unsqueeze_(-1)
        else:
            std = log_std.exp()

            normal = MultivariateNormal(mean, torch.diag_embed(std.pow(2)))
            action_base = normal.rsample()

            log_prob = normal.log_prob(action_base)
            log_prob.unsqueeze_(-1)

            action = torch.tanh(action_base)

            action_bound_compensation = torch.log(1. - action.pow(2) +
                                                  np.finfo(float).eps).sum(
                                                      dim=1, keepdim=True)

            log_prob.sub_(action_bound_compensation)

        return action, log_prob
예제 #4
0
    def forward(self, matrix, rets=None, **kwargs):
        """Perform forward pass.

        Only accepts keyword arguments to avoid ambiguity.

        Parameters
        ----------
        matrix : torch.Tensor
            Of shape (n_samples, n_assets, n_assets) representing the square of the covariance matrix if
            `self.square=True` else the covariance matrix itself.

        rets : torch.Tensor or None
            Of shape (n_samples, n_assets) representing expected returns (or whatever the feature extractor decided
            to encode). Note that `NCO` and `AnalyticalMarkowitz` allow for `rets=None` (using only minimum variance).

        kwargs : dict
            All additional input arguments the `self.allocator` needs to perform forward pass.

        Returns
        -------
        weights : torch.Tensor
            Of shape (n_samples, n_assets) representing the optimal weights.

        """
        if self.random_state is not None:
            torch.manual_seed(self.random_state)

        n_samples, n_assets, _ = matrix.shape
        dtype, device = matrix.dtype, matrix.device
        n_draws = self.n_draws or n_assets  # make sure that if None then we have the same N=M

        covmat = matrix @ matrix if self.sqrt else matrix
        dist_rets = torch.zeros(n_samples, n_assets, dtype=dtype, device=device) if rets is None else rets

        dist = MultivariateNormal(loc=dist_rets, covariance_matrix=covmat)

        portfolios = []  # n_portfolios elements of (n_samples, n_assets)

        for _ in range(self.n_portfolios):
            draws = dist.rsample((n_draws,))  # (n_draws, n_samples, n_assets)
            rets_ = draws.mean(dim=0) if rets is not None else None  # (n_samples, n_assets)
            covmat_ = CovarianceMatrix(sqrt=self.uses_sqrt)(draws.permute(1, 0, 2))  # (n_samples, n_assets, ...)

            if isinstance(self.allocator, (AnalyticalMarkowitz, NCO)):
                portfolio = self.allocator(covmat=covmat_, rets=rets_)

            elif isinstance(self.allocator, NumericalMarkowitz):
                gamma = kwargs['gamma']
                alpha = kwargs['alpha']
                portfolio = self.allocator(rets_, covmat_, gamma, alpha)

            portfolios.append(portfolio)

        portfolios_t = torch.stack(portfolios, dim=0)  # (n_portfolios, n_samples, n_assets)

        return portfolios_t.mean(dim=0)
예제 #5
0
 def select_action(self, x):
     mu, cov = self.forward(x)
     tril = self.reshape_output(mu, cov)
     dist = MultivariateNormal(mu, scale_tril=tril)
     if self.pwd:
         action = dist.rsample()
     else:
         action = dist.sample()
     log_prob = dist.log_prob(action)
     entropy = dist.entropy()
     return action, log_prob, entropy
예제 #6
0
    def forward(self, x, n_samples, reparam=True, squeeze=True):
        q_m = self.mean_encoder(x)
        l_mat = self.var_encoder
        q_v = l_mat.matmul(l_mat.T)

        variational_dist = MultivariateNormal(loc=q_m, scale_tril=l_mat)

        if squeeze and n_samples == 1:
            sample_shape = []
        else:
            sample_shape = (n_samples, )
        if reparam:
            latent = variational_dist.rsample(sample_shape=sample_shape)
        else:
            latent = variational_dist.sample(sample_shape=sample_shape)
        return dict(q_m=q_m, q_v=q_v, latent=latent)
예제 #7
0
파일: test_mvn.py 프로젝트: mortonjt/catvae
    def test_log_prob(self):
        loc = torch.ones(self.d)

        wdw = self.W @ torch.diag(self.D) @ self.W.t()
        sI = self.s2 * self.Id
        sigma = sI + wdw
        dist2 = MultivariateNormal(loc, covariance_matrix=sigma)
        samples = dist2.rsample([10000])
        exp_logp = dist2.log_prob(samples)

        dist1 = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
        res_logp = dist1.log_prob(samples)

        self.assertAlmostEqual(float(exp_logp.mean()),
                               float(res_logp.mean()),
                               places=3)
예제 #8
0
class TanhNormal(Distribution):
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = MultivariateNormal(loc, scale)

    def sample(self):
        return torch.tanh(self.normal.sample())

    def rsample(self):
        return torch.tanh(self.normal.rsample())

    # Calculates log probability of value using the change-of-variables technique (uses log1p = log(1 + x) for extra numerical stability)
    def log_prob(self, value):
        inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2  # artanh(y)
        return self.normal.log_prob(inv_value) - torch.log1p(
            -value.pow(2) + 1e-6)  # log p(f^-1(y)) + log |det(J(f^-1(y)))|

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)
예제 #9
0
파일: gpm.py 프로젝트: kastnerkyle/pplvm
    def loss(self,
             Y,
             dT,
             T,
             switch=None,
             K_zz_inv=None,
             K_xz=None,
             qmu=None,
             qs=None,
             anneal=1.):

        # Calculating covariance
        T_induce, induce_idx = self.get_T_induce(T)
        if K_zz_inv is None:
            K_zz_inv = self.calc_K_inv(T_induce)
        if K_xz is None:
            K_xz = self.calc_K_xz(T_induce, T)

        # GP loss
        if qmu is not None and qs is not None:
            q_u = MultivariateNormal(qmu, torch.diag(qs.exp()))
        else:
            mu, log_var = self.x_model(Y, dT, induce_idx, switch)
            q_u = MultivariateNormal(mu.squeeze(),
                                     torch.diag(log_var.squeeze().exp()))
        # sparse GP KL
        u = q_u.rsample().squeeze()
        p_u = MultivariateNormal(torch.zeros(u.shape[0], dtype=torch.float64),
                                 precision_matrix=K_zz_inv)
        kl = kl_divergence(q_u, p_u)

        # HMM loss
        X = torch.mv(K_xz, torch.mv(K_zz_inv, u))
        log_pi0, log_pi, log_ab = self.calc_params(X)
        ll = self.log_like_dT(dT, log_ab) + self.likelihood.mixture_prob(Y)
        loss_hmm = -1. * hmmnorm_cython(log_pi0, log_pi.contiguous(),
                                        ll.contiguous())

        loss = loss_hmm + anneal * kl

        return loss
예제 #10
0
    def _mcvi_forward(self, x):
        W_var = self._get_var(self.W_logvar)
        bias_var = self._get_var(self.bias_logvar)

        if self.certain:
            x_mean = x
            x_var = None
        else:
            x_mean = x[0]
            x_var = x[1]

        y_mean = F.linear(x_mean, self.W.t()) + self.bias

        if self.certain or not self.deterministic:
            xx = x_mean * x_mean
            y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var)
        else:
            y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias,
                                       bias_var)

        dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var)
        sample = dst.rsample()
        return sample, None
예제 #11
0
    def _loss_em_mc_efficient(
        self,
        past_targets: [Sequence[torch.Tensor], torch.Tensor],
        past_controls: Optional[Union[Sequence[ControlInputs],
                                      ControlInputs]] = None,
    ) -> torch.Tensor:
        """
        Monte Carlo loss as computed in KVAE paper.
        Can be computed more efficiently if no missing data (no imputation),
        by batching some things along time-axis.
        """
        past_controls = self._expand_particle_dim(past_controls)
        n_batch = len(past_targets[0])

        # A) SSM related distributions:
        # A1) smoothing.
        latents_smoothed = self._smooth_efficient(
            past_targets=past_targets,
            past_controls=past_controls,
            return_time_tensor=True,
        )

        state_smoothed_dist = MultivariateNormal(
            loc=latents_smoothed.variables.m,
            covariance_matrix=latents_smoothed.variables.V,
        )
        x = state_smoothed_dist.rsample()
        gls_params = latents_smoothed.gls_params

        # A2) prior && posterior transition distribution.
        prior_dist = self.state_prior_model(
            None, batch_shape_to_prepend=(self.n_particle, n_batch))

        #  # A, B, R are already 0:T-1.
        transition_dist = MultivariateNormal(
            loc=matvec(gls_params.A[:-1], x[:-1]) +
            (matvec(gls_params.B[:-1], past_controls.state[:-1])
             if gls_params.B is not None else 0.0),
            covariance_matrix=gls_params.R[:-1],
        )
        # A3) posterior predictive (auxiliary) distribution.
        auxiliary_predictive_dist = MultivariateNormal(
            loc=matvec(gls_params.C, x) +
            (matvec(gls_params.D, past_controls.target)
             if gls_params.D is not None else 0.0),
            covariance_matrix=gls_params.Q,
        )

        # A4) SSM related losses
        l_prior = (-prior_dist.log_prob(x[0:1]).sum(dim=(0, 1)) /
                   self.n_particle)  # time and particle dim
        l_transition = (-transition_dist.log_prob(x[1:]).sum(dim=(0, 1)) /
                        self.n_particle)  # time and particle dim
        l_auxiliary = (-auxiliary_predictive_dist.log_prob(
            latents_smoothed.variables.auxiliary).sum(dim=(0, 1)) /
                       self.n_particle)  # time and particle dim
        l_entropy = (
            state_smoothed_dist.log_prob(x).sum(dim=(0, 1))  # negative entropy
            / self.n_particle)  # time and particle dim

        # B) VAE related distributions
        # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute)
        # B2) measurement (decoder) distribution
        # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly
        z_particle_first = latents_smoothed.variables.auxiliary.transpose(0, 1)
        measurement_dist = self.measurement_model(z_particle_first)
        # B3) VAE related losses
        l_measurement = (
            -measurement_dist.log_prob(past_targets).sum(dim=(0, 1)) /
            self.n_particle)  # time and particle dim

        auxiliary_variational_dist = MultivariateNormal(
            loc=latents_smoothed.variables.m_auxiliary_variational,
            covariance_matrix=latents_smoothed.variables.
            V_auxiliary_variational,
        )
        l_inv_measurement = (
            auxiliary_variational_dist.log_prob(z_particle_first).sum(
                dim=(0, 1)) / self.n_particle)  # time and particle dim

        assert all(t.shape == l_prior.shape for t in (
            l_prior,
            l_transition,
            l_auxiliary,
            l_measurement,
            l_inv_measurement,
        ))

        l_total = (self.reconstruction_weight * l_measurement +
                   l_inv_measurement + l_auxiliary + l_prior + l_transition +
                   l_entropy)
        return l_total
예제 #12
0
def sample_activations(x, n_samples):
    x_mean, x_var = x[0], x[1]
    sampler = MultivariateNormal(loc=x_mean, covariance_matrix=x_var)
    samples = sampler.rsample([n_samples])
    return samples
class GaussianTorchDistribution(TorchDistribution):
    def __init__(self, mu, chol_flat, use_cuda):
        super().__init__(use_cuda)
        self._dim = mu.shape[0]

        self._mu = nn.Parameter(torch.as_tensor(mu, dtype=torch.float32),
                                requires_grad=True)
        self._chol_flat = nn.Parameter(torch.as_tensor(chol_flat,
                                                       dtype=torch.float32),
                                       requires_grad=True)

        self.distribution_t = MultivariateNormal(
            self._mu,
            scale_tril=self.to_tril_matrix(self._chol_flat, self._dim))

    def __copy__(self):
        return GaussianTorchDistribution(self._mu, self._chol_flat,
                                         self.use_cuda)

    def __deepcopy__(self, memodict=None):
        return GaussianTorchDistribution(copy.deepcopy(self._mu),
                                         copy.deepcopy(self._chol_flat),
                                         self.use_cuda)

    @staticmethod
    def to_tril_matrix(chol_flat, dim):
        if isinstance(chol_flat, np.ndarray):
            chol = np.zeros((dim, dim))
            exp_fun = np.exp
        else:
            chol = torch.zeros((dim, dim))
            exp_fun = torch.exp

        d1, d2 = np.diag_indices(dim)
        chol[d1, d2] += exp_fun(chol_flat[0:dim])
        ld1, ld2 = np.tril_indices(dim, k=-1)
        chol[ld1, ld2] += chol_flat[dim:]

        return chol

    @staticmethod
    def flatten_matrix(mat, tril=False):
        if not tril:
            mat = scpla.cholesky(mat, lower=True)

        dim = mat.shape[0]
        d1, d2 = np.diag_indices(dim)
        ld1, ld2 = np.tril_indices(dim, k=-1)

        return np.concatenate((np.log(mat[d1, d2]), mat[ld1, ld2]))

    def entropy_t(self):
        return self.distribution_t.entropy()

    def mean_t(self):
        return self.distribution_t.mean

    def log_pdf_t(self, x):
        return self.distribution_t.log_prob(x)

    def sample(self):
        return self.distribution_t.rsample()

    def covariance_matrix(self):
        return self.distribution_t.covariance_matrix.detach().numpy()

    def set_weights(self, weights):
        set_weights([self._mu], weights[0:self._dim], self._use_cuda)
        set_weights([self._chol_flat], weights[self._dim:], self._use_cuda)
        # This is important - otherwise the changes will not be reflected!
        self.distribution_t = MultivariateNormal(
            self._mu,
            scale_tril=self.to_tril_matrix(self._chol_flat, self._dim))

    def get_weights(self):
        mu_weights = get_weights([self._mu])
        chol_flat_weights = get_weights([self._chol_flat])

        return np.concatenate([mu_weights, chol_flat_weights])

    def parameters(self):
        return [self._mu, self._chol_flat]
예제 #14
0
def mvnrnd(mu, sigma, sample_shape=()):
    d = MultivariateNormal(loc=mu, covariance_matrix=sigma)
    return d.rsample(sample_shape)
예제 #15
0
파일: model.py 프로젝트: tzs930/FIVO_Gumbel
    def forward(self, x, mask, num_particles=4):

        log_hat_p_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )
        kl_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )

        h = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)
        c = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)

        # with torch.autograd.set_detect_anomaly(True):
        for t in range(x.size(0)):
            # VRNN Cell
            xts = x[t].repeat((1, num_particles)).reshape(
                (x.size(1) * num_particles, -1))
            phi_x_ts = self.phi_x(
                xts)  # [batch_size, num_particles, embed_size]

            enc_t = self.enc(
                torch.cat([phi_x_ts, h[-1]],
                          1))  # [batch_size, num_particles, embed_size]
            enc_mean_t = self.enc_mean(
                enc_t)  # [batch_size, num_particles, latent_size]
            enc_std_t = self.enc_std(
                enc_t)  # [batch_size, num_particles, latent_size]

            encoder_dist = MultivariateNormal(
                enc_mean_t, scale_tril=torch.diag_embed(enc_std_t))

            prior_t = self.prior(h[-1])
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            prior_dist = MultivariateNormal(
                prior_mean_t, scale_tril=torch.diag_embed(prior_std_t))

            z_t_is = encoder_dist.rsample(
            )  # reparametrizable  # [batch_size * seq_len, latent_size]

            phi_z_ts = self.phi_z(z_t_is)

            dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            decoder_dist = Bernoulli(probs=dec_mean_t)

            prior_logprob_ti = prior_dist.log_prob(z_t_is.detach())
            encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach())
            decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1)

            # recurrence
            _, (h, c) = self.rnn(
                torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c))

            kl = torch.distributions.kl_divergence(encoder_dist, prior_dist)
            kl_acc += kl.mean(-1) * mask[t]
            nll = self._nll_bernoulli(dec_mean_t, xts)

            log_alpha_ti = -(nll + kl)
            # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti 	# [batch_size, ]
            log_alpha_ti = log_alpha_ti.reshape(
                x.size(1), -1)  # [batch_size, num_particles]
            log_alpha_ti = log_alpha_ti * mask[t][
                None].T  # [batch_size, num_particles] * [batch_size, 1]

            # hat_p = torch.exp(logweight_acc + log_alpha_ti) 		# [batch_size, num_particles]
            # log_hat_p = torch.exp(logweight_acc).sum(-1))
            log_hat_p = torch.logsumexp(
                log_alpha_ti.clone(), dim=-1) - math.log(float(num_particles))
            log_hat_p_acc += log_hat_p * mask[t]

            # logweight_acc *= (1. - should_resample_tiled.reshape(x.size(1), num_particles).float())

        iwae_bound = torch.sum(log_hat_p_acc)
        # kl_acc = kl_acc.mean(-1)
        # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1)

        # return fivo_loss, kld_loss, nll_loss, \
        # 	(all_enc_mean, all_enc_std), \
        # 	(all_dec_mean, all_dec_std), \
        # 	log_hat_ps
        return -iwae_bound, log_hat_p_acc, _, kl_acc, log_hat_p_acc
예제 #16
0
    def logprob_w_cov_gaussian_posterior(self,
                                         input,
                                         sample_size=128,
                                         z=None,
                                         std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        assert sample_size >= 2 * self.z_dim
        ''' get z and pseudo log q(newz|x) '''
        z, newz = [], []
        #cov_qz, rv_z = [], []
        logposterior = []
        inp = self.encode._forward_inp(input).detach()
        #for i in range(sample_size):
        for i in range(batch_size):
            _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1))
            _nos = self.encode._forward_nos(batch_size=sample_size,
                                            std=std,
                                            device=input.device).detach()
            _z = self.encode._forward_all(_inp, _nos)  # ssz x zdim
            z += [_z.detach().unsqueeze(0)]
        z = torch.cat(z, dim=0)  # bsz x ssz x zdim
        mu_qz = torch.mean(z, dim=1)  # bsz x zdim
        for i in range(batch_size):
            _cov_qz = get_covmat(z[i, :, :])
            _rv_z = MultivariateNormal(mu_qz[i], _cov_qz)
            _newz = _rv_z.rsample(torch.Size([1, sample_size]))
            _logposterior = _rv_z.log_prob(_newz)

            #cov_qz += [_cov_qz.unsqueeze(0)]
            #rv_z += [_rv_z]
            newz += [_newz]
            logposterior += [_logposterior]
        #cov_qz = torch.cat(cov_qz, dim=0) # bsz x zdim x zdim
        newz = torch.cat(newz, dim=0)  # bsz x ssz x zdim
        logposterior = torch.cat(logposterior, dim=0)  # bsz x ssz
        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz,
                                    logvar_pz,
                                    newz,
                                    do_unsqueeze=False,
                                    do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size,
                                           self.z_dim),
                             dim=2)  # bsz x ssz
        ''' get log p(x|z) '''
        # decode
        logit_x = []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _logit_x = self.decode(newz[i, :, :])  # ssz x zdim
            logit_x += [_logit_x.detach().unsqueeze(0)]
        logit_x = torch.cat(logit_x, dim=0)  # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(
            batch_size, sample_size, self.input_dim)  # bsz x ssz x input_dim
        loglikelihood = -F.binary_cross_entropy_with_logits(
            logit_x, _input, reduction='none')
        loglikelihood = torch.sum(loglikelihood, dim=2)  # bsz x ssz
        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior  # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp()  # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) +
                            1e-10) + logprob_max  # bsz x 1

        # return
        return logprob.mean()
예제 #17
0
파일: model.py 프로젝트: tzs930/FIVO_Gumbel
    def forward(self, x, mask, num_particles=4):

        logweight_acc = torch.zeros(x.size(1), num_particles).to(
            device)  # (batch_size, num_particles)
        log_hat_p_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )
        log_hat_p_iwae_acc = torch.zeros(x.size(1)).to(device)
        kl_acc = torch.zeros(x.size(1)).to(device)  # (batch_size, )

        # [0, 1, 2, 3, 4, 5, 6, 7, ... ]
        noresampleidxs = torch.arange(x.size(1) * num_particles).to(device)

        h = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)
        c = Variable(
            torch.zeros(self.n_layers,
                        x.size(1) * num_particles, self.h_dim)).to(device)

        # with torch.autograd.set_detect_anomaly(True):
        for t in range(x.size(0)):
            # VRNN Cell
            xts = x[t].repeat((1, num_particles)).reshape(
                (x.size(1) * num_particles, -1))
            phi_x_ts = self.phi_x(
                xts)  # [batch_size * num_particle, embed_size]

            enc_t = self.enc(torch.cat([phi_x_ts, h[-1]], 1))
            enc_mean_t = self.enc_mean(enc_t)
            enc_std_t = self.enc_std(enc_t)

            encoder_dist = MultivariateNormal(
                enc_mean_t, scale_tril=torch.diag_embed(enc_std_t))

            prior_t = self.prior(h[-1])
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            prior_dist = MultivariateNormal(
                prior_mean_t, scale_tril=torch.diag_embed(prior_std_t))

            z_t_is = encoder_dist.rsample(
            )  # reparametrizable  # [batch_size * seq_len, latent_size]

            phi_z_ts = self.phi_z(z_t_is)

            dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            decoder_dist = Bernoulli(probs=dec_mean_t)

            prior_logprob_ti = prior_dist.log_prob(z_t_is.detach()) + 1e-7
            encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach()) + 1e-7
            decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1) + 1e-7

            # recurrence
            _, (h, c) = self.rnn(
                torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c))

            kl = torch.distributions.kl_divergence(encoder_dist, prior_dist)
            kl_acc += kl.mean(-1) * mask[t]
            nll = self._nll_bernoulli(dec_mean_t, xts)

            # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti 	# [batch_size, ]
            log_alpha_ti = -(nll + kl)
            log_alpha_ti = log_alpha_ti.reshape(
                x.size(1), -1)  # [batch_size, num_particles]
            log_alpha_ti = log_alpha_ti * mask[t][
                None].T  # [batch_size, num_particles] * [batch_size, 1]

            # hat_p = torch.exp(logweight_acc + log_alpha_ti) 		# [batch_size, num_particles]
            logweight_acc += log_alpha_ti

            # Add resampling procedure here
            # ess = 1. / (torch.exp(logweight_acc) ** 2).sum(-1)  # [batch_size, ]
            # logess = torch.log(1. / (torch.exp(logweight_acc) ** 2).sum(-1) )
            logess_num = 2 * torch.logsumexp(logweight_acc, dim=-1)
            logess_denom = torch.logsumexp(2 * logweight_acc, dim=-1)
            logess = logess_num - logess_denom

            if not self.use_resampling_gradient:
                resample_dist = Categorical(
                    logits=logweight_acc.reshape(x.size(1), num_particles))
                resampled_idxs = resample_dist.sample([num_particles]).T

                # [0, 0, 0, 0, 4, 4, 4, 4, ... ]
                sample_offset = torch.arange(x.size(1)).repeat([
                    num_particles, 1
                ]).T.reshape(-1).to(device) * num_particles
                resampled_idxs = resampled_idxs.reshape(-1) + sample_offset

                should_resample = logess <= torch.log(
                    torch.ones_like(logess).to(device) * num_particles / 2.0)
                should_resample = should_resample & mask[t].bool()
                should_resample_tiled = should_resample.repeat(
                    [num_particles, 1]).T.reshape(-1)

                new_idxs = torch.where(should_resample_tiled, resampled_idxs,
                                       noresampleidxs)

                h[-1] = h[-1][new_idxs]
                c[-1] = c[-1][new_idxs]

                log_hat_p = torch.logsumexp(logweight_acc.clone(),
                                            dim=-1) - math.log(
                                                float(num_particles))
                log_hat_p_acc += log_hat_p * should_resample.float()

                logweight_acc *= (1. - should_resample_tiled.reshape(
                    x.size(1), num_particles).float())

            else:
                # raise NotImplementedError
                resample_dist = RelaxedOneHotCategorical(
                    logits=logweight_acc.reshape(x.size(1), num_particles),
                    temperature=0.1)
                resampled_onehot_relaxedidxs = resample_dist.rsample(
                    [num_particles]).permute(1, 0,
                                             2)  #.reshape(-1, num_particles)

                should_resample = logess <= torch.log(
                    torch.ones_like(logess).to(device) * num_particles / 2.0)
                should_resample = should_resample & mask[t].bool()
                should_resample_tiled = should_resample.repeat(
                    [num_particles, 1]).T.reshape(-1)

                # noresample_onehot = torch.eye(x.size(1) * num_particles)

                for batch_idx in range(x.size(1)):
                    if should_resample[batch_idx]:
                        # cur_slice = (batch_idx * x.size(1) * num_particles) : (batch_idx * x.size(1) * num_particles + x.size(1) * num_particles)
                        h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \
                         resampled_onehot_relaxedidxs[batch_idx] @ h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone()
                        c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \
                         resampled_onehot_relaxedidxs[batch_idx] @ c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone()

                log_hat_p = torch.logsumexp(logweight_acc.clone(),
                                            dim=-1) - math.log(
                                                float(num_particles))
                log_hat_p_acc += log_hat_p * should_resample.float()

                logweight_acc *= (1. - should_resample_tiled.reshape(
                    x.size(1), num_particles).float())

            log_hat_p_iwae_acc += (
                torch.logsumexp(log_alpha_ti.detach(), dim=-1) -
                math.log(float(num_particles))) * mask[t]
            #computing losses
            # kld_loss /= self.num_zs
            # nll_loss /= self.num_zs

        log_hat_p_acc += torch.logsumexp(logweight_acc, dim=-1) - math.log(
            float(num_particles))
        fivo_bound = torch.sum(log_hat_p_acc)
        # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1)

        # return fivo_loss, kld_loss, nll_loss, \
        # 	(all_enc_mean, all_enc_std), \
        # 	(all_dec_mean, all_dec_std), \
        # 	log_hat_ps
        return -fivo_bound, log_hat_p_acc, logweight_acc, kl_acc, log_hat_p_iwae_acc
예제 #18
0
    def _loss_em_mc(
        self,
        past_targets: [Sequence[torch.Tensor], torch.Tensor],
        past_controls: Optional[Union[Sequence[ControlInputs],
                                      ControlInputs]] = None,
        past_targets_is_observed: Optional[Union[Sequence[torch.Tensor],
                                                 torch.Tensor]] = None,
    ) -> torch.Tensor:
        """" Monte Carlo loss as computed in KVAE paper """
        n_batch = len(past_targets[0])

        past_controls = self._expand_particle_dim(past_controls)

        # A) SSM related distributions:
        # A1) smoothing.
        latents_smoothed = self.smooth(
            past_targets=past_targets,
            past_controls=past_controls,
            past_targets_is_observed=past_targets_is_observed,
        )
        m = torch.stack([l.variables.m for l in latents_smoothed])
        V = torch.stack([l.variables.V for l in latents_smoothed])
        z = torch.stack([l.variables.auxiliary for l in latents_smoothed])
        state_smoothed_dist = MultivariateNormal(loc=m, covariance_matrix=V)
        x = state_smoothed_dist.rsample()

        A = torch.stack([l.gls_params.A for l in latents_smoothed])
        C = torch.stack([l.gls_params.C for l in latents_smoothed])
        LR = torch.stack([l.gls_params.LR for l in latents_smoothed])
        LQ = torch.stack([l.gls_params.LQ for l in latents_smoothed])
        if latents_smoothed[0].gls_params.B is not None:
            B = torch.stack([l.gls_params.B for l in latents_smoothed])
        else:
            B = None
        if latents_smoothed[0].gls_params.D is not None:
            D = torch.stack([l.gls_params.D for l in latents_smoothed])
        else:
            D = None

        # A2) prior && posterior transition distribution.
        prior_dist = self.state_prior_model(
            None, batch_shape_to_prepend=(self.n_particle, n_batch))

        #  # A, B, R are already 0:T-1.
        transition_dist = MultivariateNormal(
            loc=matvec(A[:-1], x[:-1]) + (matvec(
                B[:-1], past_controls.state[:-1]) if B is not None else 0.0),
            scale_tril=LR[:-1],
        )
        # A3) posterior predictive (auxiliary) distribution.
        auxiliary_predictive_dist = MultivariateNormal(
            loc=matvec(C, x) +
            (matvec(D, past_controls.target) if D is not None else 0.0),
            scale_tril=LQ,
        )

        # A4) SSM related losses
        # mean over particle dim, sum over time (after masking), leave batch dim
        l_prior = -prior_dist.log_prob(x[0:1]).mean(dim=1).sum(dim=0)
        l_transition = -transition_dist.log_prob(x[1:]).mean(dim=1).sum(dim=0)
        l_entropy = state_smoothed_dist.log_prob(x).mean(dim=1).sum(dim=0)

        _l_aux_timewise = -auxiliary_predictive_dist.log_prob(z).mean(dim=1)
        if past_targets_is_observed is not None:
            _l_aux_timewise = _l_aux_timewise * past_targets_is_observed
        l_auxiliary = _l_aux_timewise.sum(dim=0)

        # B) VAE related distributions
        # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute)
        # B2) measurement (decoder) distribution
        # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly
        z_particle_first = z.transpose(0, 1)
        measurement_dist = self.measurement_model(z_particle_first)
        # B3) VAE related losses
        # We use z_particle_first for correct broadcasting -> dim=0 is particle.
        _l_meas_timewise = -measurement_dist.log_prob(past_targets).mean(dim=0)
        if past_targets_is_observed is not None:
            _l_meas_timewise = _l_meas_timewise * past_targets_is_observed
        l_measurement = _l_meas_timewise.sum(dim=0)

        auxiliary_variational_dist = MultivariateNormal(
            loc=torch.stack([
                l.variables.m_auxiliary_variational for l in latents_smoothed
            ]),
            covariance_matrix=torch.stack([
                l.variables.V_auxiliary_variational for l in latents_smoothed
            ]),
        )
        _l_variational_timewise = auxiliary_variational_dist.log_prob(
            z_particle_first).mean(dim=0)  # again dim=0 is particle dim here.
        if past_targets_is_observed is not None:
            _l_variational_timewise = (_l_variational_timewise *
                                       past_targets_is_observed)
        l_inv_measurement = _l_variational_timewise.sum(dim=0)

        assert all(t.shape == l_prior.shape for t in (
            l_prior,
            l_transition,
            l_auxiliary,
            l_measurement,
            l_inv_measurement,
        ))

        l_total = (self.reconstruction_weight * l_measurement +
                   l_inv_measurement + l_auxiliary + l_prior + l_transition +
                   l_entropy)
        return l_total
예제 #19
0
    def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None):
        # init
        batch_size = input.size(0)
        input = input.view(batch_size, self.input_dim)
        assert sample_size >= 2*self.z_dim
        #assert int(math.sqrt(sample_size))**2 == sample_size

        ''' get z and pseudo log q(newz|x) '''
        #z, newz = [], []
        #logposterior = []
        #inp = self.encode._forward_inp(input).detach()
        #for i in range(batch_size):
        #    _inp = inp[i:i+1, :].expand(sample_size, inp.size(1))
        #    _nos = self.encode._forward_nos(sample_size, std=std, device=input.device).detach()
        #    _z = self.encode._forward_all(_inp, _nos) # ssz x zdim
        #    z += [_z.detach().unsqueeze(0)]
        #z = torch.cat(z, dim=0) # bsz x ssz x zdim
        #_nz = int(math.sqrt(sample_size))
        _, _, _, _, z, _, _, _, _ = self.encode._forward(input, std=std, nz=sample_size) # bsz x ssz x zdim
        newz = []
        logposterior = []
        eye = torch.eye(self.z_dim, device=z.device)
        mu_qz = torch.mean(z, dim=1) # bsz x zdim
        for i in range(batch_size):
            _cov_qz = get_covmat(z[i, :, :]) + 1e-5*eye
            _rv_z = MultivariateNormal(mu_qz[i], _cov_qz)
            _newz = _rv_z.rsample(torch.Size([1, sample_size]))
            _logposterior = _rv_z.log_prob(_newz)

            newz += [_newz]
            logposterior += [_logposterior]
        newz = torch.cat(newz, dim=0) # bsz x ssz x zdim
        logposterior = torch.cat(logposterior, dim=0) # bsz x ssz

        ''' get log p(z) '''
        # get prior (as unit normal dist)
        mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim)
        logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False)
        logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz

        ''' get log p(x|z) '''
        # decode
        mu_x, logvar_x = [], []
        #for i in range(sample_size):
        for i in range(batch_size):
            _, _mu_x, _logvar_x = self.decode(newz[i, :, :])
            mu_x += [_mu_x.detach().unsqueeze(0)]
            logvar_x += [_logvar_x.detach().unsqueeze(0)]
        mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim
        logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim
        _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim
        loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False)
        loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz

        ''' get log p(x|z)p(z)/q(z|x) '''
        logprob = loglikelihood + logprior - logposterior # bsz x ssz
        logprob_max, _ = torch.max(logprob, dim=1, keepdim=True)
        rprob = (logprob - logprob_max).exp() # relative prob
        logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1

        # return
        return logprob.mean()
예제 #20
0
    def filter_step(
        self,
        lats_tm1: (LatentsKVAE, None),
        tar_t: torch.Tensor,
        ctrl_t: ControlInputs,
        tar_is_obs_t: Optional[torch.Tensor] = None,
    ) -> LatentsKVAE:
        is_initial_step = lats_tm1 is None
        if tar_is_obs_t is None:
            tar_is_obs_t = torch.ones(
                tar_t.shape[:-1],
                dtype=tar_t.dtype,
                device=tar_t.device,
            )

        # 1) Initial step must prepare previous latents with prior and learnt z.
        if is_initial_step:
            n_particle, n_batch = self.n_particle, len(tar_t)
            state_prior = self.state_prior_model(
                None,
                batch_shape_to_prepend=(n_particle, n_batch),
            )
            z_init = self.z_initial[None, None].repeat(n_particle, n_batch, 1)
            lats_tm1 = LatentsKVAE(
                variables=GLSVariablesKVAE(
                    m=state_prior.loc,
                    V=state_prior.covariance_matrix,
                    Cov=None,
                    x=None,
                    auxiliary=z_init,
                    rnn_state=None,
                    m_auxiliary_variational=None,
                    V_auxiliary_variational=None,
                ),
                gls_params=None,
            )
        # 2) Compute GLS params
        rnn_state_t, rnn_output_t = self.compute_deterministic_switch_step(
            rnn_input=lats_tm1.variables.auxiliary,
            rnn_prev_state=lats_tm1.variables.rnn_state,
        )
        gls_params_t = self.gls_base_parameters(
            switch=rnn_output_t,
            controls=ctrl_t,
        )

        # Perform filter step:
        # 3) Prediction Step: Only for t > 0 and using previous GLS params.
        # (In KVAE, they do first update then prediction step.)
        if is_initial_step:
            mp, Vp, = lats_tm1.variables.m, lats_tm1.variables.V
        else:
            mp, Vp = filter_forward_prediction_step(
                m=lats_tm1.variables.m,
                V=lats_tm1.variables.V,
                R=lats_tm1.gls_params.R,
                A=lats_tm1.gls_params.A,
                b=lats_tm1.gls_params.b,
            )
        # 4) Update step
        # 4a) Observed data: Infer pseudo-obs by encoding obs && Bayes update
        auxiliary_variational_dist_t = self.encoder(tar_t)
        z_infer_t = auxiliary_variational_dist_t.rsample([self.n_particle])
        m_infer_t, V_infer_t = filter_forward_measurement_step(
            y=z_infer_t,
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )

        # 4b) Choice: inferred / predicted m, V for observed / missing data.
        is_filtered = tar_is_obs_t[None, :].repeat(self.n_particle, 1).byte()
        replace_m_fw = is_filtered[:, :, None].repeat(1, 1, mp.shape[2])
        replace_V_fw = is_filtered[:, :, None, None].repeat(
            1,
            1,
            Vp.shape[2],
            Vp.shape[3],
        )
        assert replace_m_fw.shape == m_infer_t.shape == mp.shape
        assert replace_V_fw.shape == V_infer_t.shape == Vp.shape

        m_t = torch.where(replace_m_fw, m_infer_t, mp)
        V_t = torch.where(replace_V_fw, V_infer_t, Vp)

        # 4c) Missing Data: Predict pseudo-observations && No Bayes update
        mpz_t, Vpz_t = filter_forward_predictive_distribution(
            m=m_t,  # posterior predictive or one-step-predictive (if missing)
            V=V_t,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        auxiliary_predictive_dist_t = MultivariateNormal(
            loc=mpz_t,
            covariance_matrix=Vpz_t,
        )
        z_gen_t = auxiliary_predictive_dist_t.rsample()

        # 4d) Choice: inferred / predicted z for observed / missing data.
        # One-step predictive if missing and inferred from encoder otherwise.
        replace_z = is_filtered[:, :, None].repeat(1, 1, z_gen_t.shape[2])
        z_t = torch.where(replace_z, z_infer_t, z_gen_t)

        # 5) Put result in Latents object, used in next iteration
        lats_t = LatentsKVAE(
            variables=GLSVariablesKVAE(
                m=m_t,
                V=V_t,
                Cov=None,
                x=None,
                auxiliary=z_t,
                rnn_state=rnn_state_t,
                m_auxiliary_variational=auxiliary_variational_dist_t.loc,
                V_auxiliary_variational=auxiliary_variational_dist_t.
                covariance_matrix,
            ),
            gls_params=gls_params_t,
        )
        return lats_t