Esempio n. 1
0
    def gen_default_priors(self, data, K, L,
                           # sig_prior=Gamma(10, 10),
                           sig_prior=LogNormal(0, 0.01),
                           alpha_prior=Gamma(1, 1),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.__cache_model_constants__(data, K, L)

        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {'mu0': mu0_prior, 'mu1': mu1_prior, 'sig': sig_prior,
                       'eta0': eta0_prior, 'eta1': eta1_prior, 'W': W_prior,
                       'alpha': alpha_prior}
Esempio n. 2
0
 def test_dirichlet_shape(self):
     alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
     alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
     self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
     self.assertEqual(Dirichlet(alpha).sample((5, )).size(), (5, 2, 3))
     self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4, ))
     self.assertEqual(Dirichlet(alpha_1d).sample((1, )).size(), (1, 4))
Esempio n. 3
0
def log_prior(Z, variable_types):
    if 'pi_unconstrained' in variable_types.keys():
        pi = softmax(Z['pi_unconstrained'][0], dim=-1)
        logp = Dirichlet(
            torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(
                pi, dim=-1)
    else:
        logp = 0
    for i, (key, z) in enumerate(Z.items()):
        z = z[0]
        if key != 'pi_unconstrained':
            if variable_types[key] == 'Categorical':
                alpha = softmax(z, dim=-1, additional=-50.)
                logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \
                  + log_det_jacobian_softmax(alpha, dim=-1), dim=-1)
            #elif variable_types[key] == 'Bernoulli':
            #	theta = torch.sigmoid(z)
            #	logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\
            #			+ log_det_jacobian_sigmoid(theta), dim=-1)
            elif variable_types[key] == 'Bernoulli':
                logp += TransBeta.log_prob(z).sum()
            elif variable_types[key] == 'Beta':
                alpha, beta = torch.exp(z)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) +
                                  torch.log(alpha),
                                  dim=-1)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) +
                                  torch.log(beta),
                                  dim=-1)
    return torch.mean(logp)
Esempio n. 4
0
def log_prior(Z, variable_types):
    """
	Z : A dictionary containing the draws from variational posterior for each parameter.
	variable_types : a dictionary that contains distribution name assigned to each parameter.
	"""
    ## We proceed similarly as in log-likelihood computation, however since the elements of
    ## Z are in expanded form and the prior is not data dependent we compute the contribution
    ## of only the first element of element of Z.
    pi = softmax(Z['pi_unconstrained'][0], dim=-1)
    logp = Dirichlet(
        torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(pi,
                                                                     dim=-1)
    for i, (key, z) in enumerate(Z.items()):
        if key != 'pi_unconstrained':
            z = z[0]
            if variable_types[key] == 'Categorical':
                alpha = softmax(z, dim=-1, additional=-50.)
                logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \
                  + log_det_jacobian_softmax(alpha, dim=-1), dim=-1)
            elif variable_types[key] == 'Bernoulli':
                theta = torch.sigmoid(z)
                logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\
                  + log_det_jacobian_sigmoid(theta), dim=-1)
            elif variable_types[key] == 'Beta':
                alpha, beta = torch.exp(z)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) +
                                  torch.log(alpha),
                                  dim=-1)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) +
                                  torch.log(beta),
                                  dim=-1)
    return logp
Esempio n. 5
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        n_obs: int = 1.0,
    ):
        # generative_outputs is a dict of the return value from `generative(...)`
        # assume that `n_obs` is the number of training data points
        p_x_c = generative_outputs["p_x_c"]
        gamma = generative_outputs["gamma"]

        # compute Q
        # take mean of number of cells and multiply by n_obs (instead of summing n)
        q_per_cell = torch.sum(gamma * -p_x_c, 1)

        # third term is log prob of prior terms in Q
        theta_log = F.log_softmax(self.theta_logit, dim=-1)
        theta_log_prior = Dirichlet(self.dirichlet_concentration)
        theta_log_prob = -theta_log_prior.log_prob(
            torch.exp(theta_log) + THETA_LOWER_BOUND)
        prior_log_prob = theta_log_prob
        delta_log_prior = Normal(self.delta_log_mean,
                                 self.delta_log_log_scale.exp().sqrt())
        delta_log_prob = torch.masked_select(
            delta_log_prior.log_prob(self.delta_log), (self.rho > 0))
        prior_log_prob += -torch.sum(delta_log_prob)

        loss = (torch.mean(q_per_cell) * n_obs + prior_log_prob) / n_obs

        return LossRecorder(loss, q_per_cell, torch.zeros_like(q_per_cell),
                            prior_log_prob)
Esempio n. 6
0
 def test_dirichlet_shape(self):
     dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
     self.assertEqual(dist._batch_shape, torch.Size((3,)))
     self.assertEqual(dist._event_shape, torch.Size((2,)))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
