Exemplo n.º 1
0
def set_up_model(seed, seed_data=10, dim=2):

    mu_prior = torch.zeros([dim])
    Sigma_prior = 5 * torch.eye(dim)
    prior = MultivariateNormal(mu_prior, Sigma_prior)

    torch.manual_seed(seed)
    mu_real = prior.rsample(sample_shape=(1, )).reshape(dim)

    torch.manual_seed(seed_data)
    a = torch.randn(dim, dim) * dim / 2
    Sigma_real = torch.mm(a, a.t())  # make symmetric positive-definite
    Sigma_real_tril = torch.cholesky(Sigma_real)

    # Set model
    model = MultivariateNormal(loc=mu_real, scale_tril=Sigma_real_tril)

    # set conj model
    conj_model = cMVN.ConjugateMultivariateNormalSummayStats(
        model, prior, 500, dim)

    # generate data
    x_o = conj_model.sample_fixed_seed(seed)

    # Analytical posterior

    analytical_posterior = conj_model.calc_analytical_posterior(x_o)

    return x_o, conj_model, analytical_posterior
    def forward_conditional(self, past, current, sig_inds):
        if current.shape[-1] == len(sig_inds):
            return current, current
        past = past.to(self.device)
        current = current.to(self.device)
        if len(current.shape) is 1:
            current = current.unsqueeze(0)
        mean, covariance = self.likelihood_distribution(past)  # P(X_t|X_0:t-1)
        sig_inds_comp = list(set(range(past.shape[-2])) - set(sig_inds))
        ind_len = len(sig_inds)
        ind_len_not = len(sig_inds_comp)
        x_ind = current[:, sig_inds].view(-1, ind_len)
        mean_1 = mean[:, sig_inds_comp].view(-1, ind_len_not)
        cov_1_2 = covariance[:, sig_inds_comp, :][:, :, sig_inds].view(
            -1, ind_len_not, ind_len)
        cov_2_2 = covariance[:,
                             sig_inds, :][:, :,
                                          sig_inds].view(-1, ind_len, ind_len)
        cov_1_1 = covariance[:, sig_inds_comp, :][:, :, sig_inds_comp].view(
            -1, ind_len_not, ind_len_not)
        mean_cond = mean_1 + torch.bmm(
            (torch.bmm(cov_1_2, torch.inverse(cov_2_2))),
            (x_ind - mean[:, sig_inds]).view(-1, ind_len, 1)).squeeze(-1)
        covariance_cond = cov_1_1 - torch.bmm(
            torch.bmm(cov_1_2, torch.inverse(cov_2_2)),
            torch.transpose(cov_1_2, 2, 1))

        # P(x_{-i,t}|x_{i,t})
        likelihood = MultivariateNormal(loc=mean_cond.squeeze(-1),
                                        covariance_matrix=covariance_cond)
        sample = likelihood.rsample()
        full_sample = current.clone()
        full_sample[:, sig_inds_comp] = sample
        return full_sample, mean[:, sig_inds_comp]
Exemplo n.º 3
0
class GaussianCopulaVariable(nn.Module):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    
    def __init__(self, loc, scale, covariance_matrix=None, validate_args=None):
        super(GaussianCopulaVariable, self).__init__()
        # affine = AffineTransform(-loc, 1/scale)
        self.multinormal = MultivariateNormal(loc, covariance_matrix)
        self.normal = Normal(loc, scale)
        self.standard_normal = Normal(0, 1)
        self.loc = loc
        self.scale = scale
        self.covariance_matrix = covariance_matrix

    def forward(self, x):
        
        pass

    def sample(self):
        r'''
        Sample from Gaussian Copula
            q ~ N(0, \Sigma)
            u = [cdf(q_1),...,cdf(q_n)]^T
        '''
        q = self.multinormal.rsample()
        u = torch.stack([self.standard_normal.cdf(q_i/s_i) for q_i, s_i in zip(q, self.scale)]).squeeze()

        return u
Exemplo n.º 4
0
def run_model_sim(N_samples,
                  seed,
                  conj_model,
                  analytical_posterior,
                  Sigma_real,
                  dim=2,
                  from_prior=True):

    if from_prior:
        torch.manual_seed(seed)
        theta = conj_model.prior.rsample(sample_shape=(N_samples, ))
    else:
        torch.manual_seed(seed)
        theta = analytical_posterior.rsample(sample_shape=(N_samples, ))

    torch.manual_seed(seed)

    x = torch.zeros(N_samples, conj_model.N, dim)

    for i in range(N_samples):
        model_tmp = MultivariateNormal(theta[i], Sigma_real)
        x[i, :, :] = model_tmp.rsample(sample_shape=(conj_model.N, ))

    return calc_summary_stats(
        x), theta  # /math.sqrt(5) # div with std of prior to nomarlize data
def generate_batch(batchlen, plot=False):
    cov = torch.tensor(np.identity(2) * 0.01, dtype=torch.float64)
    mu1 = torch.tensor([2**(-1 / 2), 2**(-1 / 2)], dtype=torch.float64)
    mu2 = torch.tensor([0, 1], dtype=torch.float64)

    gaussian1 = MultivariateNormal(loc=mu1, covariance_matrix=cov)
    gaussian2 = MultivariateNormal(loc=mu2, covariance_matrix=cov)

    d1 = gaussian1.rsample((int(batchlen / 2), ))
    d2 = gaussian2.rsample((int(batchlen / 2), ))

    data = np.concatenate((d1, d2), axis=0)
    np.random.shuffle(data)

    if plot:
        plt.scatter(data[:, 0], data[:, 1], s=2.0, color='gray')
        plt.show()
    return torch.Tensor(data).to(device)
Exemplo n.º 6
0
 def select_action(self, state, deterministic, reparameterize=False):
     mu, std = self.forward(state)
     dist = MultivariateNormal(loc=mu, scale_tril=torch.diag_embed(std))
     if deterministic:
         action = mu  # (bsize, action_dim)
     else:
         if reparameterize:
             action = dist.rsample()  # (bsize, action_dim)
         else:
             action = dist.sample()  # (bsize, action_dim)
     return action, dist
Exemplo n.º 7
0
def simulator(theta):
    N_samples = theta.shape[0]

    x = torch.zeros(N_samples, conj_model.N, dim)

    for i in range(N_samples):
        model_tmp = MultivariateNormal(theta[i], conj_model.model.covariance_matrix)
        x[i, :, :] = model_tmp.rsample(sample_shape=(conj_model.N,))

    # return calc_summary_stats(x), theta #/math.sqrt(5) # div with std of prior to nomarlize data
    return func.flatten(x)
Exemplo n.º 8
0
def predict_diagonal_sampling(model,
                              test_loader,
                              M_W_post,
                              M_b_post,
                              C_W_post,
                              C_b_post,
                              n_samples,
                              verbose=False,
                              cuda=False,
                              timing=False):
    py = []
    max_len = len(test_loader)
    if timing:
        time_sum = 0

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.features(x)

        mu, Sigma = get_Gaussian_output(phi, M_W_post, M_b_post, C_W_post,
                                        C_b_post)
        #print("mu size: ", mu.size())
        #print("sigma size: ", Sigma.size())

        post_pred = MultivariateNormal(mu, Sigma)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1 - t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    if timing:
        print("time used for sampling with {} samples: {}".format(
            n_samples, time_sum))

    return torch.cat(py, dim=0)
