def forward(self, x, cold=False): batch_size = x.shape[0] # take fft (B, 1, width*height, 2) x_freq = torch.rfft(x, signal_ndim=2, onesided=False).view(batch_size, 1, self.input_size, 2) if cold: temperature = 0.0 else: temperature = self.temperature logits = self.logits.expand( (batch_size, self.output_size, self.input_size)) dist = RelaxedBernoulli(temperature=temperature, logits=logits) if cold: samples = dist.sample() else: samples = dist.rsample() # reshape so broadcasting works properly samples = samples.view(batch_size, self.output_size, self.input_size) # mask the frequencies (B, output_size, width*height, 2) sensed_freq = x_freq * samples # reshape for ifft sensed_freq = sensed_freq.view(-1, int(np.sqrt(self.input_size))) sensed_freq = sensed_freq.view(-1, self.resolution, self.resolution, 2) # (B*output_size, resolution, resolution) sensed_images = torch.irfft(sensed_freq, 2, normalized=False, onesided=False) sensed = torch.sum(sensed_images.view(batch_size, self.output_size, self.input_size), axis=-1) noise_scale = self.noise * torch.sqrt(sensed.detach()) sensed += torch.randn_like(sensed) * noise_scale return sensed
def init_prior(self): # Set up priors K = self.max_lag m = self.num_X rho = self.prior_rho_A sigma = self.prior_sigma_W temperature = torch.tensor([self.temperature], device=self.device) prior_A = rho * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) prior_A[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): prior_A[:, m + i, m + i] = torch.tensor([rho], device=self.device) self.prior_A = RelaxedBernoulli(temperature=temperature, probs=prior_A) prior_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) prior_W_scale[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): prior_W_scale[:, m + i, m + i] = torch.tensor([sigma], device=self.device) self.prior_W = Normal(loc=torch.zeros_like(prior_W_scale, device=self.device), scale=prior_W_scale)
def sample(self, x): logits = self.inference_net(x) # logits = torch.tanh(logits) * 4. logits = torch.clamp(logits, min=-8., max=4.) dist = RelaxedBernoulli(torch.Tensor([1.]).cuda(), logits=logits) z = dist.rsample() #[B,Z] z = torch.clamp(z, min=.00000001, max=.9999999) logqz = dist.log_prob(z.detach()) logqz = torch.sum(logqz,1) # z= z.cuda() # print (z.shape) # logqz = torch.sum( dist.log_prob(z.detach()), dim=1) #[B] # print (logqz.shape) # fasd # z, logqz = self.dist.sample(logits) return z, logits, logqz#, #logqz
def forward(self, x, rein_flag=False): h1 = self.Alice(x) if rein_flag: h_distribution = RelaxedBernoulli(self.temp, h1) h1_nograd = h_distribution.sample() h1_nograd = h1_nograd.float() out = self.Bob(h1_nograd) else: out = self.Bob(h1) return out, h1
def gumbel_sample(self, s): """ Reparametrisation of Bernoulli distribution s : [batch, output_dim] """ # s=F.relu(self.linear1(s)) # p_logits=self.linear2(s) p_logits = self.linear(s) # [batch, 1] action_dis = RelaxedBernoulli(temperature=0.8, logits=p_logits) action = action_dis.rsample() hard_action = (action > 0.5).float() return action + (hard_action - action).detach() #
def get_weight(self, x): if self.training: pi = self.sample_pi() temp = torch.Tensor([0.1]) if torch.cuda.is_available(): temp = temp.cuda() p_z = RelaxedBernoulli(temp, probs=pi) z = p_z.rsample(torch.Size([x.size(0)])) else: Epi = self.get_Epi() mask = self.get_mask(Epi=Epi) z = Epi * mask return z
def forward(self, z): map1 = self.l1(z) map1 = map1.view(map1.shape[0], 128, self.init_size) out = self.model(map1) img = RelaxedBernoulli(torch.tensor([opt.temp]).type(TENSOR), probs=out).rsample() return img
def __init__(self, w, p, l, temperature=0.1, validate_args=None): relaxed_bernoulli = RelaxedBernoulli(temperature, p) affine_transform = AffineTransform(w, l - w) super(ToeplitzBernoulliDistribution, self).__init__(relaxed_bernoulli, affine_transform, validate_args) self.relaxed_bernoulli = relaxed_bernoulli self.affine_transform = affine_transform
def forward(self, x, mask=None, loss=None): """ Args: x (torch.FloatTensor): class-wise representation. torch.Size([n_cls, feature_dim]) mask: torch.Size([n_cls, 1]) loss: torch.Size([n_cls, 1]) Returns: x (torch.FloatTensor): class-wise mask layout. torch.Size([n_cls, 1]) """ # generate states from the features at the first loop. if self.state is None: self.state = self.state_linear(x.detach()) # state = self.static_init_state(x.size(0)) if self.input_more: # detach from the graph mask = mask.detach() loss = loss.detach() # averaged and relative loss loss = loss / np.log(loss.size(0)) # scaling by log loss_mean = loss.mean().repeat(loss.size(0), 1) # averaged loss loss_rel = loss - loss_mean # relative loss loss_mean = self.preprocess(loss_mean).detach() loss_rel = self.preprocess(loss_rel).detach() step = self.step.repeat(loss.size(0), 1) x = torch.cat([x, mask, loss_mean, loss_rel, step], dim=1) # [n_cls , feature_dim + 1 + 2 + 2] else: step = self.step.repeat(loss.size(0), 1) x = torch.cat([x, step], dim=1) self.state = self.gru(x, self.state) # [n_cls , rnn_h_dim] x = self.out_linear(self.state) # [n_cls , 1] if self.output_more: mask = x[:, 0].unsqueeze(1) lr = (x[:, 1].mean() + self.c).exp() else: mask = x if self.sample_mode: mask = RelaxedBernoulli(self.temp, mask).rsample() else: mask = self.sigmoid(mask / self.temp) # 0.2 if self.output_more: return mask, lr else: return mask, None # [n_cls , 1]
def __init__(self, w, p, temperature=0.1, validate_args=None): relaxed_bernoulli = RelaxedBernoulli(temperature, p) affine_transform = AffineTransform(0, w) one_minus_p = AffineTransform(1, -1) super(BernoulliDropoutDistribution, self).__init__(relaxed_bernoulli, ComposeTransform([one_minus_p, affine_transform]), validate_args) self.relaxed_bernoulli = relaxed_bernoulli self.affine_transform = affine_transform
def forward(self, x, logits, cold=False): batch_size = x.shape[0] # x_local is (B, 1, resolution*resolution) x_local = x.view(batch_size, -1, self.input_size) if cold: temperature = 0.001 else: temperature = self.temperature # samples are (B, output_size, 1, resolution*resolution) dist = RelaxedBernoulli(temperature=temperature, logits=logits) if cold: samples = dist.sample() else: samples = dist.rsample() # now "mask" the pixels of the input spatially by elementwise multiply # masked_x is (B, output_size, C, resolution*resolution) masked_x = samples * x_local # get sensor values, shape of (B, output_size), summing across both channel and pixels sensed = masked_x.sum(axis=-1) return sensed
def init_posterior(self): # Set up posteriors K = self.max_lag m = self.num_X rho = self.prior_rho_A sigma = self.prior_sigma_W temperature = torch.tensor([self.temperature], device=self.device) estimate_A = rho * torch.rand(size=(K, 2 * m, 2 * m), device=self.device) estimate_A[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): estimate_A[:, m + i, m + i] = torch.tensor([rho], device=self.device) estimate_A = estimate_A.requires_grad_(True) self.posterior_A = RelaxedBernoulli(temperature=temperature, probs=estimate_A) estimate_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) estimate_W_scale[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): estimate_W_scale[:, m + i, m + i] = torch.tensor([sigma], device=self.device) estimate_W_scale = estimate_W_scale.requires_grad_(True) estimate_W_loc = torch.rand(size=(K, 2 * m, 2 * m), device=self.device) estimate_W_loc[:, m:, :] = torch.tensor([0], device=self.device) # # Set the diagonal # for i in range(m): # estimate_W_loc[:,m+i,m+i] = torch.rand(size=[1],device=self.device) estimate_W_loc = estimate_W_loc.requires_grad_(True) self.posterior_W = Normal(loc=estimate_W_loc, scale=estimate_W_scale)
def forward(self, a, b): """ Args: a, b (tensor): beta distribution parameters. Returns: z (tensor): Relaxed Beta-Bernoulli binary masks. pi (tensor): Bernoulli distribution parameters """ # a, b = self.get_params(x) pi = self.sample_pi(a, b) temp = C(torch.Tensor([0.001])) z = RelaxedBernoulli(temp, probs=pi).rsample() return z, pi
def forward(self, x, z_in): if len(x.size()) == 4: h = F.avg_pool2d(x, [x.size(2), x.size(3)]) h = h.view(h.size(0), -1) else: h = x h = self.bn(h.detach()) if self.training: sigma = F.softplus(self.sigma_uc) eps = torch.randn(self.num_gates) if torch.cuda.is_available(): eps = eps.cuda() h = h + sigma * eps temp = torch.Tensor([0.1]) if torch.cuda.is_available(): temp = temp.cuda() p_z = RelaxedBernoulli(temp, probs=h.clamp(1e-10, 1 - 1e-10)) z = p_z.rsample() else: z = h.clamp(1e-10, 1 - 1e-10) if len(x.size()) == 4: z = z.view(-1, self.num_gates, 1, 1) z_in = z_in.view(-1, self.num_gates, 1, 1) else: z_in = z_in.view(-1, self.num_gates) z = z * z_in if not self.training: num_active = (z > self.thres).float().sum(1).mean(0).item() self.num_active = (self.num_active*self.counter + num_active) / \ (self.counter + 1) self.counter += 1 z[z <= self.thres] = 0. return x * z
def forward(self, x): if self.gumbel: prior_exemplars = RelaxedBernoulli(self.t, logits=self.U).rsample() else: prior_exemplars = torch.sigmoid(self.U) prior_mu, prior_var = self.enc(prior_exemplars).chunk(2, -1) prior = Normal(prior_mu, (prior_var * 0.5).clamp(-5, 4).exp()) posterior_mu, posterior_var = self.enc(x).chunk(2, -1) posterior = Normal(posterior_mu, (posterior_var * 0.5).clamp(-5, 4).exp()) z = posterior.rsample() x_hat = self.dec(z) return {'prior': prior, 'posterior': posterior, 'z': z, 'x_hat': x_hat}
def get_weight(self, num_samps, training, samp_type='rel_ber'): temp = torch.Tensor([0.67]) if torch.cuda.is_available(): temp = temp.cuda() if training: pi = self.sample_pi() p_z = RelaxedBernoulli(temp, probs=pi) z = p_z.rsample(torch.Size([num_samps])) else: if samp_type == 'rel_ber': pi = self.sample_pi() p_z = RelaxedBernoulli(temp, probs=pi) z = p_z.rsample(torch.Size([num_samps])) elif samp_type == 'ber': pi = self.sample_pi() p_z = torch.distributions.Bernoulli(probs=pi) z = p_z.sample(torch.Size([num_samps])) return z, pi
def f(self, x, z, logits, hard=False): B = x.shape[0] # image likelihood given b # b = harden(z).detach() x_hat = self.generator.forward(z) alpha = torch.sigmoid(x_hat) beta = Beta(alpha*self.beta_scale, (1.-alpha)*self.beta_scale) x_noise = torch.clamp(x + torch.FloatTensor(x.shape).uniform_(0., 1./256.).cuda(), min=1e-5, max=1-1e-5) logpx = beta.log_prob(x_noise) #[120,3,112,112] # add uniform noise here logpx = torch.sum(logpx.view(B, -1),1) # [PB] * self.w_logpx # prior is constant I think # for q(b|x), we just want to increase its entropy if hard: dist = Bernoulli(logits=logits) else: dist = RelaxedBernoulli(torch.Tensor([1.]).cuda(), logits=logits) logqb = dist.log_prob(z.detach()) logqb = torch.sum(logqb,1) return logpx, logqb, alpha
grads.append(f(samp) * logprobgrad.numpy()) print ('Grad Estimator: REINFORCE H(z), temp=10') print ('Avg samp', np.mean(samps)) print ('Grad mean', np.mean(grads)) print ('Grad std', np.std(grads)) print () # REINFORCE but sampling p(z) dist_relaxedbern = RelaxedBernoulli(torch.Tensor([1.]), bern_param) dist_bern = Bernoulli(bern_param) samps = [] grads = [] for i in range(n): samp = dist_relaxedbern.sample() Hsamp = Hpy(samp) logprob = dist_bern.log_prob(Hsamp) logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0] samps.append(Hsamp.numpy()) grads.append(f(Hsamp.numpy()) * logprobgrad.numpy()) print ('Grad Estimator: REINFORCE but sampling p(z)')
grads.append((f(samp.numpy()) - 0.) * logprobgrad.numpy()) logprobgrads.append(logprobgrad.numpy()) # print (grads[:10]) print('Grad Estimator: REINFORCE') print('Avg samp', np.mean(samps)) print('Grad mean', np.mean(grads)) print('Grad std', np.std(grads)) print('Avg logprobgrad', np.mean(logprobgrads)) print('Std logprobgrad', np.std(logprobgrads)) print() # n=1000 # print (n) dist = RelaxedBernoulli(torch.Tensor([5.]), bern_param) samps = [] grads = [] logprobgrads = [] for i in range(n): samp = dist.sample() samp = torch.clamp(samp, min=.0000001, max=.99999999999999) # print (smp) logprob = dist.log_prob(samp) if logprob != logprob: print(samp, logprob)
ax = plt.subplot2grid((rows, cols), (row, col), frameon=False, colspan=1, rowspan=1) # sum_ = 0 # for i in range(20): # sum_+= i+5 # print (sum_) # fds 290 # print (m.log_prob(2.)) xs = np.linspace(.001, .999, 30) # dist = RelaxedBernoulli(temperature=torch.Tensor([1.]), logits=torch.tensor([0.])) dist = RelaxedBernoulli(temperature=torch.Tensor([.2]), probs=torch.tensor([0.3])) ys = [] samps = [] for x in xs: # samp = dist.sample() # samps.append(samp.data.numpy()[0]) # print (samp) # print (torch.exp(dist.log_prob(samp))) # print () # print (m.log_prob(x)) prob = torch.exp(dist.log_prob(torch.tensor([x]))).numpy()[0] # print (x) # print (prob) # component_i = ().numpy()[0]
def sample_lambda0_r(self, d, batch_size, offset=0, object_locations=None, object_margin=None, num_objects=None, gau=None, max_rejections=1000, margin_offset=2): """Sample dataset parameters perturbed by r.""" name = d['name'] family = d['family'] attr_name = '{}_{}'.format(name, 'center') if self.wn: lambda_r = self.normalize_weights(name=name, prop='center') elif family != 'half_normal': lambda_r = getattr(self, attr_name) parameters = [] if family == 'gaussian': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # lambda_r = transform_to(constraints.greater_than( # 1.))(lambda_r) # lambda_r_scale = transform_to(constraints.greater_than( # self.minimum_spatial_scale))(lambda_r_scale) # TODO: Add constraint function here # w=module.weight.data # w=w.clamp(0.5,0.7) # module.weight.data=w if gau is None: gau = MultivariateNormal(loc=lambda_r, covariance_matrix=lambda_r_scale) if d['return_sampler']: return gau if name == 'object_location': if not len(object_locations): return gau.rsample(), gau else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=gau) else: raise NotImplementedError(name) elif family == 'normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) nor = Normal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor elif name == 'object_location': # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale) # noqa if not len(object_locations): return nor.rsample(), nor else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=nor) else: for idx in range(batch_size): parameters.append(nor.rsample()) elif family == 'cnormal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # Explicitly clamp the scale! lambda_r_scale = torch.clamp(lambda_r_scale, self.minimum_spatial_scale, 999.) nor = CNormal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor elif name == 'object_location': # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale) # noqa if not len(object_locations): return nor.rsample(), nor else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=nor) else: for idx in range(batch_size): parameters.append(nor.rsample()) elif family == 'abs_normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # lambda_r = transform_to(Normal.arg_constraints['loc'])(lambda_r) # lambda_r_scale = transform_to(Normal.arg_constraints['scale'])(lambda_r_scale) # noqa # lambda_r = transforms.AbsTransform()(lambda_r) # lambda_r_scale = transforms.AbsTransform()(lambda_r_scale) # These kill grads!! # lambda_r = torch.abs(lambda_r) # These kill grads!! lambda_r_scale = torch.abs(lambda_r_scale) nor = Normal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor else: parameters = nor.rsample([batch_size]) elif family == 'half_normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) nor = HalfNormal(scale=lambda_r_scale) if d['return_sampler']: return nor else: parameters = nor.rsample([batch_size]) elif family == 'categorical': if d['return_sampler']: gum = RelaxedOneHotCategorical(1e-1, logits=lambda_r) return gum # return lambda sample_size: self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset # noqa for _ in range(batch_size): parameters.append( self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset) # noqa Use default temperature -> max elif family == 'relaxed_bernoulli': bern = RelaxedBernoulli(temperature=1e-1, logits=lambda_r) if d['return_sampler']: return bern else: parameters = bern.rsample([batch_size]) else: raise NotImplementedError( '{} not implemented in sampling.'.format(family)) return parameters
# sum_ = 0 # for i in range(20): # sum_+= i+5 # print (sum_) # fds 290 # print (m.log_prob(2.)) xs = np.linspace(.001,.999, 30) # dist = RelaxedBernoulli(temperature=torch.Tensor([1.]), logits=torch.tensor([0.])) dist = RelaxedBernoulli(temperature=torch.Tensor([.2]), probs=torch.tensor([0.3])) ys = [] samps = [] for x in xs: # samp = dist.sample() # samps.append(samp.data.numpy()[0]) # print (samp) # print (torch.exp(dist.log_prob(samp))) # print () # print (m.log_prob(x)) prob = torch.exp(dist.log_prob(torch.tensor([x]))).numpy()[0] # print (x) # print (prob) # component_i = ().numpy()[0]
class TimeLatent(object): _logger = logging.getLogger(__name__) def __init__(self, num_X, max_lag, num_samples, device, prior_rho_A, prior_sigma_W, temperature, sigma_Z, sigma_X): self.num_X = num_X self.max_lag = max_lag self.num_samples = num_samples self.device = device self.prior_rho_A = prior_rho_A self.temperature = temperature self.prior_sigma_W = prior_sigma_W self.prior_sigma_Z = sigma_Z * torch.ones(size=[num_X], device=device) self.likelihood_sigma_X = sigma_X * torch.ones(size=[num_X], device=device) self.posterior_sigma_Z = sigma_Z * torch.ones(size=[num_X], device=device) self.init_prior() self.init_posterior() self._logger.debug('Finished building model') def init_prior(self): # Set up priors K = self.max_lag m = self.num_X rho = self.prior_rho_A sigma = self.prior_sigma_W temperature = torch.tensor([self.temperature], device=self.device) prior_A = rho * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) prior_A[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): prior_A[:, m + i, m + i] = torch.tensor([rho], device=self.device) self.prior_A = RelaxedBernoulli(temperature=temperature, probs=prior_A) prior_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) prior_W_scale[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): prior_W_scale[:, m + i, m + i] = torch.tensor([sigma], device=self.device) self.prior_W = Normal(loc=torch.zeros_like(prior_W_scale, device=self.device), scale=prior_W_scale) # the proir over Z depends on the sample of A and W, so we don't set up Z here. def init_posterior(self): # Set up posteriors K = self.max_lag m = self.num_X rho = self.prior_rho_A sigma = self.prior_sigma_W temperature = torch.tensor([self.temperature], device=self.device) estimate_A = rho * torch.rand(size=(K, 2 * m, 2 * m), device=self.device) estimate_A[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): estimate_A[:, m + i, m + i] = torch.tensor([rho], device=self.device) estimate_A = estimate_A.requires_grad_(True) self.posterior_A = RelaxedBernoulli(temperature=temperature, probs=estimate_A) estimate_W_scale = sigma * torch.ones(size=(K, 2 * m, 2 * m), device=self.device) estimate_W_scale[:, m:, :] = torch.tensor([0], device=self.device) # Set the diagonal for i in range(m): estimate_W_scale[:, m + i, m + i] = torch.tensor([sigma], device=self.device) estimate_W_scale = estimate_W_scale.requires_grad_(True) estimate_W_loc = torch.rand(size=(K, 2 * m, 2 * m), device=self.device) estimate_W_loc[:, m:, :] = torch.tensor([0], device=self.device) # # Set the diagonal # for i in range(m): # estimate_W_loc[:,m+i,m+i] = torch.rand(size=[1],device=self.device) estimate_W_loc = estimate_W_loc.requires_grad_(True) self.posterior_W = Normal(loc=estimate_W_loc, scale=estimate_W_scale) def ln_p_AWZ(self, A, W, Z): K = self.max_lag m = self.num_X ln_p_A = self.prior_A.log_prob(A)[:, :m, :].sum() + sum([ torch.diagonal((self.prior_A.log_prob(A)[i, m:, m:]), 0) for i in range(K) ]).sum() ln_p_W = self.prior_W.log_prob(W)[:, :m, :].sum() + sum([ torch.diagonal((self.prior_W.log_prob(W)[i, m:, m:]), 0) for i in range(K) ]).sum() # store the distributions from pZ(1),pZ(2)....pZ(T) # p_Z = [] sigma_Z = self.prior_sigma_Z p_Z1 = Normal(loc=torch.zeros_like(sigma_Z, device=self.device), scale=sigma_Z) # p_Z.append(p_Z1) ln_p_Z1 = p_Z1.log_prob(Z[0]) ln_p_ZK = torch.zeros(size=[m], device=self.device) for t in range(2, K + 1): A_22 = A[:t - 1, m:, m:] W_22 = W[:t - 1, m:, m:] mean_t = [] for i in range(1, t): A_22_i = torch.diagonal(A_22[i - 1]) W_22_i = torch.diagonal(W_22[i - 1]) mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i) p_Zt = Normal(loc=sum(mean_t), scale=sigma_Z) # p_Z.append(p_Zt) ln_p_ZK += p_Zt.log_prob(Z[t - 1]) ln_p_ZT = torch.zeros(size=[m], device=self.device) T = self.num_samples A_22 = A[:, m:, m:] W_22 = W[:, m:, m:] mean_t = [] for i in range(1, K + 1): A_22_i = torch.diagonal(A_22[i - 1]) W_22_i = torch.diagonal(W_22[i - 1]) mean_t.append(Z[K - i:T - i] * A_22_i * W_22_i) p_Zt = Normal(loc=sum(mean_t), scale=sigma_Z) ln_p_ZT = p_Zt.log_prob(Z[K:]).sum(0) return (ln_p_ZT.sum() + ln_p_ZK.sum() + ln_p_Z1.sum()) + ln_p_A + ln_p_W def ln_p_X_AWZ(self, X, A, W, Z): sigma = self.likelihood_sigma_X T = self.num_samples K = self.max_lag m = self.num_X Sum_X_mu = torch.tensor([0.0], device=self.device) # t=1 Sum_X_mu += (X[0]**2).sum() # 2<=t<=K for t in range(2, K + 1): A_11 = A[:, :m, :m] W_11 = W[:, :m, :m] A_12 = A[:, :m, m:] W_12 = A[:, :m, m:] mu = torch.zeros(size=[m], device=self.device) for i in range(1, t): A_11_i = A_11[i - 1] W_11_i = W_11[i - 1] A_12_i = A_12[i - 1] W_12_i = W_12[i - 1] mu += torch.matmul(X[t - 1 - i], (A_11_i * W_11_i).t()) + torch.matmul( Z[t - 1 - i], (A_12_i * W_12_i).t()) Sum_X_mu += ((X[t - 1] - mu)**2).sum() # K+1 <= t <= T A_11 = A[:, :m, :m] W_11 = W[:, :m, :m] A_12 = A[:, :m, m:] W_12 = A[:, :m, m:] mu = [] for i in range(1, K + 1): A_11_i = A_11[i - 1] W_11_i = W_11[i - 1] A_12_i = A_12[i - 1] W_12_i = W_12[i - 1] mu.append( torch.matmul(X[K - i:T - i], (A_11_i * W_11_i).t()) + torch.matmul(Z[K - i:T - i], (A_12_i * W_12_i).t())) Sum_X_mu += ((X[K:T + 1] - sum(mu))**2).sum() return -T / 2 * torch.log(2 * torch.tensor( [np.math.pi], device=self.device)) - T / 2 * torch.log( sigma * sigma).sum() - 1 / (2 * (sigma * sigma).sum()) * Sum_X_mu def sample_Z(self, A, W): # we sample Z from q(Z) # store the distributions from qZ(1),qZ(2)....qZ(T) # q_Z = [] # sample Z(1) m = self.num_X sigma_Z = self.posterior_sigma_Z q_Z1 = Normal(loc=torch.zeros_like(sigma_Z, device=self.device), scale=sigma_Z) # q_Z.append(q_Z1) Z1 = q_Z1.rsample() ln_q_Z = torch.tensor([q_Z1.log_prob(Z1).sum()], device=self.device) # store the sample Z(1:T) T = self.num_samples Z = [] Z.append(Z1) # sample Z(2:K) K = self.max_lag for t in range(2, K + 1): A_22 = A[:t - 1, m:, m:] W_22 = W[:t - 1, m:, m:] mean_t = [] for i in range(1, t): A_22_i = torch.diagonal(A_22[i - 1]) W_22_i = torch.diagonal(W_22[i - 1]) mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i) q_Zt = Normal(loc=sum(mean_t), scale=sigma_Z) # q_Z.append(q_Zt) Z_t = q_Zt.rsample() ln_q_Z += q_Zt.log_prob(Z_t).sum() # Normalize Z_t, otherwise it will too large and leads to large Z(t) even NAN. Z_t = F.normalize(Z_t, dim=0) Z.append(Z_t) # sample Z(K+1:T) for t in range(K + 1, T + 1): A_22 = A[:, m:, m:] W_22 = W[:, m:, m:] mean_t = [] for i in range(1, K + 1): A_22_i = torch.diagonal(A_22[i - 1]) W_22_i = torch.diagonal(W_22[i - 1]) mean_t.append(Z[t - 1 - i] * A_22_i * W_22_i) q_Zt = Normal(loc=sum(mean_t), scale=sigma_Z) # q_Z.append(q_Zt) Z_t = q_Zt.rsample() ln_q_Z += q_Zt.log_prob(Z_t).sum() # Normalize Z_t, otherwise it will too large and leads to large Z(t) even NAN. Z_t = F.normalize(Z_t, dim=0) Z.append(Z_t) # self.q_Z = q_Z return torch.stack(Z, 0), ln_q_Z def ln_q_Z(self, Z): loss = torch.tensor([0.0], device=self.device) T = self.num_samples for i in range(T): Zt = Z[i] log_prob = self.q_Z[i].log_prob(Zt) loss += log_prob.sum() return loss def loss(self, X): """ return: the negative ELBO """ # sample A = self.posterior_A.rsample() W = self.posterior_W.rsample() # ln_q_Z = self.ln_q_Z(Z) Z, ln_q_Z = self.sample_Z( A, W ) # We calculate ln_q_Z when sample Z so that we don't need to save the q_Z, which reduces the running memory. # Because we assume the X won't cause Z and Zi and mutually independent, A_22 is always the diagonal matrix and A_21 is always a zero matrix. m = self.num_X K = self.max_lag ln_q_A = self.posterior_A.log_prob(A)[:, :m, :].sum() + sum([ torch.diagonal((self.posterior_A.log_prob(A)[i, m:, m:]), 0) for i in range(K) ]).sum() ln_q_W = self.posterior_W.log_prob(W)[:, :m, :].sum() + sum([ torch.diagonal((self.posterior_W.log_prob(W)[i, m:, m:]), 0) for i in range(K) ]).sum() ln_q_AWZ = ln_q_Z + ln_q_A + ln_q_W # Calculating L_kl L_kl = -(ln_q_AWZ - self.ln_p_AWZ(A, W, Z)) # Calculating L_ell L_ell = self.ln_p_X_AWZ(X, A, W, Z) ELBO = L_kl + L_ell # ELBO = L_ell # self._logger.info("ln_q_Z:{}, ln_q_A: {}, ln_q_W: {}, ln_p_AWZ(A,W,Z):{}, L_ell:{} ".format(ln_q_Z.item(), ln_q_A.item(), ln_q_W.item() ,self.ln_p_AWZ(A,W,Z).item(), L_ell.item())) loss = -ELBO return loss @property def logger(self): try: return self._logger except: raise NotImplementedError('self._logger does not exist!')
def forward(self, h_t, h_s, auxi_hs, memory_lengths=None, coverage=None): """ Args: auxi_hs = (FloatTensor): auxiliary output vectors ``(batch, src_len, dim)`` h_t (FloatTensor): query vectors ``(batch, tgt_len, dim)`` h_s (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if h_t.dim() == 2: one_step = True h_t = h_t.unsqueeze(1) else: one_step = False batch, source_l, dim = h_s.size() batch_, target_l, dim_ = h_t.size() aeq(batch, batch_) # Assert all the arguments have the same value. aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) h_s += self.linear_cover(cover).view_as(h_s) h_s = torch.tanh(h_s) # print('h_t',h_t,file=filename) '''Main Modification''' # Expand the target hidden state to concatenate with the source hidden state. H_t = h_t.expand(-1, source_l, -1) concat_h = torch.cat([auxi_hs, H_t], 2).view(batch * source_l, dim * 2) # (batch*source_l,dim*2) new_concat_h = self.linear_map(concat_h).view(batch, source_l, 1) # Calculate the probability of the ouput of auxiliary network. p = self.sigmoid(new_concat_h) # (batch,source_l,1) # Get the distribution of gate which follows the Bernoulli distribution with probability p. # For trainning: G = RelaxedBernoulli( torch.tensor([1]).cuda(), p).sample() # hyperparameter--temperature. (batch,source_l,1) # For testing: # G = Bernoulli(p) # e = MLP(h_s) to get the infomation of source hidden state. e = self.mlp_h(h_s) e = e.view(batch, source_l, 1) # Calculate the alignment score. align_score = (self.softmax(e, G)).transpose( 1, 2) # align_score--(batch,1,source) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align_score.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. # align_score.masked_fill_(~mask, -float('inf')) # the original one--it may cause nan. align_score.masked_fill_(~mask, -100000) # my settings. align_vectors = align_score / torch.sum( align_score, 2, keepdim=True) # alpha--(batch,1,source_l) print(align_vectors, file=filename) # Calculate the context vectors. c = torch.bmm(align_vectors, h_s) # context_vec--(batch,1,dim) # concatenate context vectors with the currenct target hidden state. concat_c = torch.cat([c, h_t], 2).view(batch * target_l, dim * 2) # Get the final output hidden state. attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # print('align_vectors', align_vectors.size()) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
zeros[np.arange(len(zeros)), samples] = 1. # print (zeros) # print (zeros.shape) samples = np.sum(zeros, axis=0) / n_samples # print (samples) ax = plt.subplot2grid((rows, cols), (cur_row, 0), frameon=False, colspan=2) ax.bar(['0', '1', '2', '3'], samples) ax.text(-.5, .5, s=r'Samples Hard Gumbel', fontsize=10, family='serif') cur_row += 1 #Plot gumbel pdf ax = plt.subplot2grid((rows, cols), (cur_row, 0), frameon=False, colspan=2) # plt.plot(x,y) dist = RelaxedBernoulli(probs=torch.Tensor([0.5]), temperature=torch.Tensor([0.5])) # samp = dist.sample() # print (samp) # logprob = dist.log_prob(samp) x = [[.01], [.1], [.3], [.5], [.7], [.9], [.99]] x_len = len(x) logprob = dist.log_prob(torch.Tensor(x)) logprob = np.reshape(logprob.numpy(), [x_len]) #[5] # print (samp, torch.exp(logprob)) x = np.reshape(np.array(x), [x_len]) plt.plot(x, logprob) cur_row += 1 # ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2) # plt.plot(x,np.exp(logprob))
def forward(self, mask_mode, feature_extraction_fn=None): assert isinstance(mask_mode, MaskMode) if feature_extraction_fn is not None: assert callable(feature_extraction_fn) output = feature_extraction_fn() pairwise_dist = output.pairwise_dist.detach() classwise_loss = output.classwise_loss.detach() classwise_acc = output.classwise_acc.detach() n_classes = output.n_classes # attention-based pairwse distance information atten = self.softmax(self.pairwse_attention(pairwise_dist)) # [n_classes, hidden_dim] p = self.tanh((pairwise_dist * atten).sum(dim=1)) # loss loss_class = classwise_loss / np.log(n_classes) loss_mean = loss_class.mean() loss_rel = (loss_class - loss_mean) / loss_class.std() # [n_classes] # acc acc_class = classwise_acc acc_mean = classwise_acc.mean() acc_rel = (acc_class - acc_mean) / acc_class.std() # [n_classes] # running mean of loss & acc loss_mean = loss_mean.repeat(len(self.m)) acc_mean = acc_mean.repeat(len(self.m)) if self.loss_mean is None: self.loss_mean = loss_mean # [n_momentum] self.acc_mean = acc_mean # [n_momentum] else: self.loss_mean = self.m * self.loss_mean + (1 - self.m) * loss_mean self.acc_mean = self.m * self.acc_mean + (1 - self.m) * acc_mean # time encoding self.t += 1 time = C(torch.tensor(self.t_encoder(self.t))) # shared feature encoding # log_loss_mean: take log to suppress too large losses # then add 1 to avoid negative infinity log_loss_mean = (self.loss_mean + 1).log() s = torch.cat([log_loss_mean, self.acc_mean, time], dim=0).detach() s = self.tanh(self.shared_encoder(s)).repeat(n_classes, 1) # s: [n_classes, hidden_dim] # relative feature encoding r = torch.stack([loss_rel, acc_rel], dim=1).detach() r = self.tanh(self.relative_encoder(r)) # r: [n_classes, hidden_dim] # mask generation (binary classification) h = torch.cat([p, s + r], dim=1) mask_logits = self.mask_generator(h) # learning rate generation h_mean = self.tanh(h.mean(dim=0)) lr_logits = self.lr_generator(h_mean) lr = F.softplus(lr_logits) * 0.1 ##################################################################### # on_off_style = ['softmax', 'sigmoid'][1] # TODO: global argument # soft mask if mask_mode.dist is MaskDist.SOFT: # if on_off_style == 'sigmoid': mask = self.sigmoid( mask_logits[:, 0]) # TODO: logit[:, 1] is redundant elif mask_mode.dist is MaskDist.RL: # elif on_off_style == 'softmax': mask = self.softmax(mask_logits) # This guy is for RL case # discrete mask ##################################################################### elif mask_mode.dist is MaskDist.DISCRETE: mask = lr_logits[:, 0].max(dim=1)[1] # TODO: logit[:, 1] is redundant # concrete mask elif mask_mode.dist is MaskDist.CONCRETE: # infer Bernoulli parameter mean = self.sigmoid(mask_logits[:, 0]) sigma = F.softplus(mask_logits[:, 1]) * 0.1 eps = torch.randn(mean.size()).to(mean.device) # continously relaxed Bernoulli probs = mean + (sigma * eps) temp = torch.tensor([0.1]).to(mean.device) mask = RelaxedBernoulli(temp, probs=probs) # mask = mask.rsample() if torch.isnan(mask.rsample()).sum() > 0: import pdb pdb.set_trace() return mask, lr
print() print ('REINFORCE H(z)') print ('Value:', val) print() # net = NN() optim = torch.optim.Adam([bern_param], lr=.004) # optim_NN = torch.optim.Adam([net.parameters()], lr=.0004) steps = [] losses4= [] for step in range(total_steps): dist = RelaxedBernoulli(torch.Tensor([1.]), logits=bern_param) optim.zero_grad() zs = [] for i in range(20): z = dist.rsample() zs.append(z) zs = torch.FloatTensor(zs).unsqueeze(1) logprob = dist.log_prob(zs.detach()) # logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0] H_z = Hpy(zs) # print (H_z)
samples = np.sum(zeros, axis=0) / n_samples # print (samples) ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2) ax.bar(['0','1','2','3'],samples) ax.text(-.5, .5, s=r'Samples Hard Gumbel', fontsize=10, family='serif') cur_row+=1 #Plot gumbel pdf ax = plt.subplot2grid((rows,cols), (cur_row,0), frameon=False, colspan=2) # plt.plot(x,y) dist = RelaxedBernoulli(probs=torch.Tensor([0.5]), temperature=torch.Tensor([0.5])) # samp = dist.sample() # print (samp) # logprob = dist.log_prob(samp) x = [[.01], [.1], [.3], [.5], [.7], [.9], [.99]] x_len = len(x) logprob = dist.log_prob(torch.Tensor(x)) logprob = np.reshape(logprob.numpy(), [x_len]) #[5] # print (samp, torch.exp(logprob)) x = np.reshape(np.array(x), [x_len]) plt.plot(x,logprob) cur_row+=1