def gen_default_priors(self, data, K, L, # sig_prior=Gamma(10, 10), sig_prior=LogNormal(0, 0.01), alpha_prior=Gamma(1, 1), mu0_prior=None, mu1_prior=None, W_prior=None, eta0_prior=None, eta1_prior=None): if L is None: L = [5, 5] self.__cache_model_constants__(data, K, L) if mu0_prior is None: mu0_prior = Gamma(1, 1) if mu1_prior is None: mu1_prior = Gamma(1, 1) if W_prior is None: W_prior = Dirichlet(torch.ones(self.K) / self.K) if eta0_prior is None: eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0]) if eta1_prior is None: eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1]) self.priors = {'mu0': mu0_prior, 'mu1': mu1_prior, 'sig': sig_prior, 'eta0': eta0_prior, 'eta1': eta1_prior, 'W': W_prior, 'alpha': alpha_prior}
def test_dirichlet_shape(self): alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True) alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True) self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3)) self.assertEqual(Dirichlet(alpha).sample((5, )).size(), (5, 2, 3)) self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4, )) self.assertEqual(Dirichlet(alpha_1d).sample((1, )).size(), (1, 4))
def log_prior(Z, variable_types): if 'pi_unconstrained' in variable_types.keys(): pi = softmax(Z['pi_unconstrained'][0], dim=-1) logp = Dirichlet( torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax( pi, dim=-1) else: logp = 0 for i, (key, z) in enumerate(Z.items()): z = z[0] if key != 'pi_unconstrained': if variable_types[key] == 'Categorical': alpha = softmax(z, dim=-1, additional=-50.) logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \ + log_det_jacobian_softmax(alpha, dim=-1), dim=-1) #elif variable_types[key] == 'Bernoulli': # theta = torch.sigmoid(z) # logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\ # + log_det_jacobian_sigmoid(theta), dim=-1) elif variable_types[key] == 'Bernoulli': logp += TransBeta.log_prob(z).sum() elif variable_types[key] == 'Beta': alpha, beta = torch.exp(z) logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) + torch.log(alpha), dim=-1) logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) + torch.log(beta), dim=-1) return torch.mean(logp)
def log_prior(Z, variable_types): """ Z : A dictionary containing the draws from variational posterior for each parameter. variable_types : a dictionary that contains distribution name assigned to each parameter. """ ## We proceed similarly as in log-likelihood computation, however since the elements of ## Z are in expanded form and the prior is not data dependent we compute the contribution ## of only the first element of element of Z. pi = softmax(Z['pi_unconstrained'][0], dim=-1) logp = Dirichlet( torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(pi, dim=-1) for i, (key, z) in enumerate(Z.items()): if key != 'pi_unconstrained': z = z[0] if variable_types[key] == 'Categorical': alpha = softmax(z, dim=-1, additional=-50.) logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \ + log_det_jacobian_softmax(alpha, dim=-1), dim=-1) elif variable_types[key] == 'Bernoulli': theta = torch.sigmoid(z) logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\ + log_det_jacobian_sigmoid(theta), dim=-1) elif variable_types[key] == 'Beta': alpha, beta = torch.exp(z) logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) + torch.log(alpha), dim=-1) logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) + torch.log(beta), dim=-1) return logp
def loss( self, tensors, inference_outputs, generative_outputs, n_obs: int = 1.0, ): # generative_outputs is a dict of the return value from `generative(...)` # assume that `n_obs` is the number of training data points p_x_c = generative_outputs["p_x_c"] gamma = generative_outputs["gamma"] # compute Q # take mean of number of cells and multiply by n_obs (instead of summing n) q_per_cell = torch.sum(gamma * -p_x_c, 1) # third term is log prob of prior terms in Q theta_log = F.log_softmax(self.theta_logit, dim=-1) theta_log_prior = Dirichlet(self.dirichlet_concentration) theta_log_prob = -theta_log_prior.log_prob( torch.exp(theta_log) + THETA_LOWER_BOUND) prior_log_prob = theta_log_prob delta_log_prior = Normal(self.delta_log_mean, self.delta_log_log_scale.exp().sqrt()) delta_log_prob = torch.masked_select( delta_log_prior.log_prob(self.delta_log), (self.rho > 0)) prior_log_prob += -torch.sum(delta_log_prob) loss = (torch.mean(q_per_cell) * n_obs + prior_log_prob) / n_obs return LossRecorder(loss, q_per_cell, torch.zeros_like(q_per_cell), prior_log_prob)
def test_dirichlet_shape(self): dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) self.assertEqual(dist._event_shape, torch.Size((2,))) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def select_action(self, obs): concentration, value = self.forward(obs) m = Dirichlet(concentration) action = m.sample() self.saved_actions.append(SavedAction(m.log_prob(action), value)) return list(action.cpu().numpy())
def test_forward(self): source_dirichlet = Dirichlet(torch.ones(10)) batch_size = 12 observed_data = source_dirichlet.sample((batch_size, )) # Check that forward function returns _, membership_probs = self.dmm(observed_data) # Ensure membership probabilities sum to one for prob in membership_probs.sum(dim=1): self.assertAlmostEqual(prob.item(), 1, places=5)
def test_dirichlet_log_prob(self): num_samples = 10 alpha = torch.exp(torch.randn(5)) dist = Dirichlet(alpha) x = dist.sample((num_samples,)) actual_log_prob = dist.log_prob(x) for i in range(num_samples): expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy()) self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
def sample_partitions(B, N, K, alpha=1.0, rand_K=True, device='cpu'): pi = Dirichlet(alpha * torch.ones(K)).sample([B]).to(device) if rand_K: to_use = (torch.rand(B, K) < 0.5).float().to(device) to_use[..., 0] = 1 pi = pi * to_use pi = pi / pi.sum(1, keepdim=True) labels = Categorical(probs=pi).sample([N]).to(device) labels = labels.transpose(0, 1).contiguous() return labels
def test_backward(self): source_dirichlet = Dirichlet(torch.ones(10)) batch_size = 12 observed_data = source_dirichlet.sample((batch_size, )) # Obtain loss nll, _ = self.dmm(observed_data) loss = nll.sum() # Check that gradient is non-zero for params loss.backward() for param in self.dmm.parameters(): self.assertIsNotNone(param.grad)
def sample_labels(B, N, K_low, K_high, alpha=1.0): pi = Dirichlet(alpha * torch.ones(K_high)).sample([B]) K = torch.randint(K_low, K_high + 1, size=(B, )) to_use = torch.zeros(B, K_high).int() for i, k in enumerate(K): to_use[i, :k] = 1 pi = pi * to_use pi = pi / pi.sum(1, keepdim=True) labels = Categorical(probs=pi).sample([N]) labels = labels.transpose(0, 1).contiguous() return labels
def forward(self, state): concentration = self._get_concentration(state) if self.training: # PyTorch can't backwards pass _sample_dirichlet action = Dirichlet(concentration).rsample() else: # ONNX can't export Dirichlet() action = torch._sample_dirichlet(concentration) log_prob = Dirichlet(concentration).log_prob(action) return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
def sample_l1_sphere(device, shape): '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope. Inputs: device: 'cpu' | 'cuda' | other torch devices shape: a pair (batchsize, dim) Outputs: matrix of shape `shape` such that each row is a sample. ''' batchsize, dim = shape dirdist = Dirichlet(concentration=torch.ones(dim, device=device)) noises = dirdist.sample([batchsize]) signs = torch.sign(torch.rand_like(noises) - 0.5) return noises * signs
def test_dirichlet_sample(self): self._set_rng_seed() alpha = torch.exp(torch.randn(3)) self._check_sampler_sampler(Dirichlet(alpha), scipy.stats.dirichlet(alpha.numpy()), 'Dirichlet(alpha={})'.format(list(alpha)), multivariate=True)
def test_TwoFactorFractionalSIR(self): dist = Dirichlet(torch.tensor([10000., 1., 1.])) sir = m.TwoFactorSIR((0.1, 0.05, 0.05, 0.05), dist, dt=1e-1) x = sir.sample_path(1000) self.assertEqual(x.shape, torch.Size([1000, 3]))
def _AMN_optimization_ENS_mixture(AMN_net, expert_net, optimizer, state_batch, feature_regression=False, tau=0.1, beta=0.01, GAMMA=0.99, training=True): """ Apply the standard procedure to deep Q network. """ if not training: return None AMN_policy = 0 alpha = Dirichlet(torch.ones(AMN_net.get_num_ensembles())).sample() for i in range(AMN_net.get_num_ensembles()): AMN_q_value = AMN_net(state_batch, ens_num=i, last_layer=False) AMN_policy += alpha[i] * to_policy(AMN_q_value) loss = 0 expert_q_value = expert_net(state_batch, last_layer=False) expert_policy = to_policy(expert_q_value).detach() loss -= torch.sum(expert_policy * torch.log(AMN_policy + 1e-8)) optimizer.zero_grad() loss.backward() optimizer.step() return loss.detach()
def gen_default_priors(self, K, L, sig_prior=LogNormal(0, 1), alpha_prior=Gamma(1., 1.), mu0_prior=None, mu1_prior=None, W_prior=None, eta0_prior=None, eta1_prior=None): if L is None: L = [5, 5] self.L = L if K is None: K = 30 self.K = K # FIXME: these should be an ordered prior of TruncatedNormals if mu0_prior is None: mu0_prior = Gamma(1, 1) if mu1_prior is None: mu1_prior = Gamma(1, 1) if W_prior is None: W_prior = Dirichlet(torch.ones(self.K) / self.K) if eta0_prior is None: eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0]) if eta1_prior is None: eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1]) self.priors = { 'mu0': mu0_prior, 'mu1': mu1_prior, 'sig': sig_prior, 'H': Normal(0, 1), 'eta0': eta0_prior, 'eta1': eta1_prior, 'W': W_prior, 'alpha': alpha_prior }
def kl_divergence(model_concentrations, target_concentrations, mode='reverse'): """ Input: Model concentrations, target concentrations parameters. Output: Average of the KL between the two Dirichlet. """ assert torch.all(model_concentrations > 0) assert torch.all(target_concentrations > 0) target_dirichlet = Dirichlet(target_concentrations) model_dirichlet = Dirichlet(model_concentrations) kl_divergences = _kl_dirichlet_dirichlet( p=target_dirichlet if mode == 'forward' else model_dirichlet, q=model_dirichlet if mode == 'forward' else target_dirichlet) assert_no_nan_no_inf(kl_divergences) mean_kl = torch.mean(kl_divergences) assert_no_nan_no_inf(mean_kl) return mean_kl
def backward(ctx, grad_output): grad_input=None if not ctx.train: raise RuntimeError('Running backward on shake when train is False') if ctx.needs_input_grad[0]: dist = Dirichlet(th.full((ctx.xsh[ctx.dim],),ctx.concentration)) beta = dist.sample(sample_shape=th.Size([ctx.xsh[ctx.batchdim]])) beta = beta.to(th.device(ctx.dev)) sh = [1 for _ in range(len(ctx.xsh))] sh[ctx.batchdim], sh[ctx.dim] = ctx.xsh[ctx.batchdim], ctx.xsh[ctx.dim] beta = beta.view(*sh) grad_output = grad_output.unsqueeze(ctx.dim).expand(*ctx.xsh) grad_input = grad_output * beta return grad_input, None, None, None, None, None
def test_TwoFactorSIRD(self): dist = Dirichlet(torch.tensor([10000., 1., 1., 1.])) sird = m.ThreeFactorSIRD((0.1, 0.05, 0.05, 0.01, 0.05, 0.05, 0.05), dist, dt=1e-1) x = sird.sample_path(1000) self.assertEqual(x.shape, torch.Size([1000, 4]))
def sample(self, B, N, K, return_gt=False): device = 'cpu' if not torch.cuda.is_available() \ else torch.cuda.current_device() pi = Dirichlet(torch.ones(K)).sample(torch.Size([B])).to(device) labels = Categorical(probs=pi).sample(torch.Size([N])).to(device) labels = labels.transpose(0, 1).contiguous() X, params = self.mvn.sample(B, K, labels) if return_gt: return X, labels, pi, params else: return X
def construct_variational_parameters(T, N): # T-1 params for the second part of variational Beta factors kappa = Variable(Uniform(0, 2).rsample([T - 1]), requires_grad=True) # T scale params for the variational Gamma factors tau_0 = Uniform(0, 100).rsample([T]) # T rate params for the variational Gamma factors tau_1 = LogNormal(0, 1).rsample([T]) tau = Variable(torch.stack((tau_0, tau_1)).T, requires_grad=True) phi = Variable( Dirichlet(1 / T * torch.ones(T)).rsample([N]), requires_grad=True) # N,T params for the variational Cat factors return kappa, tau, phi
def log_prior(theta_unconstrained, pi_unconstrained): theta = torch.sigmoid(theta_unconstrained) pi = softmax(pi_unconstrained, dim=-1) """ Both tau and pi are transformed, so we need to add correction terms to log densities. For tau this will be theta_unconstrained (log |d/dy exp(y)| = y) and for pi this will be log_det_jacobian_softmax (for reasons too complicated to be explained here). """ theta_logp = torch.sum(log_det_jacobian_sigmoid(theta_unconstrained)) pi_logp = torch.sum( Dirichlet(torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(pi, dim=-1)) return pi_logp + theta_logp
def forward(ctx, x, dim, batchdim, concentration, train): ctx.dim = dim ctx.batchdim = batchdim ctx.concentration = concentration xsh, ctx.xsh = [x.shape]*2 ctx.dev = x.device ctx.train = True if train: # Randomly sample from Dirichlet distribution dist = Dirichlet(th.full((xsh[dim],), concentration)) alpha = dist.sample(sample_shape=th.Size([xsh[batchdim]])) alpha = alpha.to(th.device(x.device)) sh = [1 for _ in range(len(xsh))] sh[batchdim], sh[dim] = xsh[batchdim], xsh[dim] alpha = alpha.view(*sh) y = (x * alpha).sum(dim) else: y = x.mean(dim) return y
def entropy(alpha, uncertainty_type, n_bins=10, plot=True): entropy = [] if uncertainty_type == 'aleatoric': p = torch.nn.functional.normalize(alpha, p=1, dim=-1) entropy.append( Categorical(p).entropy().squeeze().cpu().detach().numpy()) elif uncertainty_type == 'epistemic': entropy.append( Dirichlet(alpha).entropy().squeeze().cpu().detach().numpy()) if plot: plt.hist(entropy, n_bins) plt.show() return entropy
def test_compose_affine(event_dims): transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == max(event_dims) assert transform.domain.event_dim == max(event_dims) base_dist = Normal(0, 1) if transform.domain.event_dim: base_dist = base_dist.expand((1,) * transform.domain.event_dim) dist = TransformedDistribution(base_dist, transform.parts) assert dist.support.event_dim == max(event_dims) base_dist = Dirichlet(torch.ones(5)) if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) assert dist.support.event_dim == max(1, max(event_dims))
def main(): args = parse_arguments() x, c, y, S, antes = load_data(args, 'support2') print(antes[:5]) print(type(antes)) print(len(antes)) # create model args['n_features'] = c.shape[-1] model = create_model(args, antes, S.shape[0]) # use whole dataset for now S = torch.tensor(S, dtype=torch.float) c = torch.tensor(c, dtype=torch.float) n_train = S.shape[0] print(n_train) optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) # optimizer = optim.Adam(model.parameters(), lr=0.001) start_time = time.time() for ep in range(50): optimizer.zero_grad() d_soft = model(c, S) # print(d) maxes, argmaxes = d_soft.max(dim=-1) d = [antes[amax] for amax in argmaxes] print(d) n_classes = y.shape[-1] B = torch.zeros((n_train, n_classes, len(d) + 1)) for i, xi in enumerate(x): for j, lhs in enumerate(d): if set(lhs).issubset(xi): B[i, y[i, 0], j] = 1. break B[i, y[i, 0], -1] = 1 - B[i, y[i, 0], :-1].sum() assert B.sum() == S.size()[0] # get dirichlet prior alpha = 1. priors = alpha + B.sum(0) thetas = torch.zeros((len(d) + 1, n_classes)) for i in range(len(d) + 1): p_theta = Dirichlet(torch.tensor(priors[:, i])) thetas[i] = p_theta.rsample() # compute p(y | d) log_py = 0 for i, yi in enumerate(y): for j in range(len(d) + 1): log_py += B[i, y[i, 0], j] * torch.log(thetas[j, y[i, 0]]) # compute p(d | input), as p(d_1 | input) p(d_2 | d_1, input) ... log_pd = maxes.sum() log_prob = -(log_py + log_pd) elapsed = time.time() - start_time print( f"Epoch {ep}: log-prob: {log_prob:.2f}, log p(d|x): {log_pd:.2f}, ", end='') print(f"log p(y|d): {log_py:.2f} (Elapsed: {elapsed:.2f}s)") log_prob.backward() optimizer.step()
def __init__(self, *args, **kwargs): super(ThreeGAN, self).__init__(*args, **kwargs) if self.cls > 0: raise NotImplementedError("ThreeGAN not implemented for cls > 0") self.dirichlet = Dirichlet(torch.FloatTensor([1.0, 1.0, 1.0]))
def get_log_prob(self, state, action): concentration = self._get_concentration(state) log_prob = Dirichlet(concentration).log_prob(action) return log_prob.unsqueeze(dim=1)