Exemplo n.º 9
0
def predict_KFAC_sampling(model,
                          test_loader,
                          M_W_post,
                          M_b_post,
                          U_post,
                          V_post,
                          B_post,
                          n_samples,
                          timing=False,
                          verbose=False,
                          cuda=False):
    py = []
    max_len = len(test_loader)
    if timing:
        time_sum = 0

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.features(x).detach()

        mu_pred = phi @ M_W_post + M_b_post
        Cov_pred = torch.diag(phi @ V_post @ phi.t()).reshape(
            -1, 1, 1) * U_post.unsqueeze(0) + B_post.unsqueeze(0)

        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1 - t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    if timing:
        print("time used for sampling with {} samples: {}".format(
            n_samples, time_sum))

    return torch.cat(py, dim=0)
Exemplo n.º 10
0
def sample(pi, sigma, mu):
    """Draw samples from a MoG.
    # Original implementation
    categorical = Categorical(pi)
    pis = list(categorical.sample().data)
    sample = Variable(sigma.data.new(sigma.size(0), sigma.size(2)).normal_())
    for i, idx in enumerate(pis):
        sample[i] = sample[i].mul(sigma[i,idx]).add(mu[i,idx])
    return sample
    """
    ######################
    # new implementation #
    ######################
    categorical = Categorical(pi)
    pis = list(categorical.sample().data)
    #print('len of pis', len(pis))
    #print('pis', pis)
    print('size of sigma = ', sigma.size())
    print('size of mu = ', mu.size())
    D = mu.size(-1)
    samples = torch.zeros([len(pi), D])
    sigma_cpu_all = sigma.detach().cpu()
    mu_cpu_all = mu.detach().cpu()
    for i, idx in enumerate(pis):
        #print('i = {}'.format(i))
        sigma_cpu = sigma_cpu_all[i, idx]
        precision_mat_diag_pos = torch.matmul(sigma_cpu,
                                              torch.transpose(sigma_cpu, 0, 1))
        mu_cpu = mu_cpu_all[i, idx]
        #precision_mat = sigma[i, idx] + torch.transpose(sigma[i, idx], 0, 1)
        diagonal_mat = torch.tensor(np.zeros([D, D]))
        #precision_mat_diag_pos np.fill_diagonal_(diagonal_mat, 1e-7)
        precision_mat_diag_pos += diagonal_mat.fill_diagonal_(
            1)  # add small positive value
        #precision_mat_diag_pos = precision_mat + diagonal_mat.fill_diagonal_(1 - torch.min(torch.diagonal(precision_mat)).detach().cpu().numpy())
        #print('precision_mat = ', precision_mat_diag_pos)
        #print(precision_mat_diag_pos)
        #print(mu_cpu)
        try:
            #print('precision_mat = ', precision_mat_diag_pos)
            MVN = MultivariateNormal(loc=mu_cpu,
                                     precision_matrix=precision_mat_diag_pos)
            draw_sample = MVN.rsample()
        except:
            print(
                "Ops, your covariance matrix is very unfortunately singular, assign loss of test_loss to avoid counting"
            )
            draw_sample = -999 * torch.ones([1, D])
        #print('sample size = ', draw_sample.size())
        samples[i, :] = draw_sample
    #print('samples', samples.size())
    return samples
Exemplo n.º 11
0
    def forward(self, x, S):
        x = x.view(-1, self.x_dim)
        bsz = x.size(0)

        ### get w and \alpha and L(\theta)
        mu, logvar = self.encoder(x)
        q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
        z_q = q_phi.rsample((S, ))
        recon_batch = self.decoder(z_q)
        x_dist = Bernoulli(logits=recon_batch)
        log_lik = x_dist.log_prob(x).sum(-1)
        log_prior = self.prior.log_prob(z_q).sum(-1)
        log_q = q_phi.log_prob(z_q).sum(-1)
        log_w = log_lik + log_prior - log_q
        tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
        alpha = torch.exp(log_w - tmp_alpha).detach()
        if self.version == 'v1':
            p_loss = -alpha * (log_lik + log_prior)

        ### get moment-matched proposal
        mu_r = alpha.unsqueeze(2) * z_q
        mu_r = mu_r.sum(0).detach()
        z_minus_mu_r = z_q - mu_r.unsqueeze(0)
        reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1)
        reshaped_diff_t = reshaped_diff.permute(0, 2, 1)
        outer = torch.bmm(reshaped_diff, reshaped_diff_t)
        outer = outer.view(S, bsz, self.z_dim, self.z_dim)
        Sigma_r = outer.mean(0) * S / (S - 1)
        Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6  ## ridging

        ### get v, \beta, and L(\phi)
        L = torch.cholesky(Sigma_r)
        r_phi = MultivariateNormal(loc=mu_r, scale_tril=L)

        z = r_phi.rsample((S, ))
        z_r = z.detach()
        recon_batch_r = self.decoder(z_r)
        x_dist_r = Bernoulli(logits=recon_batch_r)
        log_lik_r = x_dist_r.log_prob(x).sum(-1)
        log_prior_r = self.prior.log_prob(z_r).sum(-1)
        log_r = r_phi.log_prob(z_r)
        log_v = log_lik_r + log_prior_r - log_r
        tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0)
        beta = torch.exp(log_v - tmp_beta).detach()
        log_q = q_phi.log_prob(z_r).sum(-1)
        q_loss = -beta * log_q

        if self.version == 'v2':
            p_loss = -beta * (log_lik_r + log_prior_r)

        rem_loss = torch.sum(q_loss + p_loss, 0).sum()
        return rem_loss
Exemplo n.º 12
0
class StochasticBaseMultivariate(nn.Module):
    """ Trainable triangular matrix L, so Sigma=LL^T. """
    def __init__(self, D):
        super().__init__()
        self.gen = Generator(D)
        self.mu = nn.Parameter(torch.zeros(D), requires_grad=False)
        self.L = nn.Parameter(torch.rand(D, D))

    @property
    def sigma(self):
        return self.L.tril() @ self.L.tril().T

    def forward(self, x):
        x = self.gen(x)
        self.dist = MultivariateNormal(self.mu, scale_tril=self.L.tril())
        x_sample = self.dist.rsample()
        x = x + x_sample
        return x
Exemplo n.º 13
0
    def model_sim(self, theta):
        """
        Simulate from model for a given theta

        :param theta: Matrix of size n x dim_theta
        :return: x_samples: Matrix of size x x dim_x with samples from the model
        """
        n = theta.shape[0]

        x_samples = torch.zeros(n, self.x_dim)

        for j in range(theta.shape[0]):
            model_tmp = MultivariateNormal(theta[j, :],
                                           self.model.covariance_matrix)
            x_samples[j, :] = model_tmp.rsample(sample_shape=(
                self.N, )).flatten()  # flatten here so I do not need to
            # flatten later

        return x_samples
def predict_KFAC_sampling(model,
                          test_loader,
                          M_W_post,
                          M_b_post,
                          U_post,
                          V_post,
                          B_post,
                          n_samples,
                          verbose=False,
                          cuda=False):
    py = []
    max_len = int(np.ceil(len(test_loader.dataset) / len(test_loader)))

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.phi(x)

        mu_pred = phi @ M_W_post + M_b_post
        Cov_pred = torch.diag(phi @ U_post @ phi.t()).reshape(
            -1, 1, 1) * V_post.unsqueeze(0) + B_post.unsqueeze(0)

        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    return torch.cat(py, dim=0)