Esempio n. 7
0
    def select_action(self, obs):
        concentration, value = self.forward(obs)

        m = Dirichlet(concentration)

        action = m.sample()
        self.saved_actions.append(SavedAction(m.log_prob(action), value))
        return list(action.cpu().numpy())
Esempio n. 8
0
 def test_forward(self):
     source_dirichlet = Dirichlet(torch.ones(10))
     batch_size = 12
     observed_data = source_dirichlet.sample((batch_size, ))
     # Check that forward function returns
     _, membership_probs = self.dmm(observed_data)
     # Ensure membership probabilities sum to one
     for prob in membership_probs.sum(dim=1):
         self.assertAlmostEqual(prob.item(), 1, places=5)
Esempio n. 9
0
 def test_dirichlet_log_prob(self):
     num_samples = 10
     alpha = torch.exp(torch.randn(5))
     dist = Dirichlet(alpha)
     x = dist.sample((num_samples,))
     actual_log_prob = dist.log_prob(x)
     for i in range(num_samples):
         expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
         self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
Esempio n. 10
0
def sample_partitions(B, N, K, alpha=1.0, rand_K=True, device='cpu'):
    pi = Dirichlet(alpha * torch.ones(K)).sample([B]).to(device)
    if rand_K:
        to_use = (torch.rand(B, K) < 0.5).float().to(device)
        to_use[..., 0] = 1
        pi = pi * to_use
        pi = pi / pi.sum(1, keepdim=True)
    labels = Categorical(probs=pi).sample([N]).to(device)
    labels = labels.transpose(0, 1).contiguous()
    return labels
Esempio n. 11
0
 def test_backward(self):
     source_dirichlet = Dirichlet(torch.ones(10))
     batch_size = 12
     observed_data = source_dirichlet.sample((batch_size, ))
     # Obtain loss
     nll, _ = self.dmm(observed_data)
     loss = nll.sum()
     # Check that gradient is non-zero for params
     loss.backward()
     for param in self.dmm.parameters():
         self.assertIsNotNone(param.grad)
def sample_labels(B, N, K_low, K_high, alpha=1.0):
    pi = Dirichlet(alpha * torch.ones(K_high)).sample([B])
    K = torch.randint(K_low, K_high + 1, size=(B, ))
    to_use = torch.zeros(B, K_high).int()
    for i, k in enumerate(K):
        to_use[i, :k] = 1
    pi = pi * to_use
    pi = pi / pi.sum(1, keepdim=True)
    labels = Categorical(probs=pi).sample([N])
    labels = labels.transpose(0, 1).contiguous()
    return labels
Esempio n. 13
0
    def forward(self, state):
        concentration = self._get_concentration(state)
        if self.training:
            # PyTorch can't backwards pass _sample_dirichlet
            action = Dirichlet(concentration).rsample()
        else:
            # ONNX can't export Dirichlet()
            action = torch._sample_dirichlet(concentration)

        log_prob = Dirichlet(concentration).log_prob(action)
        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.unsqueeze(dim=1))
