def get_best_move(self, s, v, rm=None): proba = torch.tensor([v[a] for a in s.legal_moves]) # pylint: disable=E proba = F.softmax(self.c.eval_move * proba, dim=0).numpy() i_max_no_noise = proba.argmax() if self.add_noise: dir_dist = Dirichlet(torch.zeros(len(s.legal_moves)) + self.c.alpha_dir) noise = dir_dist.sample().numpy() proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise # Best move i_max = proba.argmax() best_move = s.legal_moves[i_max] # For RunManager if rm: proba_dictate = int(i_max_no_noise == i_max) rm.proba( proba[i_max], self.c.eps_dir * noise[i_max], proba_dictate ) return best_move
def get_best_move(self, s, v, rm=None): # Compute the indices of the legal moves in the tensor v. legal_mask = torch.zeros(v.shape, dtype=torch.bool) for a in s.legal_moves: legal_mask[encoding.a_id(a)] = True # Compute the probabilities of each legal moves. proba = torch.from_numpy(v[legal_mask]) proba = F.softmax(self.c.eval_move * proba, dim=0).numpy() i_max_no_noise = proba.argmax() # Add noise if so. if self.add_noise: dir_dist = Dirichlet( torch.zeros(len(s.legal_moves)) + self.c.alpha_dir) noise = dir_dist.sample().numpy() proba = (1 - self.c.eps_dir) * proba + self.c.eps_dir * noise # Best move i_max = proba.argmax() best_move = s.legal_moves[i_max] # For RunManager if rm: best_move_code = np.ravel_multi_index(encoding.a_id(best_move), v.shape) v_dictate = int(v.argmax() == best_move_code) proba_dictate = int(i_max_no_noise == i_max) rm.proba(v.max(), proba[i_max], self.c.eps_dir * noise[i_max], v[legal_mask.logical_not()].max().item(), v_dictate, proba_dictate) return best_move
def fit(self, epoch_num, optimizer, train_all): self.train() # Shuffle the input train_loader = torch.utils.data.DataLoader(train_all, batch_size=32, shuffle=True) for epoch in range(epoch_num): loss_total = 0 for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() data = data.to(device) # predict alpha target_a = target_alpha(target) target_a = target_a.to(device) output_alpha = torch.exp(self.forward(data)) dirichlet1 = Dirichlet(output_alpha) dirichlet2 = Dirichlet(target_a) loss = torch.sum(dist.kl.kl_divergence(dirichlet1, dirichlet2)) loss_total += loss.item() loss.backward() optimizer.step() print('Train Epoch: {} \t Loss: {:.6f}'.format( epoch, loss_total / 120000))
def KL_phi(self): if self.inference == "collapsed": return ELBO_collapsed_Categorical(self.qphi_logits, self.alpha_z, K=self.n_basis, N=self.data_dim) elif self.inference == "fixed_pi": qphi = self.get_phi() pi = torch.ones_like(qphi) / self.n_basis KL = (qphi * (torch.log(qphi + 1e-16) - torch.log(pi))).sum() return KL elif self.inference == "non-collapsed": qDir = Dirichlet(concentration=self.qalpha_z) pDir = Dirichlet(concentration=self.alpha_z) # KL(q(pi) || p(pi)) KL_Dir = torch.distributions.kl_divergence(qDir, pDir) # E[log q(phi) - log p(phi | pi)] under q(pi)q(phi) qpi = qDir.rsample() qphi = self.get_phi() # KL categorical KL_Cat = ( qphi * (torch.log(qphi + 1e-16) - torch.log(qpi[None, :]))).sum() return KL_Dir + KL_Cat
def forward(self, inputs, labels, topics, lengths, sample_topics=False): enc_emb = self.lookup(inputs) dec_emb = self.lookup(inputs) lab_emb = self.label_lookup(labels).unsqueeze( 0) # to match with shape of z topics.unsqueeze_(0) # prior of z mu_pr, logvar_pr = self.z_prior(lab_emb) h, _ = self.encoder(enc_emb, lengths) if self.is_joint: hn = torch.cat([h, topics, lab_emb], dim=2) else: hn = torch.cat([h, lab_emb], dim=2) # posterior of z mu_po, logvar_po = self.fcmu(hn), self.fclogvar(hn) if self.training: z = self.reparameterize(mu_po, logvar_po) else: z = mu_po alphas = self.topic_prior(torch.cat([z, lab_emb], dim=2)) if sample_topics and not self.is_joint: # sampling only valid for marginal model dist = Dirichlet((topics * topics.size(2)).cpu()) topics = dist.rsample().to(alphas.device) code = torch.cat([z, topics, lab_emb], dim=2) outputs, _ = self.decoder(dec_emb, code, lengths=lengths) outputs = self.fcout(outputs) bow = self.bow_predictor(torch.cat([z, lab_emb], dim=2)) return outputs, (mu_pr, mu_po), (logvar_pr, logvar_po), alphas, bow
def max_ucb_noise_node(self, n): Nc = len(n.children) dir_dist = Dirichlet(torch.zeros(Nc) + self.alpha_dir) noises = dir_dist.sample().numpy() i_max = max(range(Nc), key=lambda i: self.ucb_noise(n.children[i], noises[i])) return n.children[i_max]
def augmentAndMix(x_orig, k, alpha, preprocess): # k : number of chains # alpha : sampling constant x_temp = x_orig # back up for skip connection x_aug = torch.zeros_like(preprocess(x_orig)) mixing_weight_dist = Dirichlet(torch.empty(k).fill_(alpha)) mixing_weights = mixing_weight_dist.sample() for i in range(k): sampled_augs = random.sample(augmentations, k) aug_chain_length = random.choice(range(1, k + 1)) aug_chain = sampled_augs[:aug_chain_length] for aug in aug_chain: severity = random.choice(range(1, 6)) x_temp = aug(x_temp, severity) x_aug += mixing_weights[i] * preprocess(x_temp) skip_conn_weight_dist = Beta(torch.tensor([alpha]), torch.tensor([alpha])) skip_conn_weight = skip_conn_weight_dist.sample() x_augmix = skip_conn_weight * x_aug + ( 1 - skip_conn_weight) * preprocess(x_orig) return x_augmix
def forward(self, input, target): # Compute the KL Annealing factor. kl_annealing = self.annealing_factor() # Get the prior of the relation probabilities, for this batch. if (self._alpha_prior_val is None): # Pytorch's Categorical distribution normalises input probs, yielding a proper probability distribution. probs = torch.ones(input[0].shape[1], device=input[0].device).unsqueeze(0).expand( input[0].shape[0], -1) else: alpha = torch.tensor(self._alpha_prior_val, device=input[0].device).expand( input[0].shape[1]) if (self._instance_prior): alpha = alpha.unsqueeze(0).expand(input[0].shape[0], -1) probs = Dirichlet(alpha).sample() else: probs = Dirichlet(alpha).sample().unsqueeze(0).expand( input[0].shape[0], -1) # Compute the loss. loss = re_bow_loss(input, target, prior=Categorical(probs=probs), reduction=self.reduction, ignore_index=self._ignore_index, kl_annealing=kl_annealing, _DEBUG=self._DEBUG) return loss
class Beta(Distribution): r""" Beta distribution parameterized by `concentration1` and `concentration0`. Example:: >>> m = Beta(torch.Tensor([0.5]), torch.Tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 0.1046 [torch.FloatTensor of size 1] Args: concentration1 (float or Tensor or Variable): 1st concentration parameter of the distribution (often referred to as alpha) concentration0 (float or Tensor or Variable): 2nd concentration parameter of the distribution (often referred to as beta) """ params = {'concentration1': constraints.positive, 'concentration0': constraints.positive} support = constraints.unit_interval has_rsample = True def __init__(self, concentration1, concentration0): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = torch.Tensor([concentration1, concentration0]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape) def rsample(self, sample_shape=()): value = self._dirichlet.rsample(sample_shape).select(-1, 0) if isinstance(value, Number): value = self._dirichlet.concentration.new([value]) return value def log_prob(self, value): self._validate_log_prob_arg(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) def entropy(self): return self._dirichlet.entropy() @property def concentration1(self): result = self._dirichlet.concentration[..., 0] if isinstance(result, Number): return torch.Tensor([result]) else: return result @property def concentration0(self): result = self._dirichlet.concentration[..., 1] if isinstance(result, Number): return torch.Tensor([result]) else: return result
def __init__(self, concentration1, concentration0): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = torch.Tensor([concentration1, concentration0]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape)
def init_params_random(self) -> None: """ Randomly sets the parameters of the model using the dirchlet priors. """ self.log_T0 = Dirichlet(self.T0_prior).sample().log() self.log_T = Dirichlet(self.T_prior).sample().log() for s in self.states: s.init_params_random()
def __init__(self, concentration1, concentration0, validate_args=None): if isinstance(concentration1, Real) and isinstance(concentration0, Real): concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
def __init__(self, alpha, beta): if isinstance(alpha, Number) and isinstance(beta, Number): alpha_beta = torch.Tensor([alpha, beta]) else: alpha, beta = broadcast_all(alpha, beta) alpha_beta = torch.stack([alpha, beta], -1) self._dirichlet = Dirichlet(alpha_beta) super(Beta, self).__init__(self._dirichlet._batch_shape)
class Beta(Distribution): r""" Creates a Beta distribution parameterized by concentration `alpha` and `beta`. Example:: >>> m = Beta(torch.Tensor([0.5]), torch.Tensor([0.5])) >>> m.sample() # Beta distributed with concentration alpha and beta 0.1046 [torch.FloatTensor of size 1] Args: alpha (float or Tensor or Variable): 1st concentration parameter of the distribution beta (float or Tensor or Variable): 2nd concentration parameter of the distribution """ params = {'alpha': constraints.positive, 'beta': constraints.positive} support = constraints.unit_interval has_rsample = True def __init__(self, alpha, beta): if isinstance(alpha, Number) and isinstance(beta, Number): alpha_beta = torch.Tensor([alpha, beta]) else: alpha, beta = broadcast_all(alpha, beta) alpha_beta = torch.stack([alpha, beta], -1) self._dirichlet = Dirichlet(alpha_beta) super(Beta, self).__init__(self._dirichlet._batch_shape) def rsample(self, sample_shape=()): value = self._dirichlet.rsample(sample_shape).select(-1, 0) if isinstance(value, Number): value = self._dirichlet.alpha.new([value]) return value def log_prob(self, value): self._validate_log_prob_arg(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) def entropy(self): return self._dirichlet.entropy() @property def alpha(self): result = self._dirichlet.alpha[..., 0] if isinstance(result, Number): return torch.Tensor([result]) else: return result @property def beta(self): result = self._dirichlet.alpha[..., 1] if isinstance(result, Number): return torch.Tensor([result]) else: return result
def train(self, x, sampling=True, independent=True): ''' Parameters ---------- x : a batch of data sampling : whether to sample from the variational posterior distributions(if Ture, the default), or just use the mean of the variational distributions Return ------ log_likehoods : log like hood for each sample kl_sum : Sum of the KL divergences between the variational distributions and their priors ''' # The variational distributions mu = Normal(self.locs, self.scales) sigma = Gamma(self.alpha, self.beta) theta = Dirichlet(self.couts) # Sample from the variational distributions if sampling: # Nb = x.shape[0] Nb = 1 mu_sample = mu.rsample((Nb, )) sigma_sample = torch.pow(sigma.rsample((Nb, )), -0.5) theta_sample = theta.rsample((Nb, )) else: mu_sample = torch.reshape(mu.mean, (1, self.Nc, self.Nd)) sigma_sample = torch.pow( torch.reshape(sigma.mean, (1, self.Nc, self.Nd)), -0.5) theta_sample = torch.reshape(theta.mean, (1, self.Nc)) # 1*Nc # The mixture density log_var = (sigma_sample**2).log() log_likelihoods = GMM.get_likelihoods(x, mu_sample.reshape( (self.Nc, self.Nd)), log_var.reshape( (self.Nc, self.Nd)), log=True) # Nc*Nb log_prob_ = theta_sample @ log_likelihoods log_prob = log_prob_ # Compute the KL divergence sum mu_div = kl_divergence(mu, self.mu_prior) sigma_div = kl_divergence(sigma, self.sigma_prior) theta_div = kl_divergence(theta, self.theta_prior) KL = mu_div + sigma_div + theta_div if 0: print("mu_div: %f \t sigma_div: %f \t theta_div: %f" % (mu_div.sum().detach().numpy(), sigma_div.sum().detach().numpy(), theta_div.sum().detach().numpy())) return KL, log_prob
def log_parameters_prob(self) -> float: """ :returns: log probability of the parameters given priors. """ ll = Dirichlet(self.T0_prior).log_prob(self.T0) ll += Dirichlet(self.T_prior).log_prob(self.T).sum(0) for s in self.states: ll += s.log_parameters_prob() return ll
def __init__(self, k=3, alpha=1, severity=3): super(AugMix, self).__init__() self.k = k self.alpha = alpha self.severity = severity self.dirichlet = Dirichlet(torch.full(torch.Size([k]), alpha, dtype=torch.float32)) self.beta = Beta(alpha, alpha) self.augs = augmentations self.kl = nn.KLDivLoss(reduction='batchmean')
def set_temperature(self, value): self.dirs_normal[:] = [] for a in self.alphas_normal: self.dirs_normal.append( Dirichlet(value * torch.ones(a.size()).cuda())) self.dirs_reduce[:] = [] for a in self.alphas_reduce: self.dirs_reduce.append( Dirichlet(value * torch.ones(a.size()).cuda()))
def sample(self, labels, max_length, sos_id, scale=1): lab_emb = self.label_lookup(labels).unsqueeze(0) mu, logvar = self.z_prior(lab_emb) z = self.reparameterize(mu, logvar) if scale != 1: z = mu + (z - mu) * scale alphas = self.topic_prior(torch.cat([z, lab_emb], dim=2)) dist = Dirichlet(alphas.cpu()) topics = dist.sample().to(alphas.device) return self.generate(z, topics, lab_emb, max_length, sos_id)
def sample(self, num_samples, max_length, sos_id, device): """Randomly sample latent code to sample texts. Note that num_samples should not be too large. """ z_size = self.fcmu.out_features z = torch.randn(1, num_samples, z_size, device=device) alphas = self.topic_prior(z) dist = Dirichlet(alphas.cpu()) topics = dist.sample().to(device) return self.generate(z, topics, max_length, sos_id)
def log_prob(self, value): lp = torch.zeros_like(value, dtype=torch.float) if torch.mul(value > 0., value < 1.).any(): beta_idx = torch.where(torch.mul(value > 0., value < 1.)) self._dirichlet = Dirichlet( self.concentration1_concentration0[beta_idx]) lp[beta_idx] = self.log1m_p[beta_idx] + self.log1m_q[ beta_idx] + self.beta_lp(value[beta_idx]) lp[torch.where(value == 0.)] = self.log_p[torch.where(value == 0.)] lp[torch.where(value == 1.)] = self.log1m_p[torch.where( value == 1.)] + self.log_q[torch.where(value == 1.)] return lp
def _sample_volume_alphas(self, n_related): if self.uniform_volumes: u = Uniform(0.25, 1.25) return u.sample().repeat(n_related) if isinstance(self.concentration, (float, int)): concentration = self.concentration else: concentration = self.concentration.rvs() dirichlet = Dirichlet( torch.tensor([concentration for _ in range(n_related)])) if self.random_seed is not None: torch.manual_seed(self.random_seed) return dirichlet.sample() * float(self.n_classes)
def __init__(self, INITIAL_EPSILON, FINAL_EPSILON, policy_net, EPS_DECAY, n_actions, lamb, device): self._eps = INITIAL_EPSILON self._FINAL_EPSILON = FINAL_EPSILON self._INITIAL_EPSILON = INITIAL_EPSILON self._policy_net = policy_net self._EPS_DECAY = EPS_DECAY self._n_actions = n_actions self._device = device distn_params = [ 1 / lamb for _ in range(policy_net.get_num_ensembles()) ] self.distn = Dirichlet(torch.tensor(distn_params))
def plot_dir(alpha, size): model = Dirichlet(torch.tensor(alpha)) sample = model.sample(torch.Size([size])).data fig = plt.figure() ax = plt.axes(projection='3d') ax.scatter3D(sample[:, 0], sample[:, 1], sample[:, 2], color='red') ax.plot([0, 0], [1, 0], [0, 1], linewidth=3, color='purple') ax.plot([0, 1], [0, 0], [1, 0], linewidth=3, color='purple') ax.plot([0, 1], [1, 0], [0, 0], linewidth=3, color='purple') ax.set_xlim((0, 1)) ax.set_ylim((0, 1)) ax.set_zlim((0, 1)) ax.view_init(60, 35)
def log_joint_pdf(self, samples, log_f): """ Returns (the log of) p(piece | cluster, transposition); assuming uniform prior over clusters and transpositions, this is proportional to the joint p(piece, cluster, transposition) :param samples: array of shape (n_pieces, n_samples, n_pitches, n_clusters, n_transpositions) :param log_f: array of shape (n_pieces, n_samples, n_pitches, n_clusters, n_transpositions) :return: array of shape (...) """ # construct Dirichlet distributions (move pitch dimension last dir = Dirichlet(torch.einsum('abcde->abdec', log_f.exp())) # get point-wise probabilities and multiply up (log-sum) samples for each piece probs = dir.log_prob(torch.einsum('abcde->abdec', samples)) return probs.sum(dim=1)
def loss_function(targets, outputs, mu, logvar, alphas, topics, bow=None, joint=False): """ Inputs: targets: target tokens outputs: predicted tokens mu: latent mean logvar: log of the latent variance alphas: parameters of the dirichlet prior p(w|z) given latent code topics: actual distribution of topics q(w|x,z) i.e. posterior given x Outputs: ce_loss: cross entropy loss of the tokens kld: D(q(z|x)||p(z)) kld_tpc: D(q(w|x,z)||p(w|z)) """ ce_loss = F.cross_entropy(outputs.view( outputs.size(0) * outputs.size(1), outputs.size(2)), targets.view(-1), size_average=False, ignore_index=PAD_ID) if bow is None: bow_loss = torch.tensor(0., device=outputs.device) else: bow = bow.unsqueeze(1).repeat(1, outputs.size(1), 1).contiguous() bow_loss = F.cross_entropy(bow.view( bow.size(0) * bow.size(1), bow.size(2)), targets.view(-1), size_average=False, ignore_index=PAD_ID) if type(mu) == torch.Tensor: kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) else: kld = -0.5 * torch.sum(1 + logvar[1] - logvar[0] - ( (mu[1] - mu[0]).pow(2) + logvar[1].exp()) / logvar[0].exp()) prior = Dirichlet(alphas) if joint: loss_tpc = -torch.sum(prior.log_prob(topics)) else: alphas2 = topics * topics.size(1) posterior = Dirichlet(alphas2) loss_tpc = kl_divergence(posterior, prior).sum() return ce_loss, kld, loss_tpc, bow_loss
def test_dirichlet_logpdf(): alpha = torch.tensor([0.5, 0.6, 1.2]) ps = torch.tensor([0.2, 0.3, 0.5]) log_pdf = dirichlet_logpdf(ps, alpha) # pytorch implementation dist = Dirichlet(concentration=alpha) log_prob = dist.log_prob(ps) print(log_pdf) print(log_prob) assert log_pdf == log_prob
def __init__(self, Nc, Nd): ''' Nc : number of components Nd : number of dimension ''' # Initialize super(GaussianMixtureModel, self).__init__() self.Nc = Nc self.Nd = Nd # Variational distribution variables for means: u ~ Normal (locs, scales) self.locs = Variable(torch.normal(10 * torch.zeros((Nc, Nd)), 1), requires_grad=True) self.scales = Variable(torch.pow(Gamma(5, 5).rsample((Nc, Nd)), -0.5), requires_grad=True) # ?? # VDV for standard deviations : sigma ~ Gamma(alpla, beta) self.alpha = Variable(torch.rand(Nc, Nd) * 2 + 4, requires_grad=True) # 4 is hyperparameters self.beta = Variable(torch.rand(Nc, Nd) * 2 + 4, requires_grad=True) # 4 is hyperparameters # VDV for component weights: theta ~ Dir(C) self.couts = Variable(2 * torch.ones((Nc, )), requires_grad=True) # 2 is hyperparameters # Prior distributions for the means self.mu_prior = Normal(torch.zeros((Nc, Nd)), torch.ones((Nc, Nd))) # Prior distributions for the standard deviations self.sigma_prior = Gamma(5 * torch.ones((Nc, Nd)), 5 * torch.ones( (Nc, Nd))) # Prior distributions for the components weights self.theta_prior = Dirichlet(5 * torch.ones((Nc, ))) # uniform 0.2 * 5
def __init__(self, graph_location: "str", color_count: "int"): self.episode = 0 self.episode_cond = None self._color_count = color_count self._game = ColoredGraphGame(graph_location=graph_location) self._model = MCTS() self._model.initiate_sample(self._game.graph()) _tmp_dirichlet = 10 / ( (self._color_count) * self._game._colored_graph.vertex_count) self._dirichlet = Dirichlet( torch.tensor([_tmp_dirichlet for _ in range(self._color_count)])) self._resign_treshold = float("inf") self._count_uncolored_vertices = 0
def forward(self, xs=None, nData=None): """ :param xs: list(batch) of strings """ batch_size = len(xs) # Add any values waiting to be added if len(self.add_queue)>0: replace_idxs = [idx for idx in np.argpartition(self.mu.data, len(self.add_queue))][:len(self.add_queue)] for idx, value in zip(replace_idxs, self.add_queue): self.values[idx] = value self.add_queue.clear() # Sample weights dist = NormalLogSoftmax(self.mu.unsqueeze(0).repeat(batch_size, 1), self.sigma.unsqueeze(0).repeat(batch_size, 1)) log_pi = dist.rsample() p_pi = Dirichlet(1*self.t.new_ones(batch_size, self.T+1)).log_prob(log_pi.exp()) jacobian = -log_pi.sum(dim=1) #todo: double check this p_log_pi = p_pi - jacobian q_log_pi = dist.log_prob(log_pi) # Calculate probabilities from mixture distribution and base distribution component_probs = self.getComponentProbs(xs, log_pi) base_probs = log_pi[:, -1] + self.base(xs) #Total conditional = logsumexp(torch.cat([component_probs[:, None], base_probs[:, None]], dim=1)) # Variational bound score = ((p_log_pi - q_log_pi)/nData + conditional).mean() score += (base(self.values).sum()/nData).mean() return score
def __init__(self, concentration1, concentration0, validate_args=None): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
def __init__(self, concentration1, concentration0): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = variable([concentration1, concentration0]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape)
class Beta(Distribution): r""" Creates a Beta distribution parameterized by concentration `alpha` and `beta`. Example:: >>> m = Beta(torch.Tensor([0.5]), torch.Tensor([0.5])) >>> m.sample() # Beta distributed with concentrarion alpha 0.1046 [torch.FloatTensor of size 2] Args: alpha (Tensor or Variable): concentration parameter of the distribution """ params = {'alpha': constraints.positive, 'beta': constraints.positive} support = constraints.unit_interval has_rsample = True def __init__(self, alpha, beta): if isinstance(alpha, Number) and isinstance(beta, Number): alpha_beta = torch.Tensor([alpha, beta]) else: alpha, beta = broadcast_all(alpha, beta) alpha_beta = torch.stack([alpha, beta], -1) self._dirichlet = Dirichlet(alpha_beta) super(Beta, self).__init__(self._dirichlet._batch_shape) def rsample(self, sample_shape=()): value = self._dirichlet.rsample(sample_shape).select(-1, 0) if isinstance(value, Number): value = self._dirichlet.alpha.new([value]) return value def log_prob(self, value): self._validate_log_prob_arg(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) def entropy(self): return self._dirichlet.entropy()
class Beta(Distribution): r""" Beta distribution parameterized by `concentration1` and `concentration0`. Example:: >>> m = Beta(torch.Tensor([0.5]), torch.Tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 0.1046 [torch.FloatTensor of size 1] Args: concentration1 (float or Tensor or Variable): 1st concentration parameter of the distribution (often referred to as alpha) concentration0 (float or Tensor or Variable): 2nd concentration parameter of the distribution (often referred to as beta) """ params = {'concentration1': constraints.positive, 'concentration0': constraints.positive} support = constraints.unit_interval has_rsample = True def __init__(self, concentration1, concentration0): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = variable([concentration1, concentration0]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape) @property def mean(self): return self.concentration1 / (self.concentration1 + self.concentration0) @property def variance(self): total = self.concentration1 + self.concentration0 return (self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))) def rsample(self, sample_shape=()): value = self._dirichlet.rsample(sample_shape).select(-1, 0) if isinstance(value, Number): value = self._dirichlet.concentration.new([value]) return value def log_prob(self, value): self._validate_log_prob_arg(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) def entropy(self): return self._dirichlet.entropy() @property def concentration1(self): result = self._dirichlet.concentration[..., 0] if isinstance(result, Number): return torch.Tensor([result]) else: return result @property def concentration0(self): result = self._dirichlet.concentration[..., 1] if isinstance(result, Number): return torch.Tensor([result]) else: return result
class Beta(ExponentialFamily): r""" Beta distribution parameterized by `concentration1` and `concentration0`. Example:: >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 tensor([ 0.1046]) Args: concentration1 (float or Tensor): 1st concentration parameter of the distribution (often referred to as alpha) concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive} support = constraints.unit_interval has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): if isinstance(concentration1, Number) and isinstance(concentration0, Number): concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) self._dirichlet = Dirichlet(concentration1_concentration0) super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args) @property def mean(self): return self.concentration1 / (self.concentration1 + self.concentration0) @property def variance(self): total = self.concentration1 + self.concentration0 return (self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))) def rsample(self, sample_shape=()): value = self._dirichlet.rsample(sample_shape).select(-1, 0) if isinstance(value, Number): value = self._dirichlet.concentration.new_tensor(value) return value def log_prob(self, value): if self._validate_args: self._validate_sample(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) def entropy(self): return self._dirichlet.entropy() @property def concentration1(self): result = self._dirichlet.concentration[..., 0] if isinstance(result, Number): return torch.tensor([result]) else: return result @property def concentration0(self): result = self._dirichlet.concentration[..., 1] if isinstance(result, Number): return torch.tensor([result]) else: return result @property def _natural_params(self): return (self.concentration1, self.concentration0) def _log_normalizer(self, x, y): return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)