Exemplo n.º 15
0
    def model_sim(self, theta):
        """
        Simulate from model for a given theta

        :param theta: Matrix of size n x dim_theta
        :return: x_samples: Matrix of size x x S(x) with samples from the model
        """

        n = theta.shape[0]

        x_samples = torch.zeros(n, self.N, self.x_dim)

        for j in range(theta.shape[0]):
            model_tmp = MultivariateNormal(theta[j, :],
                                           self.model.covariance_matrix)
            x_samples[j, :, :] = model_tmp.rsample(sample_shape=(self.N, ))

        print(x_samples.shape)

        return func.calc_summary_stats(x_samples)
Exemplo n.º 16
0
class NormalizingFlows(nn.Module):
    def __init__(self, transforms, dim=2):

        super().__init__()
        if isinstance(transforms, nn.Module):
            self.transforms = nn.ModuleList([
                transforms,
            ])
        elif isinstance(transforms, list):
            if not all(isinstance(t, nn.Module) for t in transforms):
                raise ValueError("Wrong type of transforms")
            self.transforms = nn.ModuleList(transforms)
        else:
            raise ValueError(f"Wrong type of transforms")
        self.dim = dim
        self.base_dist = MultivariateNormal(torch.zeros(self.dim),
                                            torch.eye(self.dim))

    def log_prob(self, x):

        inv_log_det = 0.0
        for transform in reversed(self.transforms):
            z, inv_log_det_jacobian = transform.inverse(x)
            inv_log_det += inv_log_det_jacobian
            x = z
        log_base = self.base_dist.log_prob(x)
        log_prob = (inv_log_det + log_base)

        return log_prob

    def sample(self, batch_size):

        x = self.base_dist.rsample([batch_size])
        log_base = self.base_dist.log_prob(x)
        log_det = 0.0
        for transform in self.transforms:
            x, log_det_jacobian = transform.forward(x)
            log_det += log_det_jacobian
        log_prob = -log_det + log_base

        return x, log_prob
def predict_diagonal_sampling(model,
                              test_loader,
                              M_W_post,
                              M_b_post,
                              C_W_post,
                              C_b_post,
                              n_samples,
                              verbose=False,
                              cuda=False):
    py = []
    max_len = len(test_loader)

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.phi(x)

        mu, Sigma = get_Gaussian_output(phi, M_W_post, M_b_post, C_W_post,
                                        C_b_post)

        post_pred = MultivariateNormal(mu, Sigma)

        # MC-integral
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    return torch.cat(py, dim=0)
Exemplo n.º 18
0
class ResNet152_StochasticBaseMultivariate(nn.Module):
    """ Trainable lower triangular matrix L, so Sigma=LL^T. """
    def __init__(self, D, disable_noise=False):
        super().__init__()
        self.gen = GeneratorResNet152()
        self.fc1 = nn.Linear(2048, D)
        self.mu = nn.Parameter(torch.zeros(D), requires_grad=False)
        self.L = nn.Parameter(torch.rand(D, D).tril(), requires_grad=(not disable_noise))
        self.disable_noise = disable_noise

    @property
    def sigma(self):
        return self.L @ self.L.T

    def forward(self, x):
        x = self.gen(x)
        x = f.relu(self.fc1(x))
        if not self.disable_noise:
            self.dist = MultivariateNormal(self.mu, scale_tril=self.L)
            x_sample = self.dist.rsample()
            x = x + x_sample
        return x
Exemplo n.º 19
0
    def sample(self,
               x: torch.Tensor,
               raw_action: Optional[torch.Tensor] = None,
               deterministic: bool = False) -> Tuple[torch.Tensor, ...]:
        mean, log_std = self.forward(x)
        covariance = torch.diag_embed(log_std.exp())
        dist = MultivariateNormal(loc=mean, scale_tril=covariance)

        if not raw_action:
            if self._reparameterize:
                raw_action = dist.rsample()
            else:
                raw_action = dist.sample()

        action = torch.tanh(raw_action) if self._squash else raw_action
        log_prob = dist.log_prob(raw_action).unsqueeze(-1)
        if self._squash:
            log_prob -= self._squash_correction(raw_action)
        entropy = dist.entropy().unsqueeze(-1)

        if deterministic:
            action = torch.tanh(dist.mean)
        return action, log_prob, entropy
Exemplo n.º 20
0
def predict(dataloader, model, M_W, M_b, U, V, B, n_samples=100, delta=1, apply_softmax=True):
    py = []

    for x, y in dataloader:
        x, y = delta*x.cuda(), y.cuda()
        phi = model.features(x)

        mu_pred = phi @ M_W + M_b
        Cov_pred = torch.diag(phi @ U @ phi.t()).view(-1, 1, 1) * V.unsqueeze(0) + B.unsqueeze(0)

        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1) if apply_softmax else f_s

        py_ /= n_samples

        py.append(py_)

    return torch.cat(py, dim=0)
    def forward_conditional(self, past, current, sig_inds):
        # print(current.shape)
        cond_samples = []
        for sample in current:
            mean = torch.Tensor(self.gmm.means_).to('cuda')
            covariance = torch.Tensor(self.gmm.covariances_).to('cuda')
            sig_inds_comp = list(set(range(current.shape[1])) - set(sig_inds))
            ind_len = len(sig_inds)
            ind_len_not = len(sig_inds_comp)
            x_ind = sample[sig_inds].view(-1, ind_len)
            mean_1 = mean[:, sig_inds_comp].view(-1, ind_len_not)
            cov_1_2 = covariance[:, sig_inds_comp, :][:, :, sig_inds].view(
                -1, ind_len_not, ind_len)
            cov_2_2 = covariance[:, sig_inds, :][:, :, sig_inds].view(
                -1, ind_len, ind_len)
            cov_1_1 = covariance[:,
                                 sig_inds_comp, :][:, :, sig_inds_comp].view(
                                     -1, ind_len_not, ind_len_not)
            cond_means = mean_1 + torch.bmm(
                (torch.bmm(cov_1_2, torch.inverse(cov_2_2))),
                (x_ind - mean[:, sig_inds]).view(-1, ind_len, 1)).squeeze(-1)
            cond_covariance = cov_1_1 - torch.bmm(
                torch.bmm(cov_1_2, torch.inverse(cov_2_2)),
                torch.transpose(cov_1_2, 2, 1))
            marginal_dist = MultivariateNormal(loc=mean_1,
                                               covariance_matrix=cov_1_1)
            # print(torch.exp(marginal_dist.log_prob(sample[sig_inds])))
            cond_pi = torch.Tensor(self.gmm.weights_).to('cuda') * \
                      torch.exp(marginal_dist.log_prob(sample[sig_inds]))
            # print(self.gmm.weights_)
            # print(cond_pi)
            cond_pi = torch.nn.Softmax()(cond_pi)
            m = torch.multinomial(input=cond_pi, num_samples=1)

            # for i in range(self.n_components):
            # mean = torch.Tensor(self.gmm.means_[i]).unsqueeze(0).to('cuda')
            # covariance = torch.Tensor(self.gmm.covariances_[i]).unsqueeze(0).to('cuda')
            # sig_inds_comp = list(set(range(current.shape[1])) - set(sig_inds))
            # ind_len = len(sig_inds)
            # ind_len_not = len(sig_inds_comp)
            # x_ind = sample[sig_inds].view(-1, ind_len)
            # mean_1 = mean[:,sig_inds_comp].view(-1, ind_len_not)
            # cov_1_2 = covariance[:, sig_inds_comp, :][:, :, sig_inds].view(-1, ind_len_not, ind_len)
            # cov_2_2 = covariance[:, sig_inds, :][:, :, sig_inds].view(-1, ind_len, ind_len)
            # cov_1_1 = covariance[:, sig_inds_comp, :][:, :, sig_inds_comp].view(-1, ind_len_not, ind_len_not)
            # mean_cond = mean_1 + torch.bmm((torch.bmm(cov_1_2, torch.inverse(cov_2_2))),
            #                                (x_ind - mean[:, sig_inds]).view(-1, ind_len, 1)).squeeze(-1)
            # covariance_cond = cov_1_1 - torch.bmm(torch.bmm(cov_1_2, torch.inverse(cov_2_2)),
            #                                       torch.transpose(cov_1_2, 2, 1))
            # cond_means.append(mean_cond[0])
            # cond_covariance.append(covariance_cond[0])
            # marginal_dist =  MultivariateNormal(loc=mean_1[0], covariance_matrix=cov_1_1[0])
            # cond_dist = MultivariateNormal(loc=mean_cond[0].squeeze(-1),
            #                           covariance_matrix=covariance_cond[0])
            # print(mean_1[0].shape, cov_1_1[0].shape, sample[sig_inds])
            # cond_pi.append(self.gmm.weights_[i] * torch.exp(marginal_dist.log_prob(sample[sig_inds])))

            # m = torch.multinomial(input=torch.stack(cond_pi), num_samples=1)
            # print(torch.stack(cond_pi), m)
            dist = MultivariateNormal(loc=cond_means[m],
                                      covariance_matrix=cond_covariance[m])
            x = dist.rsample()
            # print('sample:', sample)
            full_sample = sample.clone()
            full_sample[sig_inds_comp] = x
            # print('Conditional', full_sample)
            cond_samples.append(full_sample)
        # print('%%%%%%%%%%%%%%', len(cond_samples))
        return torch.stack(cond_samples), None