Esempio n. 14
0
def sample_l1_sphere(device, shape):
    '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope.
    Inputs:
        device: 'cpu' | 'cuda' | other torch devices
        shape: a pair (batchsize, dim)
    Outputs:
        matrix of shape `shape` such that each row is a sample.
    '''
    batchsize, dim = shape
    dirdist = Dirichlet(concentration=torch.ones(dim, device=device))
    noises = dirdist.sample([batchsize])
    signs = torch.sign(torch.rand_like(noises) - 0.5)
    return noises * signs
Esempio n. 15
0
 def test_dirichlet_sample(self):
     self._set_rng_seed()
     alpha = torch.exp(torch.randn(3))
     self._check_sampler_sampler(Dirichlet(alpha),
                                 scipy.stats.dirichlet(alpha.numpy()),
                                 'Dirichlet(alpha={})'.format(list(alpha)),
                                 multivariate=True)
Esempio n. 16
0
    def test_TwoFactorFractionalSIR(self):
        dist = Dirichlet(torch.tensor([10000., 1., 1.]))
        sir = m.TwoFactorSIR((0.1, 0.05, 0.05, 0.05), dist, dt=1e-1)

        x = sir.sample_path(1000)

        self.assertEqual(x.shape, torch.Size([1000, 3]))
def _AMN_optimization_ENS_mixture(AMN_net,
                                  expert_net,
                                  optimizer,
                                  state_batch,
                                  feature_regression=False,
                                  tau=0.1,
                                  beta=0.01,
                                  GAMMA=0.99,
                                  training=True):
    """
    Apply the standard procedure to deep Q network.
    """
    if not training:
        return None

    AMN_policy = 0
    alpha = Dirichlet(torch.ones(AMN_net.get_num_ensembles())).sample()

    for i in range(AMN_net.get_num_ensembles()):
        AMN_q_value = AMN_net(state_batch, ens_num=i, last_layer=False)
        AMN_policy += alpha[i] * to_policy(AMN_q_value)
        loss = 0

    expert_q_value = expert_net(state_batch, last_layer=False)
    expert_policy = to_policy(expert_q_value).detach()

    loss -= torch.sum(expert_policy * torch.log(AMN_policy + 1e-8))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.detach()
Esempio n. 18
0
    def gen_default_priors(self,
                           K,
                           L,
                           sig_prior=LogNormal(0, 1),
                           alpha_prior=Gamma(1., 1.),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.L = L

        if K is None:
            K = 30

        self.K = K

        # FIXME: these should be an ordered prior of TruncatedNormals
        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {
            'mu0': mu0_prior,
            'mu1': mu1_prior,
            'sig': sig_prior,
            'H': Normal(0, 1),
            'eta0': eta0_prior,
            'eta1': eta1_prior,
            'W': W_prior,
            'alpha': alpha_prior
        }
def kl_divergence(model_concentrations, target_concentrations, mode='reverse'):
    """
    Input: Model concentrations, target concentrations parameters.
    Output: Average of the KL between the two Dirichlet.
    """
    assert torch.all(model_concentrations > 0)
    assert torch.all(target_concentrations > 0)

    target_dirichlet = Dirichlet(target_concentrations)
    model_dirichlet = Dirichlet(model_concentrations)
    kl_divergences = _kl_dirichlet_dirichlet(
        p=target_dirichlet if mode == 'forward' else model_dirichlet,
        q=model_dirichlet if mode == 'forward' else target_dirichlet)
    assert_no_nan_no_inf(kl_divergences)
    mean_kl = torch.mean(kl_divergences)
    assert_no_nan_no_inf(mean_kl)
    return mean_kl
Esempio n. 20
0
    def backward(ctx, grad_output):
        grad_input=None

        if not ctx.train:
            raise RuntimeError('Running backward on shake when train is False')

        if ctx.needs_input_grad[0]:
            dist = Dirichlet(th.full((ctx.xsh[ctx.dim],),ctx.concentration))
            beta = dist.sample(sample_shape=th.Size([ctx.xsh[ctx.batchdim]]))
            beta = beta.to(th.device(ctx.dev))
            sh = [1 for _ in range(len(ctx.xsh))]
            sh[ctx.batchdim], sh[ctx.dim] = ctx.xsh[ctx.batchdim], ctx.xsh[ctx.dim]
            beta = beta.view(*sh)
            grad_output = grad_output.unsqueeze(ctx.dim).expand(*ctx.xsh)

            grad_input = grad_output * beta

        return grad_input, None, None, None, None, None
Esempio n. 21
0
    def test_TwoFactorSIRD(self):
        dist = Dirichlet(torch.tensor([10000., 1., 1., 1.]))
        sird = m.ThreeFactorSIRD((0.1, 0.05, 0.05, 0.01, 0.05, 0.05, 0.05),
                                 dist,
                                 dt=1e-1)

        x = sird.sample_path(1000)

        self.assertEqual(x.shape, torch.Size([1000, 4]))
    def sample(self, B, N, K, return_gt=False):
        device = 'cpu' if not torch.cuda.is_available() \
                else torch.cuda.current_device()
        pi = Dirichlet(torch.ones(K)).sample(torch.Size([B])).to(device)
        labels = Categorical(probs=pi).sample(torch.Size([N])).to(device)
        labels = labels.transpose(0, 1).contiguous()

        X, params = self.mvn.sample(B, K, labels)
        if return_gt:
            return X, labels, pi, params
        else:
            return X
def construct_variational_parameters(T, N):
    # T-1 params for the second part of variational Beta factors
    kappa = Variable(Uniform(0, 2).rsample([T - 1]), requires_grad=True)
    # T scale params for the variational Gamma factors
    tau_0 = Uniform(0, 100).rsample([T])
    # T rate params for the variational Gamma factors
    tau_1 = LogNormal(0, 1).rsample([T])
    tau = Variable(torch.stack((tau_0, tau_1)).T, requires_grad=True)
    phi = Variable(
        Dirichlet(1 / T * torch.ones(T)).rsample([N]),
        requires_grad=True)  # N,T params for the variational Cat factors

    return kappa, tau, phi
def log_prior(theta_unconstrained, pi_unconstrained):
    theta = torch.sigmoid(theta_unconstrained)
    pi = softmax(pi_unconstrained, dim=-1)
    """ Both tau and pi are transformed, so we need to add correction terms
		to log densities. For tau this will be theta_unconstrained (log |d/dy exp(y)| = y)
		and for pi this will be log_det_jacobian_softmax (for reasons too 
		complicated to be explained here).
	"""
    theta_logp = torch.sum(log_det_jacobian_sigmoid(theta_unconstrained))
    pi_logp = torch.sum(
        Dirichlet(torch.ones_like(pi)).log_prob(pi) +
        log_det_jacobian_softmax(pi, dim=-1))
    return pi_logp + theta_logp
Esempio n. 25
0
    def forward(ctx, x, dim, batchdim, concentration, train):
        ctx.dim = dim
        ctx.batchdim = batchdim
        ctx.concentration = concentration
        xsh, ctx.xsh = [x.shape]*2
        ctx.dev = x.device
        ctx.train = True

        if train:
            # Randomly sample from Dirichlet distribution
            dist = Dirichlet(th.full((xsh[dim],), concentration))
            alpha = dist.sample(sample_shape=th.Size([xsh[batchdim]]))
            alpha = alpha.to(th.device(x.device))
            sh = [1 for _ in range(len(xsh))]
            sh[batchdim], sh[dim] = xsh[batchdim], xsh[dim]
            alpha = alpha.view(*sh)


            y = (x * alpha).sum(dim)
        else:
            y = x.mean(dim)

        return y
Esempio n. 26
0
def entropy(alpha, uncertainty_type, n_bins=10, plot=True):
    entropy = []

    if uncertainty_type == 'aleatoric':
        p = torch.nn.functional.normalize(alpha, p=1, dim=-1)
        entropy.append(
            Categorical(p).entropy().squeeze().cpu().detach().numpy())
    elif uncertainty_type == 'epistemic':
        entropy.append(
            Dirichlet(alpha).entropy().squeeze().cpu().detach().numpy())

    if plot:
        plt.hist(entropy, n_bins)
        plt.show()
    return entropy
Esempio n. 27
0
def test_compose_affine(event_dims):
    transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, max(event_dims))
Esempio n. 28
0
def main():
    args = parse_arguments()

    x, c, y, S, antes = load_data(args, 'support2')

    print(antes[:5])
    print(type(antes))
    print(len(antes))

    # create model
    args['n_features'] = c.shape[-1]

    model = create_model(args, antes, S.shape[0])

    # use whole dataset for now
    S = torch.tensor(S, dtype=torch.float)
    c = torch.tensor(c, dtype=torch.float)
    n_train = S.shape[0]
    print(n_train)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    # optimizer = optim.Adam(model.parameters(), lr=0.001)

    start_time = time.time()
    for ep in range(50):
        optimizer.zero_grad()

        d_soft = model(c, S)

        # print(d)
        maxes, argmaxes = d_soft.max(dim=-1)

        d = [antes[amax] for amax in argmaxes]
        print(d)

        n_classes = y.shape[-1]
        B = torch.zeros((n_train, n_classes, len(d) + 1))
        for i, xi in enumerate(x):
            for j, lhs in enumerate(d):
                if set(lhs).issubset(xi):
                    B[i, y[i, 0], j] = 1.
                    break

            B[i, y[i, 0], -1] = 1 - B[i, y[i, 0], :-1].sum()

        assert B.sum() == S.size()[0]

        # get dirichlet prior
        alpha = 1.
        priors = alpha + B.sum(0)

        thetas = torch.zeros((len(d) + 1, n_classes))
        for i in range(len(d) + 1):
            p_theta = Dirichlet(torch.tensor(priors[:, i]))
            thetas[i] = p_theta.rsample()

        # compute p(y | d)
        log_py = 0
        for i, yi in enumerate(y):
            for j in range(len(d) + 1):
                log_py += B[i, y[i, 0], j] * torch.log(thetas[j, y[i, 0]])

        # compute p(d | input), as p(d_1 | input) p(d_2 | d_1, input) ...
        log_pd = maxes.sum()

        log_prob = -(log_py + log_pd)

        elapsed = time.time() - start_time
        print(
            f"Epoch {ep}: log-prob: {log_prob:.2f}, log p(d|x): {log_pd:.2f}, ",
            end='')
        print(f"log p(y|d): {log_py:.2f} (Elapsed: {elapsed:.2f}s)")

        log_prob.backward()
        optimizer.step()
Esempio n. 29
0
 def __init__(self, *args, **kwargs):
     super(ThreeGAN, self).__init__(*args, **kwargs)
     if self.cls > 0:
         raise NotImplementedError("ThreeGAN not implemented for cls > 0")
     self.dirichlet = Dirichlet(torch.FloatTensor([1.0, 1.0, 1.0]))
Esempio n. 30
0
 def get_log_prob(self, state, action):
     concentration = self._get_concentration(state)
     log_prob = Dirichlet(concentration).log_prob(action)
     return log_prob.unsqueeze(dim=1)