def select_action(args, logits, status='train', exploration=True, info={}): if args.continuous: act_mean = logits act_std = cuda_wrapper(torch.ones_like(act_mean), args.cuda) if status is 'train': return Normal(act_mean, act_std).sample() elif status is 'test': return act_mean else: if status is 'train': if exploration: if args.epsilon_softmax: eps = info['softmax_eps'] p_a = (1 - eps) * torch.softmax(logits, dim=-1) + eps / logits.size(-1) return OneHotCategorical(logits=None, probs=p_a).sample() elif args.gumbel_softmax: return GumbelSoftmax(logits=logits).sample() else: return OneHotCategorical(logits=logits).sample() else: if args.gumbel_softmax: temperature = 1.0 return torch.softmax(logits/temperature, dim=-1) else: return OneHotCategorical(logits=logits).sample() elif status is 'test': p_a = torch.softmax(logits, dim=-1) return (p_a == torch.max(p_a, dim=-1, keepdim=True)[0]).float()
def __init__(self): super(Transform5, self).__init__() kernel = [[[[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]]] kernel = torch.from_numpy(np.array(kernel)).float() self.conv1 = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv1.weight = nn.Parameter(kernel) self.conv_trans1 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans2 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans3 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_trans4 = nn.Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.conv_smooth = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv_smooth.weight = nn.Parameter( torch.ones(9).cuda().view(1, 1, 3, 3) * 1 / 9.) self.drop = nn.Dropout(p=0.05) self.relu = nn.ReLU(inplace=True) self.one_hot1 = OneHotCategorical(torch.Tensor([0.6, 0.4])) self.one_hot2 = OneHotCategorical( torch.Tensor([ 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 0.000, 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24., 1 / 24. ])) for param in self.parameters(): param.requires_grad = False
def _prior(self, inference='sample'): device = self.z2hdec.weight.device batchSize = self.images.shape[0] if self.varType == 'cont': if inference == 'sample': sample = torch.randn(batchSize, self.latentSize, device=device) else: sample = torch.zeros(batchSize, self.latentSize, device=device) mean = torch.zeros_like(sample) logv = torch.ones_like(sample) z = sample, (mean, logv) elif self.varType == 'gumbel': K = self.num_embeddings V = self.num_vars prior_probs = torch.tensor([1 / K] * K, dtype=torch.float, device=device) logprior = torch.log(prior_probs) if inference == 'sample': prior = OneHotCategorical(prior_probs) z_vs = prior.sample(sample_shape=(batchSize * V, )) z = z_vs.reshape([batchSize, -1]) else: z_vs = prior_probs.expand(batchSize * V, -1) z = z_vs.reshape([batchSize, -1]) z = (z, logprior) elif self.varType == 'none': raise Exception('Z has no prior for varType==none') return z
def sample_stroke(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[batch, seq_len, M] mus[batch, seq_len, M, 2] sigmas[batch, seq_len, M, 2] rhos[batch, seq_len, M] qs[batch, seq_len, 3] Output: strokes[batch, seq_len, 5]: """ batch_size, seq_len, M = pis.size() strokes = [] sigmas = sigmas * gamma # Sample for each sketch for i in range(batch_size): #print(pis[:,i,:].size(), pis[:,i,:].device) #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size()) comp_m = OneHotCategorical(logits=pis[i, :, :]) comp_choice = (comp_m.sample()==1) mu, sigma, rho, q = mus[i,:,:,:][comp_choice], sigmas[i,:,:,:][comp_choice], rhos[i,:,:][comp_choice], qs[i,:,:] cov = torch.stack([torch.diag(sigma[j]*sigma[j]) + (1-torch.eye(2).to(mu.device)) * rho[j] * torch.prod(sigma[j]) for j in range(seq_len)]).to(device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=1, keepdim=True)[0]).to(dtype=torch.float)#[seq_len, 3] stroke = torch.cat([stroke_move, pen_states], dim=1).to(pis.device) strokes.append(stroke) return torch.stack(strokes)
def sample_single_stroke(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[M] mus[M, 2] sigmas[M, 2] rhos[M] qs[3] Output: strokes[5] """ comp_m = OneHotCategorical(logits=pis) comp_choice = (comp_m.sample() == 1) mu, sigma, rho, q = mus[comp_choice].view(-1), sigmas[ comp_choice].view(-1), rhos[comp_choice].view(-1), qs.view(-1) cov = (torch.diag((sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to( device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=0, keepdim=True)[0]).to( dtype=torch.float) #[seq_len, 3] # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states) stroke = torch.cat( [stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device) return stroke
def distributions(self): """Generate one hot and normal samples""" from torch.distributions import Normal from torch.distributions.one_hot_categorical import OneHotCategorical pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz])) py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy) return py, pz
def __init__(self): super(Transform4, self).__init__() self.conv_trans = nn.Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, groups=3) self.one_hot2 = OneHotCategorical(torch.Tensor([1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 0.000, 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24., 1/24.]))
def sample(self, probas, beta=1): batch_size = probas[0].size(0) phi = beta * sum([p.view(batch_size, self.q, self.N) for p in probas]) phi += self.linear.weights.view(1, self.q, self.N) self.phi = phi distribution = OneHotCategorical( probs=F.softmax(phi, 1).permute(0, 2, 1)) return distribution.sample().permute(0, 2, 1)
def generate_spikes(self, neurons_group): spikes = OneHotCategorical( torch.softmax(torch.cat( (torch.zeros([len(neurons_group), 1]).to(self.device), self.potential[neurons_group - self.n_input_neurons]), dim=-1), dim=-1)).sample() self.spiking_history[neurons_group, :, -1] = spikes[:, 1:].to(self.device)
def forward(self, logits, training=True, temperature=None): # gumbel-softmax (training and evaluation) if temperature is not None: return F.gumbel_softmax(logits, hard=not training, tau=temperature) # softmax training elif training: return F.softmax(logits, dim=1) # softmax evaluation else: return OneHotCategorical(logits=logits).sample()
def forward(self, inputs): batch_size = inputs.size(0) h, c = self.lstm( inputs.expand(self.len_msg, batch_size, self.embedding_size).transpose(0, 1)) logits = self.linear(h) dists_speaker = OneHotCategorical(logits=F.log_softmax(logits, dim=2)) return dists_speaker
def forward(self, x): h1 = self.act(self.fc1(x)) h2 = self.act(self.fc2(h1)) h3 = self.act(self.fc3(h2)) h4 = self.act(self.fc4(h3)) out = self.fc5(h4) probs1 = F.softmax(out[:, :NG1], dim=1) probs2 = F.softmax(out[:, NG1:], dim=1) distr1 = OneHotCategorical(probs=probs1) distr2 = OneHotCategorical(probs=probs2) msg_oht1 = distr1.sample() msg_oht2 = distr2.sample() self.get_log_probs = torch.log((probs1 * msg_oht1).sum(1)) + torch.log( (probs2 * msg_oht2).sum(1)) self.get_entropy = distr2.entropy() msg1 = msg_oht1.argmax(1) msg2 = msg_oht2.argmax(1) msgs_value = torch.cat((msg1.unsqueeze(1), msg2.unsqueeze(1)), dim=1) return out, msgs_value
def get_action(self, state): all_hp_probs, all_anchor_probs = self.forward(state) all_anchor_act, all_hp_act = [], [] for layer_anchor_probs in all_anchor_probs: anchor_sampler = Bernoulli(layer_anchor_probs) layer_anchor_act = anchor_sampler.sample() all_anchor_act.append(layer_anchor_act) for hp_probs in all_hp_probs: sampler = OneHotCategorical(logits=hp_probs) all_hp_act.append(sampler.sample()) return all_hp_act, all_anchor_act
def construct_samples(executor_reply): onehot_reply = {} sample_reply = {} for (k, v) in executor_reply.items(): one_hot = OneHotCategorical(v).sample() onehot_reply[k] = one_hot sample_reply["sample_" + k] = Categorical(one_hot).sample() pre_log_prob_sum, pre_log_probs = compute_log_prob(sample_reply, executor_reply) return pre_log_prob_sum, pre_log_probs, onehot_reply, sample_reply
def build_diayn(n_skills=4, env_name="MountainCar-v0", alpha=.1): ''' :param n_skills: :param env_name: "MountainCar-v0" or "Navigation2D" :return: ''' env = gym.make(env_name) alpha = .1 gamma = .9 prior = OneHotCategorical(torch.ones((1, n_skills))) hidden_sizes = {s: [30, 30] for s in ("actor", "discriminator", "critic")} model = diayn.DIAYN(env, prior, hidden_sizes, alpha=alpha, gamma=gamma) return model
def sample_z(args): # generate samples from the prior z_cat = OneHotCategorical( logits=torch.zeros(args.batch_size, args.cat_dim)).sample() z_noise = dist.Uniform(-1, 1).sample( torch.Size((args.batch_size, args.noise_dim))) z_cont = dist.Uniform(-1, 1).sample( torch.Size((args.batch_size, args.cont_dim))) # concatenate the incompressible noise, discrete latest, and continuous latents z = torch.cat([z_noise, z_cat, z_cont], dim=1) return z.to(args.device), z_cat.to(args.device), z_noise.to( args.device), z_cont.to(args.device)
def generate_input(self) -> torch.Tensor: noise_input = torch.randn(self.feature_spec['noise']) categorical_input = [] categorical_labels = [] for n_cat in self.feature_spec['categorical']: categorical_input.append(OneHotCategorical(torch.ones(n_cat)/n_cat).sample()) categorical_labels.append(torch.argmax(categorical_input[-1])) categorical_input = torch.hstack(categorical_input) gaussian_input = torch.randn(self.feature_spec['gaussian']) uniform_input = Uniform(-1, 1).sample((self.feature_spec['uniform'], )) gen_input = torch.hstack([noise_input, categorical_input, gaussian_input, uniform_input]) gen_input = gen_input.to(self.device) return gen_input, torch.tensor(categorical_labels)
def forward(self, rating_matrix): cores = F.normalize(self.k_embedding.weight, dim=1) items = F.normalize(self.item_embedding.weight, dim=1) rating_matrix = F.normalize(rating_matrix) rating_matrix = F.dropout(rating_matrix, self.drop_out, training=self.training) cates_logits = torch.matmul(items, cores.transpose(0, 1)) / self.tau if self.nogb: cates = torch.softmax(cates_logits, dim=1) else: cates_dist = OneHotCategorical(logits=cates_logits) cates_sample = cates_dist.sample() cates_mode = torch.softmax(cates_logits, dim=1) cates = (self.training * cates_sample + (1 - self.training) * cates_mode) probs = None mulist = [] logvarlist = [] for k in range(self.kfac): cates_k = cates[:, k].reshape(1, -1) # encoder x_k = rating_matrix * cates_k h = self.encoder(x_k) mu = h[:, :self.embedding_size] mu = F.normalize(mu, dim=1) logvar = h[:, self.embedding_size:] mulist.append(mu) logvarlist.append(logvar) z = self.reparameterize(mu, logvar) # decoder z_k = F.normalize(z, dim=1) logits_k = torch.matmul(z_k, items.transpose(0, 1)) / self.tau probs_k = torch.exp(logits_k) probs_k = probs_k * cates_k probs = (probs_k if (probs is None) else (probs + probs_k)) logits = torch.log(probs) return logits, mulist, logvarlist
def log_ce_with_pg(pred, truth, r, b): # (bs, t, c) all_hp_pred, all_prob_anchors = pred all_hp_act, all_act_anchors = truth loss = torch.FloatTensor([0.0]) for hp_pred, hp_act in list(zip(all_hp_pred, all_hp_act)): target = hp_act.detach() sampler = OneHotCategorical(logits=hp_pred) l = torch.mean(torch.sum(-sampler.log_prob(target), dim=-1) * (b - r)) loss += l for anchors_pred, anchors_act in list( zip(all_prob_anchors, all_act_anchors)): target = anchors_act.detach() sampler = Bernoulli(logits=anchors_pred) l = torch.mean(torch.sum(-sampler.log_prob(target), dim=-1) * (b - r)) loss += l return loss
def gumbel_softmax(self, logits, temperature, hard=False): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar hard: if True, take argmax, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes """ prob = self.gumbel_softmax_sample(logits, temperature) if hard: sampler = OneHotCategorical(prob) prob = sampler.sample() return prob
def build_diayn(n_skills=4, env_name="MountainCar-v0", alpha=0.1, gamma=0.1, seed = 101): ''' :param n_skills: :param env_name: "MountainCar-v0" or "Navigation2D" :alpha=0.1, :gamma=0.1, :seed = 101 :return: ''' if env_name == "Navigation2D" : env = Navigation2D(20) else : env = gym.make(env_name) prior = OneHotCategorical(torch.ones((1, n_skills))) hidden_sizes = {s: [30, 30] for s in ("actor", "discriminator", "critic")} model = DIAYN(env, prior, hidden_sizes, alpha, gamma, seed = seed) return model
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(probs=probs) return cat_distr.sample(), cat_distr.entropy() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, probs=probs) y_soft = cat_distr.rsample() if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(probs, device=DEVICE).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret, ret
def recon_sketches(self, pis, mus, sigmas, rhos, qs, gamma): """ Input: pis[batch, seq_len, M] mus[batch, seq_len, M, 2] sigmas[batch, seq_len, M, 2] rhos[batch, seq_len, M] qs[batch, seq_len, 3] Output: strokes[batch, seq_len, 5]: """ batch_size, seq_len, M = pis.size() sketches = [] sigmas = sigmas * gamma # Sample for each sketch for i in range(batch_size): strokes = [] #print(pis[:,i,:].size(), pis[:,i,:].device) #print(pis.size(), mus.size(), sigmas.size(), rhos.size(), qs.size()) for j in range(seq_len): comp_m = OneHotCategorical(logits=pis[i, j]) comp_choice = (comp_m.sample() == 1) mu, sigma, rho, q = mus[i, j][comp_choice].view(-1), sigmas[ i, j][comp_choice].view(-1), rhos[i, j][comp_choice].view( -1), qs[i, j].view(-1) cov = (torch.diag( (sigma * sigma)) + (1 - torch.eye(2).to(mu.device)) * rho * torch.prod(sigma)).to(device=mu.device) normal_m = MultivariateNormal(mu, cov) stroke_move = normal_m.sample().to(pis.device) # [seq_len, 2] pen_states = (q == q.max(dim=0, keepdim=True)[0]).to( dtype=torch.float) #[seq_len, 3] # print('mu',mu,'stroke_move', stroke_move, 'pen_states', pen_states) stroke = torch.cat([stroke_move.view(-1), pen_states.view(-1)], dim=0).to(pis.device) strokes.append(stroke) sketches.append(torch.stack(strokes)) return torch.stack(sketches)
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(logits=logits) return cat_distr.sample() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, logits=logits) y_soft = cat_distr.rsample() elif mode == 'SOFTMAX': y_soft = F.softmax(logits, dim=1) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, device=args.device).scatter_( dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret
def forward(self, tgt_x): batch_size = tgt_x.shape[0] tgt_hid = self.x_to_embd(tgt_x) lstm_input = torch.zeros((batch_size, NG1)).cuda() lstm_hid = tgt_hid.squeeze(1) lstm_cell = tgt_hid.squeeze(1) msgs = [] msgs_value = [] logits = [] log_probs = 0. for _ in range(2): lstm_hid, lstm_cell = self.lstm(lstm_input, (lstm_hid, lstm_cell)) logit = self.out_layer(lstm_hid) logits.append(logit) probs = nn.functional.softmax(logit, dim=1) if self.training: cat_distr = OneHotCategorical(probs=probs) msg_oht, entropy = cat_distr.sample(), cat_distr.entropy() self.get_entropy = entropy else: msg_oht = nn.functional.one_hot( torch.argmax(probs, dim=1), num_classes=self.out_size).float() log_probs += torch.log((probs * msg_oht).sum(1)) msgs.append(msg_oht) msgs_value.append(msg_oht.argmax(1)) lstm_input = msg_oht msgs = torch.stack(msgs) msgs_value = torch.stack(msgs_value).transpose(0, 1) logits = torch.stack(logits) logits = logits.transpose(0, 1).reshape(batch_size, -1) self.get_log_probs = log_probs return logits, msgs_value
def getNdiracs(data, N , sparse = False, flat = False, replace = True): if not sparse: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') graphcount =data.num_nodes #number of graphs in data/batch object totalnodecount = data.x.shape[1] #number of total nodes for each graph actualnodecount = 0 #cumulative number of nodes diracmatrix= torch.zeros((graphcount,totalnodecount,N),device=device) #matrix with dirac pulses for k in range(graphcount): graph_nodes = data.mask[k].sum() #number of nodes in the graph actualnodecount += graph_nodes #might not need this, we'll see probabilities= torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs node_distribution=OneHotCategorical(probs=probabilities.squeeze()) node_sample= node_distribution.sample(sample_shape=(N,)) node_sample= torch.cat((node_sample,torch.zeros((N,totalnodecount-node_sample.shape[1]),device=device)),-1) #concat zeros to fit dataset shape diracmatrix[k,:]= torch.transpose(node_sample,dim0=-1,dim1=-2) #add everything to the final matrix return diracmatrix else: original_batch_index = data.batch original_edge_index = data.edge_index device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') graphcount =data.num_graphs #number of graphs in data/batch object diracmatrix = torch.zeros(0,device=device) batch_prime = torch.zeros(0,device=device).long() locationmatrix = torch.zeros(0,device=device).long() global_offset = 0 for k in range(graphcount): graph_nodes = (data.batch == k).sum() #probabilities = torch.ones((graph_nodes.item(),1),device=device)/graph_nodes #uniform probs #node_distribution = OneHotCategorical(probs=probabilities.squeeze()) #node_sample = node_distribution.sample(sample_shape=(N,)) # if flat: # diracmatrix = torch.cat((diracmatrix, node_sample.view(-1)),0) # else: # diracmatrix = torch.cat((diracmatrix, node_sample.t(),0)) #for diracmatrix randInt = np.random.choice(range(graph_nodes), N, replace = replace) node_sample = torch.zeros(N*graph_nodes,device=device) offs = torch.arange(N, device=device)*graph_nodes dirac_locations = (offs + torch.from_numpy(randInt).to(device)) node_sample[dirac_locations] = 1 dirac_locations2 = torch.from_numpy(randInt).to(device) + global_offset global_offset += graph_nodes diracmatrix = torch.cat((diracmatrix, node_sample),0) locationmatrix = torch.cat((locationmatrix, dirac_locations2),0) #for batch prime dirac_indices = torch.arange(N, device=device).unsqueeze(-1).expand(-1, graph_nodes).contiguous().view(-1) dirac_indices = dirac_indices.long() dirac_indices += k*N batch_prime = torch.cat((batch_prime, dirac_indices)) # locationmatrix = diracmatrix.nonzero() edge_index_prime = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.shape[1]).contiguous().view(-1)*data.batch.shape[0] offset = torch.arange(N).unsqueeze(-1).expand(-1,data.edge_index.size()[1]).contiguous().view(-1)*data.batch.shape[0] offset_2 = torch.cat(2*[offset.unsqueeze(0)],dim = 0) edge_index_prime = torch.cat(N*[data.edge_index], dim = 1) + offset_2 normalization_indices = data.batch.unsqueeze(-1).expand(-1,N).contiguous().view(-1).to(device) return Batch(batch = batch_prime, x = diracmatrix, edge_index = edge_index_prime, y = data.y, locations = locationmatrix, norm_index = normalization_indices, batch_old = original_batch_index, edge_index_old = original_edge_index)
def sample_action_with_prob(self, x): likelihood = self.forward(x) m = OneHotCategorical(likelihood) action = m.sample() return action, m.log_prob(action)
def get_kl(dists, P, eps=1e-4, eta=1e-20, Fdiv='kl'): """Get KL divergences for different distributions.""" kl = 0. for dst in dists: name = dst['name'] family = dst['family'] if family == 'gaussian' or family == 'mv_normal': if P.wn: lambda_r_mu = normalize_weights(P=P, name=name, prop='center') lambda_r_scale = normalize_weights(P=P, name=name, prop='scale') # noqa else: attr_name = '{}_{}'.format(name, 'center') lambda_r_mu = getattr(P, attr_name) attr_name = '{}_{}'.format(name, 'scale') lambda_r_scale = getattr(P, attr_name) lambda_0_mu = dst['lambda_0'] lambda_0_scale = dst['lambda_0_scale'] lambda_r_dist = MultivariateNormal(loc=lambda_r_mu, covariance_matrix=lambda_r_scale) # noqa lambda_0_dist = MultivariateNormal(loc=lambda_0_mu, covariance_matrix=lambda_0_scale) # noqa elif family == 'low_mv_normal': if P.wn: lambda_r_mu = normalize_weights(P=P, name=name, prop='center') lambda_r_scale = normalize_weights(P=P, name=name, prop='scale') # noqa lambda_r_factor = normalize_weights(P=P, name=name, prop='factor') # noqa else: attr_name = '{}_{}'.format(name, 'center') lambda_r_mu = getattr(P, attr_name) attr_name = '{}_{}'.format(name, 'scale') lambda_r_scale = getattr(P, attr_name) attr_name = '{}_{}'.format(name, 'factor') lambda_r_factor = getattr(P, attr_name) lambda_0_mu = dst['lambda_0'] lambda_0_scale = dst['lambda_0_scale'] lambda_0_factor = dst['lambda_0_factor'] lambda_r_dist = LowRankMultivariateNormal(loc=lambda_r_mu, cov_diag=lambda_r_scale, cov_factor=lambda_r_factor) # noqa lambda_0_dist = LowRankMultivariateNormal(loc=lambda_0_mu, cov_diag=lambda_0_scale, cov_factor=lambda_0_factor) # noqa elif family == 'normal' or family == 'abs_normal' or family == 'cnormal': if P.wn: lambda_r_mu = normalize_weights(P=P, name=name, prop='center') lambda_r_scale = normalize_weights(P=P, name=name, prop='scale') # noqa else: attr_name = '{}_{}'.format(name, 'center') lambda_r_mu = getattr(P, attr_name) attr_name = '{}_{}'.format(name, 'scale') lambda_r_scale = getattr(P, attr_name) lambda_0_mu = dst['lambda_0'] lambda_0_scale = dst['lambda_0_scale'] lambda_r_dist = Normal(loc=lambda_r_mu, scale=lambda_r_scale) lambda_0_dist = Normal(loc=lambda_0_mu, scale=lambda_0_scale) elif family == 'half_normal': if P.wn: lambda_r_scale = normalize_weights(P=P, name=name, prop='scale') # noqa attr_name = '{}_{}'.format(name, 'scale') lambda_r_scale = getattr(P, attr_name) lambda_0_scale = dst['lambda_0_scale'] lambda_r_dist = HalfNormal(scale=lambda_r_scale) lambda_0_dist = HalfNormal(scale=lambda_0_scale) elif family == 'categorical': if P.wn: lambda_r_mu = normalize_weights(P=P, name=name, prop='center') attr_name = '{}_{}'.format(name, 'center') lambda_r_mu = getattr(P, attr_name) lambda_0 = dst['lambda_0'] # This is probs log_0 = (lambda_0 + eta).log() # noqa lambda_r_dist = RelaxedOneHotCategorical(temperature=1e-1, logits=lambda_r_mu) # Log probs # noqa lambda_0_dist = RelaxedOneHotCategorical(temperature=1e-1, logits=log_0) lambda_r_dist = OneHotCategorical(logits=lambda_r_mu) # Log probs lambda_0_dist = OneHotCategorical(logits=log_0) elif family == 'relaxed_bernoulli': if P.wn: lambda_r_mu = normalize_weights(P=P, name=name, prop='center') attr_name = '{}_{}'.format(name, 'center') lambda_r_mu = getattr(P, attr_name) lambda_0 = dst['lambda_0'] # This is probs log_0 = (lambda_0 + eta).log() # noqa lambda_r_dist = RelaxedBernoulli(temperature=1e-1, logits=lambda_r_mu) # Log probs # noqa lambda_0_dist = RelaxedBernoulli(temperature=1e-1, logits=log_0) lambda_r_dist = Bernoulli(logits=lambda_r_mu) # Log probs lambda_0_dist = Bernoulli(logits=log_0) else: raise NotImplementedError( 'KL for {} is not implemented.'.format(family)) if Fdiv == 'kl': it_kl = kl_divergence(p=lambda_0_dist, q=lambda_r_dist).sum() elif Fdiv == 'js': raise RuntimeError('Needs per-distribution implementation.') m = 0.5 * (lambda_0_dist.probs * lambda_r_dist.probs) p = kl_divergence(p=lambda_0_dist, q=m).sum() q = kl_divergence(p=lambda_r_dist, q=m).sum() it_kl = 0.5 * p + 0.5 * q else: raise NotImplementedError(div) if it_kl < -1e-4 or torch.isnan(it_kl): # Give a numerical margin print(kl) kl = kl + it_kl return kl
def sample(self, dist_info): prob = dist_info["prob"] sampler = OneHotCategorical(prob) return sampler.sample()
def __init__(self, p): super(IndependantSampler, self).__init__() self.sampler = OneHotCategorical(p) self.p = p