def kl(self, dist_a, prior=None): """ KL divergence of dist_a against a prior, if none then Cat(1/k) :param dist_a: the distribution parameters :param prior: prior parameters (or None) :returns: batch_size kl-div tensor :rtype: torch.Tensor """ if prior is None: # use standard uniform prior # return torch.sum(GumbelSoftmax._kld_categorical_uniform( # dist_a['discrete']['log_q_z'], dim=self.dim # ), -1) prior = D.OneHotCategorical(logits=dist_a['discrete']['log_q_z']) return torch.sum(GumbelSoftmax._kl_tf_version( D.OneHotCategorical(logits=dist_a['discrete']['log_q_z']), prior ), -1) # we have two distributions provided (eg: VRNN) return torch.sum(GumbelSoftmax._kl_tf_version( D.OneHotCategorical(logits=dist_a['discrete']['log_q_z']), # D.OneHotCategorical(prior['discrete']['log_q_z']) D.OneHotCategorical(logits=prior['discrete']['log_q_z']) ), -1)
def _get_sender_lstm_output(self, inputs): samples = [] batch_size = inputs.shape[0] sample_loss = torch.zeros(batch_size, device=self.config['device']) total_kl = torch.zeros(batch_size, device=self.config['device']) hx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) cx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) for num in range(self.config['num_binary_messages']): hx, cx = self.sender_cell(inputs, (hx, cx)) output = self.sender_project(hx) pre_logits = self.sender_out(output) sample = utils.gumbel_softmax( pre_logits, self.temperature[num], self.config['device'], ) logits_dist = dists.OneHotCategorical(logits=pre_logits) prior_logits = self.prior[num].unsqueeze(0) prior_logits = prior_logits.expand(batch_size, self.output_size) prior_dist = dists.OneHotCategorical(logits=prior_logits) kl = dists.kl_divergence(logits_dist, prior_dist) total_kl += kl samples.append(sample) return samples, sample_loss, total_kl
def decode_x(self, w, z): params = self.decoder_x(torch.cat((w, z), dim=-1)) px_wz = [] samples = [] for indices in self.likelihood_partition: data_type = self.likelihood_partition[indices] params_subset = params[:, indices[0]:(indices[1] + 1)] if data_type == 'real': cov_diag = self.likelihood_params['lik_var'] * torch.ones_like( params_subset).to(self.device) dist = D.Normal(loc=params_subset, scale=cov_diag.sqrt()) elif data_type == 'categorical': dist = D.OneHotCategorical(logits=params_subset) elif data_type == 'binary': dist = D.Bernoulli(logits=params_subset) elif data_type == 'positive': lognormal_var = self.likelihood_params[ 'lik_var_lognormal'] * torch.ones_like(params_subset).to( self.device) dist = D.LogNormal(loc=params_subset, scale=lognormal_var.sqrt()) elif data_type == 'count': positive_params_subset = F.softplus(params_subset) dist = D.Poisson(rate=positive_params_subset) elif data_type == 'binomial': num_trials = self.likelihood_params['binomial_num_trials'] dist = D.Binomial(total_count=num_trials, logits=params_subset) elif data_type == 'ordinal': h = params_subset[:, 0:1] thetas = torch.cumsum(F.softplus(params_subset[:, 1:]), axis=1) prob_lessthans = torch.sigmoid(thetas - h) probs = torch.cat((prob_lessthans, torch.ones(len(prob_lessthans), 1)), axis=1) - \ torch.cat((torch.zeros(len(prob_lessthans), 1), prob_lessthans), axis=1) dist = D.OneHotCategorical(probs=probs) else: raise NotImplementedError samples.append(dist.sample()) px_wz.append(dist) sample_x = torch.cat(samples, axis=1) return params, sample_x, px_wz
def step(self, x, model): x_cur = x a_s = [] m_terms = [] prop_terms = [] for i in range(self.n_steps): forward_delta = self.diff_fn(x_cur, model) # make sure we dont choose to stay where we are! forward_logits = forward_delta - 1e9 * x_cur #print(forward_logits) cd_forward = dists.OneHotCategorical( logits=forward_logits.view(x_cur.size(0), -1)) changes = cd_forward.sample() # compute probability of sampling this change lp_forward = cd_forward.log_prob(changes) # reshape to (bs, dim, nout) changes_r = changes.view(x_cur.size()) # get binary indicator (bs, dim) indicating which dim was changed changed_ind = changes_r.sum(-1) # mask out cuanged dim and add in the change x_delta = x_cur.clone() * (1. - changed_ind[:, :, None]) + changes_r reverse_delta = self.diff_fn(x_delta, model) reverse_logits = reverse_delta - 1e9 * x_delta cd_reverse = dists.OneHotCategorical( logits=reverse_logits.view(x_delta.size(0), -1)) reverse_changes = x_cur * changed_ind[:, :, None] lp_reverse = cd_reverse.log_prob( reverse_changes.view(x_delta.size(0), -1)) m_term = (model(x_delta).squeeze() - model(x_cur).squeeze()) la = m_term + lp_reverse - lp_forward a = (la.exp() > torch.rand_like(la)).float() x_cur = x_delta * a[:, None, None] + x_cur * (1. - a[:, None, None]) a_s.append(a.mean().item()) m_terms.append(m_term.mean().item()) prop_terms.append((lp_reverse - lp_forward).mean().item()) self._ar = np.mean(a_s) self._mt = np.mean(m_terms) self._pt = np.mean(prop_terms) self._hops = (x != x_cur).float().sum(-1).sum(-1).mean().item() return x_cur
def mutual_info(self, params, eps=1e-9): """ Returns Ent + xent where xent is taken against hard targets. :param params: distribution parameters :param eps: tolerance :returns: batch_size tensor of mutual info :rtype: torch.Tensor """ return self.config['discrete_mut_info']*self.mutual_info_monte_carlo(params) targets = torch.argmax(params['q_z_given_xhat']['discrete']['z_hard'].type( long_type(self.config['cuda'])), dim=-1) # soft_targets = F.softmax( # params['discrete']['logits'], -1 # ).type(long_type(self.config['cuda'])) # targets = torch.argmax(params['discrete']['log_q_z'], -1) # 3rd change, havent tried # crossent_loss = -F.cross_entropy(input=params['q_z_given_xhat']['discrete']['logits'], crossent_loss = -F.cross_entropy(input=params['discrete']['logits'], target=targets, reduce=False) # ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1) # ent_loss = torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1) # print('xent = ', crossent_loss.shape, " |ent = ", ent_loss.shape) ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['logits']).entropy(), -1) return -self.config['discrete_mut_info'] * (ent_loss + crossent_loss)
def forward(self, x, return_latents=False): x = self.model(x) critic_score = self.critic(x) x = self.dist_conv(x).view(-1, x.size(1)) dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x)) dist_cont = distributions.Normal(loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x))) return critic_score, dist_dis, dist_cont if return_latents is True else critic_score
def step(self, x, model): if len(self._inds) == 0: # ran out of inds self._inds = self._init_inds() inds = self._inds[:self.block_size] self._inds = self._inds[self.block_size:] # bit flips in the hamming ball H = torch.tensor(hamming_ball(len(inds), min(self.hamming_dist, len(inds)))).float().to(x.device) H_inds = list(range(H.size(0))) chosen_H_inds = np.random.choice(H_inds, x.size(0)) changes = H[chosen_H_inds] u = x.clone() u[:, inds] = changes * (1. - u[:, inds]) + (1. - changes) * u[:, inds] # apply sampled changes U ~ p(U | X) logits = [] xs = [] for c in H: xc = u.clone() c = torch.tensor(c).float().to(xc.device)[None] xc[:, inds] = c * (1. - xc[:, inds]) + (1. - c) * xc[:, inds] # apply all changes l = model(xc).squeeze() xs.append(xc[:, :, None]) logits.append(l[:, None]) logits = torch.cat(logits, 1) xs = torch.cat(xs, 2) dist = dists.OneHotCategorical(logits=logits) choices = dist.sample() x_new = (xs * choices[:, None, :]).sum(-1) return x_new
def rsample(self, sample_shape=torch.Size([])): a_sampler = D.OneHotCategorical(probs=torch.ones(len(self.a_domain))) probs = {} x_a_dists = {} for a in self.a_domain: probs[a] = {} for x in self.x_support: probs[a][x] = math.exp(self.log_prob(x, a)) normalise = sum(probs[a].values()) for x in self.x_support: probs[a][x] = probs[a][x] / normalise x_a_dists[a] = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a]) self.test = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a]) a_vals = a_sampler.sample_n(sample_shape.numel()) a_counts = torch.sum(a_vals, axis=0) x_samples = [] a_samples = [] for a_c, a_vals in zip(a_counts, self.a_domain): a_c = int(a_c) x_samples.append(x_a_dists[a_vals].sample_n(a_c)) a_samples.append(torch.Tensor([a_vals] * a_c)) x_samples = torch.cat(x_samples).view(*sample_shape, -1) a_samples = torch.cat(a_samples).view(*sample_shape, -1) return x_samples, a_samples
def step(self, x, model): sample = x.clone() lp_keep = model(sample).squeeze() if self.rand: changes = dists.OneHotCategorical( logits=torch.zeros((self.dim, ))).sample( (x.size(0), )).to(x.device) else: changes = torch.zeros((x.size(0), self.dim)).to(x.device) changes[:, self._i] = 1. sample_change = (1. - changes) * sample + changes * (1. - sample) lp_change = model(sample_change).squeeze() lp_update = lp_change - lp_keep update_dist = dists.Bernoulli(logits=lp_update) updates = update_dist.sample() sample = sample_change * updates[:, None] + sample * (1. - updates[:, None]) self.changes[self._i] = updates.mean() self._i = (self._i + 1) % self.dim self._hops = (x != sample).float().sum(-1).mean().item() self._ar = self._hops return sample
def step(self, x, model): if self.rand: i = np.random.randint(0, self.dim) else: i = self._i logits = [] ndim = x.size(-1) for k in range(ndim): sample = x.clone() sample_i = torch.zeros((ndim, )) sample_i[k] = 1. sample[:, i, :] = sample_i lp_k = model(sample).squeeze() logits.append(lp_k[:, None]) logits = torch.cat(logits, 1) dist = dists.OneHotCategorical(logits=logits) updates = dist.sample() sample = x.clone() sample[:, i, :] = updates self._i = (self._i + 1) % self.dim self._hops = ((x != sample).float().sum(-1) / 2.).sum(-1).mean().item() self._ar = self._hops return sample
def step(self, x, model): H = self.H.to(x.device) x_cur = x forward_delta = self.diff_fn(x_cur, model) forward_logits = forward_delta @ H.t() cd_forward = dists.Categorical(logits=forward_logits.detach()) changes = cd_forward.sample() lp_forward = cd_forward.log_prob(changes) x_changes = H[changes] x_delta = (1. - x_cur) * x_changes + x_cur * (1. - x_changes) reverse_delta = self.diff_fn(x_delta.detach(), model) reverse_logits = reverse_delta @ H.t() cd_reverse = dists.OneHotCategorical(logits=reverse_logits.detach()) lp_reverse = cd_reverse.log_prob(changes) m_term = (model(x_delta).squeeze() - model(x_cur).squeeze()) la = m_term + lp_reverse - lp_forward a = (la.exp() > torch.rand_like(la)).float() x_cur = x_delta * a[:, None] + x_cur * (1. - a[:, None]) return x_cur
def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None): decode_array = prior for i in Bar('generating').iter(range(min(self.max_seq, length))): if decode_array.shape[1] >= self.max_seq: break _, _, look_ahead_mask = \ utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array) # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask) result, _ = decode_fn(decode_array, look_ahead_mask) result = self.fc(result) result = result.softmax(-1) if tf_board_writer: tf_board_writer.add_image("logits", result, global_step=i) u = random.uniform(0, 1) if u > 1: result = result[:, -1].argmax(-1).to(torch.int32) decode_array = torch.cat( [decode_array, result.unsqueeze(-1)], -1) else: pdf = dist.OneHotCategorical(probs=result[:, -1]) result = pdf.sample(1) result = torch.transpose(result, 1, 0).to(torch.int32) decode_array = torch.cat((decode_array, result), dim=-1) del look_ahead_mask decode_array = decode_array[0] return decode_array
def log_likelihood(self, z, params): """ Log-likelihood of z induced under params. :param z: inferred latent z :param params: the params of the distribution :returns: log-likelihood :rtype: torch.Tensor """ return D.OneHotCategorical(logits=params['discrete']['logits']).log_prob(z)
def dist_from_h(self, h, mode): logits_separated = torch.reshape(h, (-1, self.N, self.K)) logits_separated_mean_zero = logits_separated - torch.mean(logits_separated, dim=-1, keepdim=True) if self.z_logit_clip is not None and mode == ModeKeys.TRAIN: c = self.z_logit_clip logits = torch.clamp(logits_separated_mean_zero, min=-c, max=c) else: logits = logits_separated_mean_zero return td.OneHotCategorical(logits=logits)
def prior_distribution(self, batch_size, **kwargs): """ get a torch distrbiution prior :param batch_size: size of the prior :returns: uniform categorical :rtype: torch.distribution """ params = self.prior_params(batch_size, **kwargs) return D.OneHotCategorical(logits=params['discrete']['logits'])
def _sample_batch_from_proposal(self, batch_size, return_log_density_of_samples=False): # need to do n_samples passes through autoregressive net samples = torch.zeros(batch_size, self.autoregressive_net.input_dim) log_density_of_samples = torch.zeros(batch_size, self.autoregressive_net.input_dim) for dim in range(self.autoregressive_net.input_dim): # compute autoregressive outputs autoregressive_outputs = self.autoregressive_net(samples).reshape( -1, self.dim, self.autoregressive_net.output_dim_multiplier) # grab proposal params for dth dimensions proposal_params = autoregressive_outputs[..., dim, self.context_dim:] # make mixture coefficients, locs, and scales for proposal logits = proposal_params[ ..., :self.n_proposal_mixture_components] # [B, D, M] if logits.shape[0] == 1: logits = logits.reshape(self.dim, self.n_proposal_mixture_components) locs = proposal_params[..., self.n_proposal_mixture_components:( 2 * self.n_proposal_mixture_components)] # [B, D, M] scales = self.mixture_component_min_scale + self.scale_activation( proposal_params[..., ( 2 * self.n_proposal_mixture_components):]) # [B, D, M] # create proposal if self.Component is not None: mixture_distribution = distributions.OneHotCategorical( logits=logits, validate_args=True) components_distribution = self.Component(loc=locs, scale=scales) self.proposal = distributions_.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=components_distribution) proposal_samples = self.proposal.sample((1, )) # [S, B, D] else: self.proposal = distributions.Uniform(low=-4, high=4) proposal_samples = self.proposal.sample((1, batch_size, 1)) proposal_samples = proposal_samples.permute(1, 2, 0) # [B, D, S] proposal_log_density = self.proposal.log_prob(proposal_samples) log_density_of_samples[:, dim] += proposal_log_density.reshape( -1).detach() samples[:, dim] += proposal_samples.reshape(-1).detach() if return_log_density_of_samples: return samples, torch.sum(log_density_of_samples, dim=-1) else: return samples
def step(self, x, model): x_cur = x a_s = [] m_terms = [] prop_terms = [] for i in range(self.n_steps): forward_delta = self.diff_fn(x_cur, model) cd_forward = dists.OneHotCategorical(logits=forward_delta) changes_all = cd_forward.sample((self.n_samples, )) lp_forward = cd_forward.log_prob(changes_all).sum(0) changes = (changes_all.sum(0) > 0.).float() x_delta = (1. - x_cur) * changes + x_cur * (1. - changes) self._phops = (x_delta != x).float().sum(-1).mean().item() reverse_delta = self.diff_fn(x_delta, model) cd_reverse = dists.OneHotCategorical(logits=reverse_delta) lp_reverse = cd_reverse.log_prob(changes_all).sum(0) m_term = (model(x_delta).squeeze() - model(x_cur).squeeze()) la = m_term + lp_reverse - lp_forward a = (la.exp() > torch.rand_like(la)).float() x_cur = x_delta * a[:, None] + x_cur * (1. - a[:, None]) a_s.append(a.mean().item()) m_terms.append(m_term.mean().item()) prop_terms.append((lp_reverse - lp_forward).mean().item()) self._ar = np.mean(a_s) self._mt = np.mean(m_terms) self._pt = np.mean(prop_terms) self._hops = (x != x_cur).float().sum(-1).mean().item() return x_cur
def forward(self, x): raw_init_std = np.log(np.exp(self.init_std) - 1) x = self.model(x) if self.dist == "tanh_normal": mean, std = torch.chunk(x, 2, dim=-1) mean = self.mean_scale * torch.tanh(mean / self.mean_scale) std = self.softplus(std + raw_init_std) + self.min_std dist = td.Normal(mean, std) transforms = [TanhBijector()] dist = td.transformed_distribution.TransformedDistribution(dist, transforms) dist = td.Independent(dist, 1) elif self.dist == "onehot": dist = td.OneHotCategorical(logits=x) raise NotImplementedError("Atari not implemented yet!") return dist
def mutual_info_analytic(self, params, eps=1e-9): """ I(z_d; x) ~ H(z_prior, z_d) + H(z_prior), i.e. analytic version. :param params: parameters of distribution :param eps: tolerance :returns: batch_size mutual information (prop-to) tensor. :rtype: torch.Tensor """ targets = torch.argmax( F.softmax(params['discrete']['logits'], -1), dim=-1 ).type(long_type(self.config['cuda'])) crossent_loss = F.cross_entropy(input=params['q_z_given_xhat']['discrete']['logits'], target=targets, reduce=False) ent_loss = -torch.sum(D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1) return ent_loss + crossent_loss
def prior_params(self, batch_size, **kwargs): """ Helper to get prior parameters :param batch_size: the size of the batch :returns: a dictionary of parameters :rtype: dict """ uniform_probs = same_type(self.config['half'], self.config['cuda'])( batch_size, self.output_size).zero_() uniform_probs += 1.0 / self.output_size return { 'discrete': { 'logits': D.OneHotCategorical(probs=uniform_probs).logits } }
def forward(self, h, z_logit_clip=None): ''' h: hidden state used to compute distribution parameter, (batch, self.K) ''' self.device = h.device h = self.h_to_logit(h) logits_separated = torch.reshape(h, (-1, self.N, self.K)) logits_separated_mean_zero = logits_separated - torch.mean( logits_separated, dim=-1, keepdim=True) if z_logit_clip is not None and self.training: logits = torch.clamp(logits_separated_mean_zero, min=-z_logit_clip, max=z_logit_clip) else: logits = logits_separated_mean_zero self.dist = td.OneHotCategorical(logits=logits)
def mutual_info_analytic(self, params, eps=1e-9): # I(z_d; x) ~ H(z_prior, z_d) + H(z_prior) targets = torch.argmax(params['discrete']['z_hard'].type( long_type(self.config['cuda'])), dim=-1) # soft_targets = F.softmax( # params['discrete']['logits'], -1 # ).type(long_type(self.config['cuda'])) # targets = torch.argmax(params['discrete']['log_q_z'], -1) # 3rd change, havent tried crossent_loss = -F.cross_entropy( input=params['q_z_given_xhat']['discrete']['logits'], target=targets, reduce=False) ent_loss = -torch.sum( D.OneHotCategorical(logits=params['discrete']['z_hard']).entropy(), -1) return ent_loss + crossent_loss
def check_test_acc(self): messages = [] log_probs = np.zeros(5000) for num in range(self.config['num_binary_messages']): prior_dst = dists.OneHotCategorical(logits=self.prior[num]) samples = prior_dst.sample((5000, )) log_prob = prior_dst.log_prob(samples).data.cpu().numpy() messages.append(samples) log_probs += log_prob messages = torch.stack(messages).permute(1, 0, 2) maxz = torch.argmax(messages, dim=-1, keepdim=True) h_z = torch.zeros(messages.shape, device=self.config['device']).scatter_(-1, maxz, 1) _, final_preds = self.test_forward(h_z) final_preds = final_preds.data.cpu().numpy() no_rep = utils.check_correct_preds(final_preds) return no_rep / self.config['batch_size']
def __init__(self, args): super().__init__() C, H, W = args.image_dims x_dim = C * H * W # -------------------- # p model -- SSL paper generative semi supervised model M2 # -------------------- self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device)) self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device)) # parametrized data likelihood p(x|y,z) self.decoder = nn.Sequential(nn.Linear(args.z_dim + args.y_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, x_dim)) # -------------------- # q model -- SSL paper eq 4 # -------------------- # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution self.encoder_y = nn.Sequential(nn.Linear(x_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.y_dim)) # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution self.encoder_z = nn.Sequential(nn.Linear(x_dim + args.y_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, 2 * args.z_dim)) # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4) for p in self.parameters(): p.data.normal_(0, 0.001) if p.ndimension() == 1: p.data.fill_(0.)
def test_prior(self, data): batch_size = data.shape[0] input_embs = self.sender_embedding(data) inputs = input_embs.view( batch_size, self.config['num_digits'] * self.config['embedding_size_sender']) hx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) cx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) samples = [] log_probs = 0 post_probs = 0 for num in range(self.config['num_binary_messages']): hx, cx = self.sender_cell(inputs, (hx, cx)) output = self.sender_project(hx) pre_logits = self.sender_out(output) posterior_prob = torch.log_softmax(pre_logits, -1) sample = utils.gumbel_softmax(pre_logits, self.temperature[num], self.config['device']) samples.append(sample) maxz = torch.argmax(sample, dim=-1, keepdim=True) h_z = torch.zeros(sample.shape, device=self.config['device']).scatter_( -1, maxz, 1) prior_dst = dists.OneHotCategorical(logits=self.prior[num]) log_prob = prior_dst.log_prob(h_z).detach().cpu().numpy() log_probs += log_prob post_probs += posterior_prob[torch.arange(batch_size), maxz.squeeze()] samples = torch.stack(samples).permute(1, 0, 2) prior_prob = log_probs / self.config['num_binary_messages'] post_prob = post_probs.detach().cpu().numpy( ) / self.config['num_binary_messages'] return post_prob, prior_prob, samples
def generate(self, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None): decode_array = prior result_array = prior print(config) print(length) for i in Bar('generating').iter(range(length)): if decode_array.size(1) >= config.threshold_len: decode_array = decode_array[:, 1:] _, _, look_ahead_mask = \ utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token) # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask) # result, _ = decode_fn(decode_array, look_ahead_mask) result, _ = self.Decoder(decode_array, None) result = self.fc(result) result = result.softmax(-1) if tf_board_writer: tf_board_writer.add_image("logits", result, global_step=i) u = 0 if u > 1: result = result[:, -1].argmax(-1).to(decode_array.dtype) decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1) else: pdf = dist.OneHotCategorical(probs=result[:, -1]) print("pdf: " + str(pdf)) print("pdf shape: " + str(pdf.shape)) result = pdf.sample().argmax(-1).unsqueeze(-1) print("result shape: " + str(result.shape)) # result = torch.transpose(result, 1, 0).to(torch.int32) decode_array = torch.cat((decode_array, result), dim=-1) result_array = torch.cat((result_array, result), dim=-1) del look_ahead_mask result_array = result_array[0] return result_array
def __init__(self, dim, n_out=3, init_sigma=.15, init_bias=0., learn_G=False, learn_sigma=False, learn_bias=False): super().__init__() g = ig.Graph.Lattice(dim=[dim, dim], circular=True) # Boundary conditions A = np.asarray(g.get_adjacency().data) # g.get_sparse_adjacency() self.G = nn.Parameter(torch.tensor(A).float(), requires_grad=learn_G) self.sigma = nn.Parameter(torch.tensor(init_sigma).float(), requires_grad=learn_sigma) self.bias = nn.Parameter(torch.ones( (dim**2, n_out)).float() * init_bias, requires_grad=learn_bias) self.init_dist = dists.OneHotCategorical(logits=self.bias) self.dim = dim self.n_out = n_out self.data_dim = dim**2
def generate(self, prior, length=2048): decoded = prior outputs = prior for i in range(length): _, _, mask = get_masked_with_pad_tensor(decoded.size(1), decoded, decoded, self.pad_token) result, _ = self.Decoder(decoded, mask) result = self.fc(result) result = result.softmax(dim=-1) pdf = dist.OneHotCategorical(probs=result[:, -1]) result = pdf.sample().argmax(-1).unsqueeze(-1) decoded = torch.cat((decoded, result), dim=-1) outputs = torch.cat((outputs, result), dim=-1) del mask outputs = outputs[0] return outputs
def step(self, x, model): if len(self._inds) == 0: # ran out of inds self._inds = self._init_inds() inds = self._inds[:self.block_size] self._inds = self._inds[self.block_size:] logits = [] xs = [] for c in itertools.product(*([[0., 1.]] * len(inds))): xc = x.clone() c = torch.tensor(c).float().to(xc.device) xc[:, inds] = c l = model(xc).squeeze() xs.append(xc[:, :, None]) logits.append(l[:, None]) logits = torch.cat(logits, 1) xs = torch.cat(xs, 2) dist = dists.OneHotCategorical(logits=logits) choices = dist.sample() x_new = (xs * choices[:, None, :]).sum(-1) return x_new
import storch import torch import torch.distributions as td method1 = storch.method.Reparameterization method2 = storch.method.ScoreFunction method1 = method1(plate_name="1",n_samples=25) method2 = method2(plate_name="1",n_samples=25) p1 = td.Independent(td.Normal(loc=torch.zeros([1000, 2]), scale=torch.ones([1000, 2])), 0) p2 = td.Independent(td.OneHotCategorical(probs=torch.zeros([1000, 3]).uniform_()), 0) samp1 = method1(p1) samp2 = method2 (p2) # torch.Size([25, 1000, 2]) # torch.Size([25, 1000, 3]) print(storch.cat([samp1,samp2], 2).shape) # torch.Size([25, 1000, 5]) method1 = storch.method.Reparameterization method2 = storch.method.UnorderedSetEstimator method1 = method1(plate_name="1",n_samples=25) method2 = method2(plate_name="2",k=25) p1 = td.Independent(td.Normal(loc=torch.zeros([1000, 2]), scale=torch.ones([1000, 2])), 0) p2 = td.Independent(td.OneHotCategorical(probs=torch.zeros([1000, 3]).uniform_()), 0) samp1 = method1(p1) samp2 = method2 (p2) # torch.Size([25, 1000, 2]) # torch.Size([25, 1000, 3])