def sample_2d_data(dataset, n_samples):
    z = torch.randn(n_samples, 2)

    if dataset == '8gaussians':
        scale = 4
        sq2 = 1 / math.sqrt(2)
        centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (sq2, sq2), (-sq2, sq2),
                   (sq2, -sq2), (-sq2, -sq2)]
        centers = torch.tensor([(scale * x, scale * y) for x, y in centers])
        return sq2 * (0.5 * z +
                      centers[torch.randint(len(centers), size=(n_samples, ))])

    elif dataset == '1gaussian':
        m = MultivariateNormal(torch.zeros(2), torch.eye(2))
        data = m.rsample(torch.Size([n_samples, 1, 1]))
        return data

    elif dataset == 'sine':
        xs = torch.rand((n_samples, 1)) * 4 - 2
        ys = torch.randn(n_samples, 1) * 0.25

        return torch.cat((xs, torch.sin(3 * xs) + ys), dim=1)

    elif dataset == 'moons':
        from sklearn.datasets import make_moons
        data = make_moons(n_samples=n_samples, shuffle=True, noise=0.05)[0]
        data = torch.tensor(data)
        return data

    elif dataset == 'trimodal':
        centers = torch.tensor([(0, 0), (5, 5), (5, -5)])
        stds = torch.tensor([1., 0.5, 0.5]).unsqueeze(-1)
        seq = torch.randint(len(centers), size=(n_samples, ))
        return stds[seq] * z + centers[seq]

    elif dataset == 'trimodal2':
        centers = torch.tensor([(0, 0), (5, 5), (5, -5)])
        stds = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(-1)
        seq = torch.randint(len(centers), size=(n_samples, ))
        data = stds[seq] * z + centers[seq]
        return data.unsqueeze(1).unsqueeze(1)

    elif dataset == 'smile':
        scale = 4
        sq2 = 1 / math.sqrt(2)

        # SMILE

        centers = []
        centers.append((scale * 0.5, -scale * 0.8660254037844387))
        centers.append((-scale * 0.5, -scale * 0.8660254037844387))

        centers.append((scale * 0, -scale * 0))

        centers.append((scale * 0, scale * 1))
        centers.append((scale * sq2, scale * sq2))
        centers.append((scale * -sq2, scale * sq2))
        centers.append((scale * 0.5, scale * math.sqrt(3) / 2))
        centers.append(
            (scale * 0.25881904510252074, scale * 0.9659258262890683))
        centers.append((-scale * 0.5, scale * math.sqrt(3) / 2))
        centers.append(
            (-scale * 0.25881904510252074, scale * 0.9659258262890683))
        centers = torch.tensor(centers)

        weights = torch.tensor([
            0.5 / 3, 0.5 / 3, 0.5 / 3, 0.5 / 7, 0.5 / 7, 0.5 / 7, 0.5 / 7,
            0.5 / 7, 0.5 / 7, 0.5 / 7
        ])

        stds = torch.tensor([0.5] * len(centers)).unsqueeze(-1)

        from torch.distributions import Categorical
        seq = Categorical(probs=weights).sample((n_samples, ))

        return stds[seq] * z + centers[seq]

    elif dataset == '2spirals':
        n = torch.sqrt(torch.rand(n_samples // 2)) * 540 * (2 * math.pi) / 360
        d1x = -torch.cos(n) * n + torch.rand(n_samples // 2) * 0.5
        d1y = torch.sin(n) * n + torch.rand(n_samples // 2) * 0.5
        x = torch.cat(
            [torch.stack([d1x, d1y], dim=1),
             torch.stack([-d1x, -d1y], dim=1)],
            dim=0) / 3
        return x + 0.1 * z

    elif dataset == 'checkerboard':
        x1 = torch.rand(n_samples) * 4 - 2
        x2_ = torch.rand(n_samples) - torch.randint(
            0, 2, (n_samples, ), dtype=torch.float) * 2
        x2 = x2_ + x1.floor() % 2
        return torch.stack([x1, x2], dim=1) * 2

    elif dataset == 'rings':
        n_samples4 = n_samples3 = n_samples2 = n_samples // 4
        n_samples1 = n_samples - n_samples4 - n_samples3 - n_samples2

        # so as not to have the first point = last point,
        # set endpoint=False in np; here shifted by one
        linspace4 = torch.linspace(0, 2 * math.pi, n_samples4 + 1)[:-1]
        linspace3 = torch.linspace(0, 2 * math.pi, n_samples3 + 1)[:-1]
        linspace2 = torch.linspace(0, 2 * math.pi, n_samples2 + 1)[:-1]
        linspace1 = torch.linspace(0, 2 * math.pi, n_samples1 + 1)[:-1]

        circ4_x = torch.cos(linspace4)
        circ4_y = torch.sin(linspace4)
        circ3_x = torch.cos(linspace4) * 0.75
        circ3_y = torch.sin(linspace3) * 0.75
        circ2_x = torch.cos(linspace2) * 0.5
        circ2_y = torch.sin(linspace2) * 0.5
        circ1_x = torch.cos(linspace1) * 0.25
        circ1_y = torch.sin(linspace1) * 0.25

        x = torch.stack([
            torch.cat([circ4_x, circ3_x, circ2_x, circ1_x]),
            torch.cat([circ4_y, circ3_y, circ2_y, circ1_y])
        ],
                        dim=1) * 3.0

        # random sample
        x = x[torch.randint(0, n_samples, size=(n_samples, ))]

        # Add noise
        return x + torch.normal(mean=torch.zeros_like(x),
                                std=0.08 * torch.ones_like(x))

    else:
        raise RuntimeError('Invalid `dataset` to sample from.')
Exemplo n.º 23
0
class TopicRNN(Model):
    """
    Replication of Dieng et al.'s
    ``TopicRnn: A Recurrent Neural Network with Long-range Semantic Dependency``
    (https://arxiv.org/abs/1611.01702).

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    text_encoder : ``Seq2SeqEncoder``
        The encoder used to encode input text.
    text_decoder: ``Seq2SeqEncoder``
        Projects latent word representations into probabilities over the vocabulary.
    variational_autoencoder : ``FeedForward``
        The feedforward network to produce the parameters for the variational distribution.
    topic_dim: ``int``
        The number of latent topics to use.
    freeze_feature_extraction: ``bool``, optional
        If true, the encoding of text as well as learned topics will be frozen.
    classification_mode: ``bool``, optional
        If true, the model will output cross entropy loss w.r.t sentiment instead of
        prediction the rest of the sequence.
    pretrained_file: ``str``, optional
        If provided, will initialize the model with the weights provided in this file.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 text_encoder: Seq2SeqEncoder,
                 variational_autoencoder: FeedForward = None,
                 sentiment_classifier: FeedForward = None,
                 topic_dim: int = 20,
                 freeze_feature_extraction: bool = False,
                 classification_mode: bool = False,
                 pretrained_file: str = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(TopicRNN, self).__init__(vocab, regularizer)

        self.metrics = {
            'cross_entropy': Average(),
            'negative_kl_divergence': Average(),
            'stopword_loss': Average()
        }

        self.classification_mode = classification_mode
        if classification_mode:
            self.metrics['sentiment'] = CategoricalAccuracy()

        if pretrained_file:
            archive = load_archive(pretrained_file)
            pretrained_model = archive.model
            self._init_from_archive(pretrained_model)
        else:
            # Model parameter definition.
            #
            # Defaults reflect Dieng et al.'s decisions when training their semi-unsupervised
            # IMDB sentiment classifier.
            self.text_field_embedder = text_field_embedder
            self.vocab_size = self.vocab.get_vocab_size("tokens")
            self.text_encoder = text_encoder
            self.topic_dim = topic_dim
            self.vocabulary_projection_layer = TimeDistributed(
                Linear(text_encoder.get_output_dim(), self.vocab_size))

            # Parameter gamma from the paper; projects hidden states into binary logits for whether a
            # word is a stopword.
            self.stopword_projection_layer = TimeDistributed(
                Linear(text_encoder.get_output_dim(), 2))

            self.tokens_to_index = vocab.get_token_to_index_vocabulary()

            # This step should only ever be performed ONCE.
            # When running allennlp train, the vocabulary will be constructed before the model instantiation, but
            # we can't create the stopless namespace until we get here.
            # Check if there already exists a stopless namespace: if so refrain from altering it.
            if "stopless" not in vocab._token_to_index.keys():
                assert self.tokens_to_index[DEFAULT_PADDING_TOKEN] == 0 and \
                       self.tokens_to_index[DEFAULT_OOV_TOKEN] == 1
                for token, _ in self.tokens_to_index.items():
                    if token not in STOP_WORDS:
                        vocab.add_token_to_namespace(token, "stopless")

                # Since a vocabulary with the stopless namespace hasn't been saved, save one for convienience.
                vocab.save_to_files("vocabulary")

            # Compute stop indices in the normal vocab space to prevent stop words
            # from contributing to the topic additions.
            self.stop_indices = torch.LongTensor(
                [vocab.get_token_index(stop) for stop in STOP_WORDS])

            # Learnable topics.
            # TODO: How should these be initialized?
            self.beta = nn.Parameter(torch.rand(topic_dim, self.vocab_size))

            # mu: The mean of the variational distribution.
            self.mu_linear = nn.Linear(topic_dim, topic_dim)

            # sigma: The root standard deviation of the variational distribution.
            self.sigma_linear = nn.Linear(topic_dim, topic_dim)

            # noise: used when sampling.
            self.noise = MultivariateNormal(torch.zeros(topic_dim),
                                            torch.eye(topic_dim))

            stopless_dim = vocab.get_vocab_size("stopless")
            self.variational_autoencoder = variational_autoencoder or FeedForward(
                # Takes as input the word frequencies in the stopless dimension and projects
                # the word frequencies into a latent topic representation.
                #
                # Each latent representation will help tune the variational dist.'s parameters.
                stopless_dim,
                3,
                [500, 500, topic_dim],
                torch.nn.ReLU(),
            )

            # The shape for the feature vector for sentiment classification.
            # (RNN Hidden Size + Inference Network output dimension).
            sentiment_input_size = text_encoder.get_output_dim() + topic_dim
            self.sentiment_classifier = sentiment_classifier or FeedForward(
                # As done by the paper; a simple single layer with 50 hidden units
                # and sigmoid activation for sentiment classification.
                sentiment_input_size,
                2,
                [50, 2],
                torch.nn.Sigmoid(),
            )

        if freeze_feature_extraction:
            # Freeze the RNN and VAE pipeline so that only the classifier is trained.
            for name, param in self.named_parameters():
                if "sentiment_classifier" not in name:
                    param.requires_grad = False

        self.sentiment_criterion = nn.CrossEntropyLoss()

        self.num_samples = 50

        initializer(self)

    def _init_from_archive(self, pretrained_model: Model):
        """ Given a TopicRNN instance, take its weights. """
        self.text_field_embedder = pretrained_model.text_field_embedder
        self.vocab_size = pretrained_model.vocab_size
        self.text_encoder = pretrained_model.text_encoder

        # This function is only to be invoved when needing to classify.
        # To avoid manually dealing with padding, instantiate a Seq2Vec instead.
        self.text_to_vec = PytorchSeq2VecWrapper(
            self.text_encoder._modules['_module'])

        self.topic_dim = pretrained_model.topic_dim
        self.vocabulary_projection_layer = pretrained_model.vocabulary_projection_layer
        self.stopword_projection_layer = pretrained_model.stopword_projection_layer
        self.tokens_to_index = pretrained_model.tokens_to_index
        self.stop_indices = pretrained_model.stop_indices
        self.beta = pretrained_model.beta
        self.mu_linear = pretrained_model.mu_linear
        self.sigma_linear = pretrained_model.sigma_linear
        self.noise = pretrained_model.noise
        self.variational_autoencoder = pretrained_model.variational_autoencoder
        self.sentiment_classifier = pretrained_model.sentiment_classifier

    @overrides
    def forward(
            self,  # type: ignore
            input_tokens: Dict[str, torch.LongTensor],
            output_tokens: Dict[str, torch.LongTensor],
            frequency_tokens: Dict[str, torch.LongTensor],
            sentiment: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        input_tokens : Dict[str, Variable], required
            The BPTT portion of text to encode.
        output_tokens : Dict[str, Variable], required
            The BPTT portion of text to produce.
        word_counts : Dict[str, int], required
            Words mapped to the frequency in which they occur in their source text.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the
            label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        output_dict = {}

        # Encode the input text.
        # Shape: (batch x sequence length x hidden size)
        embedded_input = self.text_field_embedder(input_tokens)
        input_mask = util.get_text_field_mask(input_tokens)
        encoded_input = self.text_encoder(embedded_input, input_mask)

        # Initial projection into vocabulary space, v^T * h_t.
        # Shape: (batch x sequence length x vocabulary size)
        logits = self.vocabulary_projection_layer(encoded_input)

        # Predict stopwords.
        # Note that for every logit in the projection into the vocabulary, the stop indicator
        # will be the same within time steps. This is because we predict whether forthcoming
        # words are stops or not and zero out topic additions for those time steps.
        stopword_logits = sigmoid(
            self.stopword_projection_layer(encoded_input))
        stopword_predictions = torch.argmax(stopword_logits, dim=-1)
        stopword_predictions = stopword_predictions.unsqueeze(2).expand_as(
            logits)

        # Word frequency vectors and noise aren't generated with the model. If the model
        # is running on a GPU, these tensors need to be moved to the correct device.
        device = logits.device

        # Mask the output for proper loss calculation.
        output_mask = util.get_text_field_mask(output_tokens)
        relevant_output = output_tokens['tokens'].contiguous()
        relevant_output_mask = output_mask.contiguous()

        # Compute Gaussian parameters.
        stopless_word_frequencies = self._compute_word_frequency_vector(
            frequency_tokens).to(device=device)
        mapped_term_frequencies = self.variational_autoencoder(
            stopless_word_frequencies)

        # If the inference network ever learns to output just 0, something has gone wrong.
        assert mapped_term_frequencies.sum().item() > 0

        mu = self.mu_linear(mapped_term_frequencies)
        log_sigma = self.sigma_linear(mapped_term_frequencies)

        # I .Compute KL-Divergence.
        # A closed-form solution exists since we're assuming q is drawn
        # from a normal distribution.
        kl_divergence = 2 * log_sigma - (mu**2) - torch.exp(2 * log_sigma)

        # Sum along the topic dimension and add const.
        kl_divergence = (self.topic_dim + torch.sum(kl_divergence)) / 2

        aggregate_cross_entropy_loss = 0
        for _ in range(self.num_samples):

            # Compute noise for sampling.
            epsilon = self.noise.rsample().to(device=device)

            # Compute noisy topic proportions given Gaussian parameters.
            theta = mu + torch.exp(log_sigma) * epsilon

            # II. Compute cross entropy against next words for the current sample of noise.
            # Padding and OOV tokens are indexed at 0 and 1.
            topic_additions = torch.mm(theta, self.beta)
            topic_additions.t()[0] = 0  # Padding will be treated as stops.
            topic_additions.t()[1] = 0  # Unknowns will be treated as stops.

            # Stop words have no contribution via topics.
            topic_additions = (1 - stopword_predictions).float(
            ) * topic_additions.unsqueeze(1).expand_as(logits)
            cross_entropy_loss = util.sequence_cross_entropy_with_logits(
                logits + topic_additions, relevant_output,
                relevant_output_mask)
            aggregate_cross_entropy_loss += cross_entropy_loss

        averaged_cross_entropy_loss = aggregate_cross_entropy_loss / self.num_samples

        # III. Compute stopword probabilities and gear RNN hidden states toward learning them.
        relevant_stopword_output = self._compute_stopword_mask(
            output_tokens).contiguous().to(device=device)
        stopword_loss = util.sequence_cross_entropy_with_logits(
            stopword_logits, relevant_stopword_output, relevant_output_mask)

        if self.classification_mode:
            output_dict['loss'] = self._classify_sentiment(
                frequency_tokens, mapped_term_frequencies, sentiment)
        else:
            output_dict[
                'loss'] = -kl_divergence + averaged_cross_entropy_loss + stopword_loss

        self.metrics['negative_kl_divergence']((-kl_divergence).item())
        self.metrics['cross_entropy'](averaged_cross_entropy_loss.item())
        self.metrics['stopword_loss'](stopword_loss.item())

        return output_dict

    def _classify_sentiment(
            self,  # type: ignore
            frequency_tokens: Dict[str, torch.LongTensor],
            mapped_term_frequencies: torch.Tensor,
            sentiment: Dict[str, torch.LongTensor]) -> torch.Tensor:
        # pylint: disable=arguments-differ
        """
        Using the entire review (frequency_tokens), classify it as positive or negative.
        """

        # Encode the input text.
        # Shape: (batch, sequence length, hidden size)
        embedded_input = self.text_field_embedder(frequency_tokens)
        input_mask = util.get_text_field_mask(frequency_tokens)

        # Use text_to_vec to avoid dealing with padding.
        encoded_input = self.text_to_vec(embedded_input, input_mask)

        # Construct feature vector.
        # Shape: (batch, RNN hidden size + number of topics)
        sentiment_features = torch.cat(
            [encoded_input, mapped_term_frequencies], dim=-1)

        # Classify.
        logits = self.sentiment_classifier(sentiment_features)
        loss = self.sentiment_criterion(logits, sentiment)

        self.metrics['sentiment'](logits, sentiment)

        return loss

    def _compute_word_frequency_vector(
        self,
        frequency_tokens: Dict[str,
                               torch.LongTensor]) -> Dict[str, torch.Tensor]:
        """ Given the window in which we're allowed to collect word frequencies, produce a
            vector in the 'stopless' dimension for the variational distribution.
        """
        batch_size = frequency_tokens['tokens'].size(0)
        res = torch.zeros(batch_size, self.vocab.get_vocab_size("stopless"))
        for i, row in enumerate(frequency_tokens['tokens']):
            # A conversion between namespaces (full vocab to stopless) is necessary.
            words = [
                self.vocab.get_token_from_index(index)
                for index in row.tolist()
            ]
            word_counts = dict(Counter(words))
            num_words = sum(word_counts.values())

            # TODO: Make this faster.
            for word, count in word_counts.items():
                if word in self.tokens_to_index:
                    index = self.vocab.get_token_index(word, "stopless")

                    # Exclude padding token from influencing inference.
                    res[i][index] = (count * int(index > 0)) / num_words

        return res

    def _compute_stopword_mask(
            self,
            output_tokens: Dict[str,
                                torch.LongTensor]) -> Dict[str, torch.Tensor]:
        """ Given a set of output tokens, compute a mask where 1 indicates stopword presence and 0
            indicates stopword absence.
        """
        res = torch.zeros_like(output_tokens['tokens'])
        for i, row in enumerate(output_tokens['tokens']):
            words = [
                self.vocab.get_token_from_index(index)
                for index in row.tolist()
            ]
            res[i] = torch.LongTensor(
                [int(word in STOP_WORDS) for word in words])

        return res

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            metric_name: metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        }

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        # TODO: What makes sense for a decode for TopicRNN?
        return output_dict
Exemplo n.º 24
0
    def sample_lambda0_r(self,
                         d,
                         batch_size,
                         offset=0,
                         object_locations=None,
                         object_margin=None,
                         num_objects=None,
                         gau=None,
                         max_rejections=1000,
                         margin_offset=2):
        """Sample dataset parameters perturbed by r."""
        name = d['name']
        family = d['family']
        attr_name = '{}_{}'.format(name, 'center')
        if self.wn:
            lambda_r = self.normalize_weights(name=name, prop='center')
        elif family != 'half_normal':
            lambda_r = getattr(self, attr_name)
        parameters = []
        if family == 'gaussian':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(constraints.greater_than(
            #     1.))(lambda_r)
            # lambda_r_scale = transform_to(constraints.greater_than(
            #     self.minimum_spatial_scale))(lambda_r_scale)
            # TODO: Add constraint function here
            # w=module.weight.data
            # w=w.clamp(0.5,0.7)
            # module.weight.data=w

            if gau is None:
                gau = MultivariateNormal(loc=lambda_r,
                                         covariance_matrix=lambda_r_scale)
            if d['return_sampler']:
                return gau
            if name == 'object_location':
                if not len(object_locations):
                    return gau.rsample(), gau
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=gau)
            else:
                raise NotImplementedError(name)
        elif family == 'normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'cnormal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)

            # Explicitly clamp the scale!
            lambda_r_scale = torch.clamp(lambda_r_scale,
                                         self.minimum_spatial_scale, 999.)
            nor = CNormal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            elif name == 'object_location':
                # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale)  # noqa
                if not len(object_locations):
                    return nor.rsample(), nor
                else:
                    parameters = self.rejection_sampling(
                        object_margin=object_margin,
                        margin_offset=margin_offset,
                        object_locations=object_locations,
                        max_rejections=max_rejections,
                        num_objects=num_objects,
                        gau=nor)
            else:
                for idx in range(batch_size):
                    parameters.append(nor.rsample())
        elif family == 'abs_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            # lambda_r = transform_to(Normal.arg_constraints['loc'])(lambda_r)
            # lambda_r_scale = transform_to(Normal.arg_constraints['scale'])(lambda_r_scale)  # noqa
            # lambda_r = transforms.AbsTransform()(lambda_r)
            # lambda_r_scale = transforms.AbsTransform()(lambda_r_scale)
            # These kill grads!! # lambda_r = torch.abs(lambda_r)
            # These kill grads!! lambda_r_scale = torch.abs(lambda_r_scale)
            nor = Normal(loc=lambda_r, scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'half_normal':
            attr_name = '{}_{}'.format(name, 'scale')
            if self.wn:
                lambda_r_scale = self.normalize_weights(name=name,
                                                        prop='scale')
            else:
                lambda_r_scale = getattr(self, attr_name)
            nor = HalfNormal(scale=lambda_r_scale)
            if d['return_sampler']:
                return nor
            else:
                parameters = nor.rsample([batch_size])
        elif family == 'categorical':
            if d['return_sampler']:
                gum = RelaxedOneHotCategorical(1e-1, logits=lambda_r)
                return gum
                # return lambda sample_size: self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset   # noqa
            for _ in range(batch_size):
                parameters.append(
                    self.argmax(self.gumbel_fun(lambda_r, name=name)) +
                    offset)  # noqa Use default temperature -> max
        elif family == 'relaxed_bernoulli':
            bern = RelaxedBernoulli(temperature=1e-1, logits=lambda_r)
            if d['return_sampler']:
                return bern
            else:
                parameters = bern.rsample([batch_size])
        else:
            raise NotImplementedError(
                '{} not implemented in sampling.'.format(family))
        return parameters
Exemplo n.º 25
0
    def forward(self, index, X_c=None, X_d=None, X_b=None):

        B = len(X_c)  #shape of batch
        M_samples = 1000

        KL_z = 0
        KL_s = 0
        LL = 0
        elbo = 0

        term_1 = 0
        term_2 = 0
        term_3 = 0

        rik = np.zeros((self.N, self.K))

        start = B * index
        end = B * (index + 1)

        for i in range(start, end):
            #for i in range(len(X_c)):
            # -------Variational function q(z)-------
            self.q_z_CoVar = torch.exp(self.q_log_var[i, :]) * torch.eye(
                self.L)
            q_z = MultivariateNormal(self.q_z_mean[i, :], self.q_z_CoVar)
            p_z = MultivariateNormal(
                self.posterior_mean[i, :],
                self.posterior_var[i, :] * torch.eye(self.L))
            KL_i_z = torch.distributions.kl.kl_divergence(q_z, p_z)
            #print(f'KL z = {KL_i_z}')
            # -------Variational function q(s)-------
            q_s = Categorical(self.q_s_param[i, :])
            p_s = Categorical(self.posterior_mu[i, :])
            KL_i_s = torch.distributions.kl.kl_divergence(q_s, p_s)
            s_pred = torch.nn.functional.gumbel_softmax(
                self.q_s_param
            )  # Esto me daria el perfil estimado para ese dia
            #print(f'KL s = {KL_i_s}')
            # ----------------------------------------

            LL_i = 0
            term_1_i = 0
            term_2_i = 0
            term_3_i = 0
            z_pred = q_z.rsample([M_samples])  # M x L

            for k in range(self.K):
                term_1_k = self.gaussian.log_pdf(k, X_c[i - start], z_pred)
                term_2_k = self.bernoulli.log_pdf(k, X_b[i - start], z_pred)
                term_3_k = self.categorical.log_pdf(k, X_d[i - start], z_pred)
                #term_3_k = 0
                LL_i += self.q_s_param[i, k] * torch.sum(term_1_k + term_2_k +
                                                         term_3_k)
                term_1_i += term_1_k
                term_2_i += term_2_k
                term_3_i += term_3_k

            rik[i, :] = torch.nn.functional.softmax(self.q_s_param[i, :],
                                                    dim=0).detach().numpy()
            elbo_i = LL_i - KL_i_z - KL_i_s

            elbo += elbo_i

            # print(KL_s)
            term_1 += term_1_i
            term_2 += term_2_i
            term_3 += term_3_i
            LL += LL_i

        KL_z = KL_i_z
        KL_s = KL_i_s

        return -elbo, LL, KL_z, KL_s, rik, term_1, term_2, term_3
Exemplo n.º 26
0
class RealNVP_MLP(nn.Module):
    """ Minimal Real NVP architecture

    Args:
        dims (int,): input dimension
        n_realnvp_blocks (int): number of pairs of coupling layers
        block_depth (int): repetition of blocks with shared param
        init_weight_scale (float): scaling factor for weights in s and t layers
        prior_arg (dict): specifies the base distribution
        mask_type (str): 'half' or 'inter' masking pattern
        hidden_dim (int): # of hidden neurones per layer (coupling MLPs)
    """

    def __init__(self, dim, n_realnvp_blocks, 
                 block_depth,
                 init_weight_scale=None,
                 prior_arg={'type': 'standn'},
                 mask_type='half',  
                 hidden_dim=10,
                 hidden_depth=3,
                 hidden_bias=True,
                 hidden_activation=torch.relu,
                 device='cpu'):
        super(RealNVP_MLP, self).__init__()

        self.device = device
        self.dim = dim
        self.n_blocks = n_realnvp_blocks
        self.block_depth = block_depth
        self.couplings_per_block = 2  # one update of entire layer per block 
        self.n_layers_in_coupling = hidden_depth  # depth of MLPs in coupling layers 
        self.hidden_dim_in_coupling = hidden_dim
        self.hidden_bias = hidden_bias
        self.hidden_activation = hidden_activation
        self.init_scale_in_coupling = init_weight_scale

        mask = torch.ones(dim, device=self.device)
        if mask_type == 'half':
            mask[:int(dim / 2)] = 0
        elif mask_type == 'inter':
            idx = torch.arange(dim, device=self.device)
            mask = mask * (idx % 2 == 0)
        else:
            raise RuntimeError('Mask type is either half or inter')
        self.mask = mask.view(1, dim)

        self.coupling_layers = self.initialize()

        self.beta = 1.  # effective temperature needed e.g. in Langevin

        self.prior_arg = prior_arg

        if prior_arg['type'] == 'standn':
            self.prior_prec =  torch.eye(dim).to(device)
            self.prior_log_det = 0
            self.prior_distrib = MultivariateNormal(
                torch.zeros((dim,), device=self.device), self.prior_prec)

        elif prior_arg['type'] == 'uncoupled':
            self.prior_prec = prior_arg['a'] * torch.eye(dim).to(device)
            self.prior_log_det = - torch.logdet(self.prior_prec)
            self.prior_distrib = MultivariateNormal(
                torch.zeros((dim,), device=self.device),
                precision_matrix=self.prior_prec)

        elif prior_arg['type'] == 'coupled':
            self.beta_prior = prior_arg['beta']
            self.coef = prior_arg['alpha'] * dim
            prec = torch.eye(dim) * (3 * self.coef + 1 / self.coef)
            prec -= self.coef * torch.triu(torch.triu(torch.ones_like(prec),
                                                      diagonal=-1).T, diagonal=-1)
            prec = prior_arg['beta'] * prec
            self.prior_prec = prec.to(self.device)
            self.prior_log_det = - torch.logdet(prec)
            self.prior_distrib = MultivariateNormal(
                torch.zeros((dim,), device=self.device),
                precision_matrix=self.prior_prec)

        elif prior_arg['type'] == 'white':
            cov = prior_arg['cov']
            self.prior_prec = torch.inverse(cov).to(device)
            self.prior_prec = 0.5 * (self.prior_prec + self.prior_prec.T)
            self.prior_mean = prior_arg['mean'].to(device)
            self.prior_log_det = - torch.logdet(self.prior_prec)
            self.prior_distrib = MultivariateNormal(
                prior_arg['mean'],
                precision_matrix=self.prior_prec
                )

        elif prior_arg['type'] == 'bridge':
            self.bridge_kwargs = prior_arg['bridge_kwargs']

        else:
            raise NotImplementedError("Invalid prior arg type")

    def forward(self, x, return_per_block=False):
        log_det_jac = torch.zeros(x.shape[0], device=self.device)

        if return_per_block:
            xs = [x]
            log_det_jacs = [log_det_jac]

        for block in range(self.n_blocks):
            couplings = self.coupling_layers[block]

            for dt in range(self.block_depth):
                for coupling_layer in couplings:
                    x, log_det_jac = coupling_layer(x, log_det_jac)

                if return_per_block:
                    xs.append(x)
                    log_det_jacs.append(log_det_jac)

        if return_per_block:
            return xs, log_det_jacs
        else:
            return x, log_det_jac

    def backward(self, x, return_per_block=False):
        log_det_jac = torch.zeros(x.shape[0], device=self.device)

        if return_per_block:
            xs = [x]
            log_det_jacs = [log_det_jac]
        
        for block in range(self.n_blocks):
            couplings = self.coupling_layers[::-1][block]

            for dt in range(self.block_depth):
                for coupling_layer in couplings[::-1]:
                    x, log_det_jac = coupling_layer(
                        x, log_det_jac, inverse=True)

                if return_per_block:
                    xs.append(x)
                    log_det_jacs.append(log_det_jac)

        if return_per_block:
            return xs, log_det_jacs
        else:
            return x, log_det_jac

    def initialize(self):
        dim = self.dim
        coupling_layers = []

        for block in range(self.n_blocks):
            layer_dims = [self.hidden_dim_in_coupling] * \
                (self.n_layers_in_coupling - 2)
            layer_dims = [dim] + layer_dims + [dim]

            couplings = self.build_coupling_block(layer_dims)

            coupling_layers.append(nn.ModuleList(couplings))

        return nn.ModuleList(coupling_layers)

    def build_coupling_block(self, layer_dims=None, nets=None, reverse=False):
        count = 0
        coupling_layers = []
        for count in range(self.couplings_per_block):
            s = MLP(layer_dims, init_scale=self.init_scale_in_coupling)
            s = s.to(self.device)
            t = MLP(layer_dims, init_scale=self.init_scale_in_coupling)
            t = t.to(self.device)

            if count % 2 == 0:
                mask = 1 - self.mask
            else:
                mask = self.mask
            
            dt = self.n_blocks * self.couplings_per_block * self.block_depth
            dt = 2 / dt
            coupling_layers.append(ResidualAffineCoupling(
                s, t, mask, dt=dt))

        return coupling_layers

    def nll(self, x):
        z, log_det_jac = self.backward(x)

        if self.prior_arg['type']=='bridge':
            a_min = torch.tensor([self.bridge_kwargs["x0"],self.bridge_kwargs["y0"]])
            b_min = torch.tensor([self.bridge_kwargs["x1"],self.bridge_kwargs["y1"]])
            dt = self.bridge_kwargs["dt"]
            prior_nll = bridge_energy(z, dt=dt, a_min=a_min, b_min=b_min, device=self.device)
            return prior_nll - log_det_jac
        elif self.prior_arg['type'] == 'white':
                z = z - self.prior_mean

        prior_ll = - 0.5 * torch.einsum('ki,ij,kj->k', z, self.prior_prec, z)
        prior_ll -= 0.5 * (self.dim * np.log(2 * np.pi) + self.prior_log_det)

        ll = prior_ll + log_det_jac
        nll = -ll
        return nll

    def sample(self, n):
        if self.prior_arg['type'] == 'standn':
            z = torch.randn(n, self.dim, device=self.device)
        elif self.prior_arg['type'] == 'bridge':
            # get a bridge
            n_steps = self.bridge_kwargs["n_steps"]
            bridges = torch.zeros(n, n_steps - 2, 1, 2, device=self.device)
            for i in range(n):
                bridges[i,:] = get_bridge(**self.bridge_kwargs)
            z = bridges.detach().requires_grad_().view(n, -1)
        else:
            z = self.prior_distrib.rsample(torch.Size([n, ])).to(self.device)

        return self.forward(z)[0]

    def U(self, x):
        """
        alias
        """
        return self.nll(x)

    def V(self, x):
        z, log_det_jac = self.backward(x)
        return self.beta_prior * (z ** 2 / 2).sum(dim=-1) / self.coef - log_det_jac

    def U_coupling_per_site(self, x):
        """
        return the (\nable phi) ** 2 to be used in direct computation
        with dirichlet boundary conditions to 0
        U = U_coup_per_site.sum(dim=-1) + V
        """
        bc_value = 0
        z, _ = self.backward(x)

        z_ = F.pad(input=z, pad=(1,) * 2, mode='constant',
                   value=bc_value)

        return ((z_[:, 1:] - z_[:, :-1]) ** 2 / 2) * self.coef * self.beta_prior