def sample_simplax(probs): dist = RelaxedOneHotCategorical(probs=probs, temperature=torch.Tensor([1.])) z = dist.rsample() logprob = dist.log_prob(z) b = torch.argmax(z, dim=1) return z, b, logprob
def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits / 100., dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1) check_nan(logprob_cluster) logpxz = logprob_undercomponent( x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz # - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) return surr_loss
def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10 + cols, 4 + rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits / 100., dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9, 205, n_evals) x = torch.from_numpy(x1).view(n_evals, 1).float().cuda() z = z.repeat(n_evals, 1) cluster_H = cluster_H.repeat(n_evals, 1) xz = torch.cat([z, x], dim=1) logpxz = logprob_undercomponent( x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz #- logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col = 0 row = i # print (row) ax = plt.subplot2grid((rows, cols), (row, col), frameon=False, colspan=1, rowspan=1) ax.plot(x1, surr_pred, label='Surr') ax.plot(x1, f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir + 'gmm_surr.png' plt.savefig(plt_path) print('saved training plot', plt_path) plt.close()
def simplax(surrogate, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) outputs = {} net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq + surr_pred) # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred)) # grad_logq = torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2) surr_dif = torch.mean(torch.abs(f.detach() - surr_pred)) # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss/ k surr_loss = surr_loss/ k outputs['net_loss'] = net_loss outputs['f'] = f outputs['logpx_given_z'] = logpx_given_z outputs['logpz'] = logpz outputs['logq'] = logq outputs['surr_loss'] = surr_loss outputs['surr_dif'] = surr_dif outputs['grad_path'] = grad_path outputs['grad_score'] = grad_score return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
def forward(self, x): x = self.featCompressor(x) x = self.fc1(x) x = self.fc2(x) logits = self.fc3(x) B, L, K = logits.shape RelaxedOneHotSampler = RelaxedOneHotCategorical(float(self.temper), logits=logits) y = RelaxedOneHotSampler.rsample() return y, F.softmax(logits, dim=-1), logits
def simplax(surrogate, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) net_loss += -torch.mean((f.detach() - surr_pred.detach()) * logq + surr_pred) # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred)) # grad_logq = torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] surr_loss += torch.mean( ((f.detach() - surr_pred) * grad_logq + grad_surr)**2) surr_dif = torch.mean(torch.abs(f.detach() - surr_pred)) grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad( [torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss / k surr_loss = surr_loss / k return net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
def gumbel_softmax_dist(self, param, name, temperature=1e-1, hard=True, sample_size=()): """ST gumbel with pytorch distributions.""" gumbel = RelaxedOneHotCategorical(temperature, logits=param) y = gumbel.rsample(sample_size) if hard: # One-hot the y y = self.st_op(y) return y
def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9,205, n_evals) x = torch.from_numpy(x1).view(n_evals,1).float().cuda() z = z.repeat(n_evals,1) cluster_H = cluster_H.repeat(n_evals,1) xz = torch.cat([z,x], dim=1) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col =0 row = i # print (row) ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) ax.plot(x1,surr_pred, label='Surr') ax.plot(x1,f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_surr.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close()
def reparameterize(self, p_i, tau, k, num_sample=1): ## sampling p_i_ = p_i.view(p_i.size(0), 1, 1, -1) p_i_ = p_i_.expand(p_i_.size(0), num_sample, k, p_i_.size(-1)) C_dist = RelaxedOneHotCategorical(tau, p_i_) V = torch.max(C_dist.sample(), -2)[0] # [batch-size, multi-shot, d] ## without sampling V_fixed_size = p_i.unsqueeze(1).size() _, V_fixed_idx = p_i.unsqueeze(1).topk(k, dim=-1) # batch * 1 * k V_fixed = idxtobool(V_fixed_idx, V_fixed_size, is_cuda=self.args.cuda) V_fixed = V_fixed.type(torch.float) return V, V_fixed
def _gumbel_softmax(probs, tau: float, hard: bool): """ Computes sampling from the Gumbel Softmax (GS) distribution Args: probs (torch.tensor): probabilities of shape [batch_size, n_classes] tau (float): temperature parameter for the GS hard (bool): discretize if True """ rohc = RelaxedOneHotCategorical(tau, probs) y = rohc.rsample() if hard: y_hard = torch.zeros_like(y) y_hard.scatter_(-1, torch.argmax(y, dim=-1, keepdim=True), 1.0) y = (y_hard - y).detach() + y return y
def forward(self, inputs, lengths, temp=None, y=None): enc_emb = self.lookup(inputs) dec_emb = self.lookup(inputs) hn = self.encoder(enc_emb, lengths) py = self.classifier(hn) if y is None: dist = RelaxedOneHotCategorical(temp, logits=py) y = dist.sample().max(1)[1] y_emb = self.y_lookup(y) h = torch.cat([hn, y_emb.unsqueeze(0)], dim=2) mu, logvar = self.fcmu(h), self.fclogvar(h) if self.training: z = self.reparameterize(mu, logvar) else: z = mu code = torch.cat([z, y_emb.unsqueeze(0)], dim=2) outputs, _ = self.decoder(dec_emb, code, lengths=lengths) outputs = self.fcout(outputs) bow = self.bow_predictor(code) return outputs, mu, logvar, bow, py
def forward(self, y, c, m, useGumbel=True, temp=0.5): # TODO: add m to all things calling this. # y is the one-hot input (goal is to predict next output) # c is one-hot representing context # useGumbel=True, then uses Gumbel instead of softmax # embed in N-dim x_context = torch.matmul(self.Wfix_context, c) x_state = torch.matmul(self.Wfix, y) # concatenate token and context x = torch.cat([x_context, x_state, m]) # TODO: confirm works # hidden node b = torch.tanh(self.fc1(x)) # ---- STATE if useGumbel: # will not get gradients for fc3 if do this. h = torch.tanh(self.fc2(b)) # linear layer + nonlinearity z = Gumbel(torch.tensor([0.0]), torch.tensor([1.0])).sample(torch.Size((h.shape[0],))) # add gumbel noise yind = (h.view((-1,)) + z.view((-1,))).argmax() # take argmax yout = self.idx_to_onehot(yind, self.outdim) # convert to onehot # ---- CONTEXT h_context = torch.tanh(self.fc3(b)) z_context = Gumbel(torch.tensor([0.0]), torch.tensor([1.0])).sample(torch.Size((h_context.shape[0],))) # add gumbel noise c_ind = (h_context.view((-1,)) + z_context.view((-1,))).argmax() # take argmax c_out = self.idx_to_onehot(c_ind, self.Kc) else: h = torch.tanh(self.fc2(b)) # linear layer + nonlinearity yout = RelaxedOneHotCategorical(temp, logits=h).rsample() # yout = self.idx_to_onehot(yind, self.outdim) # convert to onehot yind = [] # ---- CONTEXT h_context = torch.tanh(self.fc3(b)) c_out = RelaxedOneHotCategorical(temp, logits=h_context).rsample() c_ind = [] m = b # rename as m return yout, h, yind, c_out, h_context, c_ind, m
def reinforce_pz(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. net_loss += - torch.mean((f.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss/ k return net_loss, f, logpx_given_z, logpz, logq
def gumbel_softmax(self, logits, training, tau=1.0, msg_hard=None): device = torch.device("cuda" if logits.is_cuda else "cpu") if training: # Here, Gumbel sample is taken: msg_dists = RelaxedOneHotCategorical(tau, logits=logits) msg = msg_dists.rsample() if msg_hard is None: msg_hard = torch.zeros_like(msg, device=device) msg_hard.scatter_(-1, torch.argmax(msg, dim=-1, keepdim=True), 1.0) # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable msg = (msg_hard - msg).detach() + msg else: if msg_hard is None: msg = torch.zeros_like(logits, device=self.device) msg.scatter_(-1, torch.argmax(logits, dim=-1, keepdim=True), 1.0) else: msg = msg_hard return msg
def reinforce_pz(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. net_loss += -torch.mean((f.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss / k return net_loss, f, logpx_given_z, logpz, logq
def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(f.detach()-surr_pred)) return surr_loss
def cat_softmax(probs, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(probs=probs) return cat_distr.sample(), cat_distr.entropy() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, probs=probs) y_soft = cat_distr.rsample() if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(probs, device=DEVICE).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret, ret
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(logits=logits) return cat_distr.sample() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, logits=logits) y_soft = cat_distr.rsample() elif mode == 'SOFTMAX': y_soft = F.softmax(logits, dim=1) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, device=args.device).scatter_( dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret
def reparameterize(self, p_pep, p_tcr, tau, k, num_sample): # sampling batch_size = p_pep.size(0) len_pep = p_pep.size(1) # batch_size * len_pep len_tcr = p_tcr.size(1) # batch_size * len_tcr p_pep_ = p_pep.view(batch_size, 1, 1, len_pep).expand(batch_size, num_sample, k, len_pep) p_tcr_ = p_tcr.view(batch_size, 1, 1, len_tcr).expand(batch_size, num_sample, k, len_tcr) C_pep = RelaxedOneHotCategorical(tau, p_pep_) C_tcr = RelaxedOneHotCategorical(tau, p_tcr_) Z_pep, _ = torch.max(C_pep.sample(), -2) # batch_size, num_sample, len_pep Z_tcr, _ = torch.max(C_tcr.sample(), -2) # batch_size, num_sample, len_tcr # without sampling _, Z_fixed_pep = p_pep.topk(k, dim = -1) # batch_size, k _, Z_fixed_tcr = p_tcr.topk(k, dim = -1) # batch_size, k size_pep = p_pep.size() size_tcr = p_tcr.size() Z_fixed_pep = idxtobool(Z_fixed_pep, size_pep, self.cuda) Z_fixed_tcr = idxtobool(Z_fixed_tcr, size_tcr, self.cuda) return Z_pep, Z_tcr, Z_fixed_pep, Z_fixed_tcr
def sample_lambda0_r(self, d, batch_size, offset=0, object_locations=None, object_margin=None, num_objects=None, gau=None, max_rejections=1000, margin_offset=2): """Sample dataset parameters perturbed by r.""" name = d['name'] family = d['family'] attr_name = '{}_{}'.format(name, 'center') if self.wn: lambda_r = self.normalize_weights(name=name, prop='center') elif family != 'half_normal': lambda_r = getattr(self, attr_name) parameters = [] if family == 'gaussian': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # lambda_r = transform_to(constraints.greater_than( # 1.))(lambda_r) # lambda_r_scale = transform_to(constraints.greater_than( # self.minimum_spatial_scale))(lambda_r_scale) # TODO: Add constraint function here # w=module.weight.data # w=w.clamp(0.5,0.7) # module.weight.data=w if gau is None: gau = MultivariateNormal(loc=lambda_r, covariance_matrix=lambda_r_scale) if d['return_sampler']: return gau if name == 'object_location': if not len(object_locations): return gau.rsample(), gau else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=gau) else: raise NotImplementedError(name) elif family == 'normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) nor = Normal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor elif name == 'object_location': # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale) # noqa if not len(object_locations): return nor.rsample(), nor else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=nor) else: for idx in range(batch_size): parameters.append(nor.rsample()) elif family == 'cnormal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # Explicitly clamp the scale! lambda_r_scale = torch.clamp(lambda_r_scale, self.minimum_spatial_scale, 999.) nor = CNormal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor elif name == 'object_location': # nor.arg_constraints['scale'] = constraints.greater_than(self.minimum_spatial_scale) # noqa if not len(object_locations): return nor.rsample(), nor else: parameters = self.rejection_sampling( object_margin=object_margin, margin_offset=margin_offset, object_locations=object_locations, max_rejections=max_rejections, num_objects=num_objects, gau=nor) else: for idx in range(batch_size): parameters.append(nor.rsample()) elif family == 'abs_normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) # lambda_r = transform_to(Normal.arg_constraints['loc'])(lambda_r) # lambda_r_scale = transform_to(Normal.arg_constraints['scale'])(lambda_r_scale) # noqa # lambda_r = transforms.AbsTransform()(lambda_r) # lambda_r_scale = transforms.AbsTransform()(lambda_r_scale) # These kill grads!! # lambda_r = torch.abs(lambda_r) # These kill grads!! lambda_r_scale = torch.abs(lambda_r_scale) nor = Normal(loc=lambda_r, scale=lambda_r_scale) if d['return_sampler']: return nor else: parameters = nor.rsample([batch_size]) elif family == 'half_normal': attr_name = '{}_{}'.format(name, 'scale') if self.wn: lambda_r_scale = self.normalize_weights(name=name, prop='scale') else: lambda_r_scale = getattr(self, attr_name) nor = HalfNormal(scale=lambda_r_scale) if d['return_sampler']: return nor else: parameters = nor.rsample([batch_size]) elif family == 'categorical': if d['return_sampler']: gum = RelaxedOneHotCategorical(1e-1, logits=lambda_r) return gum # return lambda sample_size: self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset # noqa for _ in range(batch_size): parameters.append( self.argmax(self.gumbel_fun(lambda_r, name=name)) + offset) # noqa Use default temperature -> max elif family == 'relaxed_bernoulli': bern = RelaxedBernoulli(temperature=1e-1, logits=lambda_r) if d['return_sampler']: return bern else: parameters = bern.rsample([batch_size]) else: raise NotImplementedError( '{} not implemented in sampling.'.format(family)) return parameters
surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2) train_ = 1 n_steps = 1000 #0 #0 #1000 #50000 # B = 1 #32 #0 k = 3 if train_: optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7) #Train surrogate for i in range(n_steps + 1): warmup = 1. cat = RelaxedOneHotCategorical(logits=logits.repeat(B, 1), temperature=torch.tensor([1.])) z = cat.rsample() logprob = cat.log_prob(z.detach()).view(B, 1) b = H(z) reward = f(b).view(B, 1) cz = surrogate.net(z) # estimator = (reward - cz) * logprob + cz # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0] gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0] gradcz = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
def forward(self, t, word_counts, tau=1.2): batch_size = t.shape[0] if self.training: message = [ torch.zeros((batch_size, self.vocab_size), dtype=torch.float32) ] if self.use_gpu: message[0] = message[0].cuda() message[0][:, self.bound_token_idx] = 1.0 else: message = [ torch.full((batch_size, ), fill_value=self.bound_token_idx, dtype=torch.int64) ] if self.use_gpu: message[0] = message[0].cuda() # h0, c0 h = self.aff_transform(t) # batch_size, hidden_size c = torch.zeros([batch_size, self.hidden_size]) initial_length = self.max_sentence_length + 1 seq_lengths = torch.ones([batch_size], dtype=torch.int64) * initial_length ce_loss = nn.CrossEntropyLoss(reduction='none') # Handle alpha by giving weight to the padding token w_counts = word_counts.clone() # Tensor is passed by ref w_counts[self.bound_token_idx] *= self.bound_weight denominator = w_counts.sum() if denominator > 0: normalized_word_counts = w_counts / denominator else: normalized_word_counts = w_counts vl_loss = 0.0 entropy = 0.0 if self.use_gpu: c = c.cuda() seq_lengths = seq_lengths.cuda() input_embed_rep = [] for i in range(self.max_sentence_length ): # or sampled <EOS>, but this is batched emb = torch.matmul( message[-1], self.embedding ) if self.training else self.embedding[message[-1]] h, c = self.lstm_cell(emb, (h, c)) vocab_scores = self.linear_probs(h) p = F.softmax(vocab_scores, dim=1) entropy += Categorical(p).entropy() if self.training: rohc = RelaxedOneHotCategorical(tau, p) token = rohc.rsample() # Straight-through part token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: if self.greedy: _, token = torch.max(p, -1) else: token = Categorical(p).sample() message.append(token) input_embed_rep.append(emb) self._calculate_seq_len(seq_lengths, token, initial_length, seq_pos=i + 1) if self.vl_loss_weight > 0.0: vl_loss += ce_loss(vocab_scores - normalized_word_counts, self._discretize_token(token)) return (torch.stack(message, dim=1), seq_lengths, vl_loss, torch.mean(entropy) / self.max_sentence_length, torch.stack(input_embed_rep, dim=1))
cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) col =0 row = 0 ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) n_cats = 2 # needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True) # needsoftmax_mixtureweight = torch.tensor([], requires_grad=True) # weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() theta = .99 weights = torch.tensor([1-theta,theta], requires_grad=True).float() cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.])) val = 1. val2 = 0 val3 = 0 cmap='Blues' alpha =1. xlimits=[val3, val] ylimits=[val2, val] numticks = 51 x = np.linspace(*xlimits, num=numticks) y = np.linspace(*ylimits, num=numticks) X, Y = np.meshgrid(x, y) aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T aaa = torch.tensor(aaa).float() logprob = cat.log_prob(aaa)
# needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True) # needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) weights = torch.softmax(needsoftmax_mixtureweight, dim=0) # cat = Categorical(probs= weights) avg_samp = torch.zeros(n_cats) grads = [] logprobgrads = [] momem = 0 for step in range(max_steps): weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.])) cluster_S = cat.sample() logprob = cat.log_prob(cluster_S.detach()) cluster_H = H(cluster_S) # logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(needsoftmax_mixtureweight), retain_graph=False)[0] one_hot = torch.zeros(n_cats) one_hot[cluster_H] = 1. f_val = f(one_hot) # grad = f_val * logprobgrad # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad grad = f_val * logprobgrad
def forward(self, image, question=None, feature_saving=False, cut_image_info=False, cut_question_info=False, ground_truth=None, save_messages=False, save_features=False): self.use_gpu = torch.cuda.is_available() if self.use_gpu: device = 'cuda' else: device = 'cpu' assert not (feature_saving and self.training) self.batch_size = image.shape[0] im_features = [] if not self.train_from_symbolic: for i in range(image.shape[1]): im_features.append(self.stem_conv(image[:, i, :, :, :]).view(self.batch_size, -1, 1)) image_info = im_features[0] candidate_im_features = im_features[1:] else: candidate_im_features = [] ground_truth = ground_truth.type(torch.FloatTensor).to(device) image_info = ground_truth[:, 0, :] for i in range(image.shape[1]-1): candidate_im_features.append(self.bottleneck_in_fc(ground_truth[:, i+1, :]).view(self.batch_size, -1, 1)) #detaching gradients candidate_im_features = [candidate.detach() for candidate in candidate_im_features] if self.bottleneck: if self.training: message = [torch.zeros((self.batch_size, self.vocab_size), dtype=torch.float32)] if self.use_gpu: message[0] = message[0].cuda() message[0][:, self.bound_token_idx] = 1.0 else: message = [torch.full((self.batch_size,), fill_value=self.bound_token_idx, dtype=torch.int64)] if self.use_gpu: message[0] = message[0].cuda() # h0, c0 flattened_im_features = image_info.view(self.batch_size, -1) sender_representation = self.drop(self.bottleneck_in_fc(flattened_im_features)) h = sender_representation c = torch.zeros((self.batch_size, self.message_lstm_hidden_size)) if self.use_gpu: c = c.cuda() entropy = 0.0 # produce words one by one for i in range(self.max_sentence_length): emb = torch.matmul(message[-1], self.message_embedding) if self.training else self.message_embedding[message[-1]] h, c = self.lstm_cell(emb, (h, c)) vocab_scores = self.drop(self.hidden2vocab(h)) p = F.softmax(vocab_scores, dim=1) entropy += Categorical(p).entropy() if self.training: rohc = RelaxedOneHotCategorical(self.tau, p) token = rohc.rsample() # Straight-through part if not self.continuous_communication: token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: if self.greedy: _, token = torch.max(p, -1) else: token = Categorical(p).sample() message.append(token) message = torch.stack(message, dim=1) if self.training: _, m = torch.max(message, dim=-1) else: m = message md = calc_message_distinctness(m) # If we feed the ground_truth to the receiver, we simply hijack the message here if self.use_ground_truth: if self.training: message = batch_2_onehot(ground_truth, self.max_sentence_length, self.vocab_size) else: message = ground_truth.type(torch.LongTensor) if self.use_gpu: message = message.cuda() # Receiver part h = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size)) c = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size)) if self.use_gpu: h = h.cuda() c = c.cuda() emb = torch.matmul(message, self.message_embedding) if self.training else self.message_embedding[message] _, (h, c) = self.message_encoder_lstm(emb, (h[None, ...], c[None, ...])) hidden_receiver = h[0] bottleneck_out = self.drop(self.aff_transform(hidden_receiver)) image_info = bottleneck_out #todo recomment # comm_info = {'entropy': torch.mean(entropy).item() / (self.max_sentence_length+ 1e-7), 'md': md} comm_info = {'entropy': 0, 'md': 0} if save_features: comm_info['image_features'] = flattened_im_features comm_info['sender_repr'] = sender_representation comm_info['receiver_repr'] = hidden_receiver if save_messages: comm_info['message'] = m # in case no communication bottleneck else: comm_info = None orig_im_features = image_info.view(self.batch_size, 1, -1) out = torch.zeros(self.batch_size, len(candidate_im_features)).type(torch.FloatTensor) if self.use_gpu: out = out.cuda() for i in range(len(candidate_im_features)): out[:, i] = torch.bmm(orig_im_features, candidate_im_features[i]).squeeze() return out, comm_info
surr_loss = get_loss() optim_surr.zero_grad() surr_loss.backward() optim_surr.step() if ii%1000==0: print (ii, surr_loss) for step in range(n_steps): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits/100., dim=1) # print (probs) # fsdafsa cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) net_loss = 0 loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz #- logprob_cluster
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cat_bernoulli = Categorical(probs=probs) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq_z = cat.log_prob(cluster_S.detach()).view(B, 1) logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f_z = logpxz - logq_z - 1. f_b = logpxz - logq_b - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] # surr_pred, alpha = surrogate.net(surr_input) surr_pred = surrogate.net(surr_input) alpha = torch.sigmoid(surrogate2.net(x)) net_loss += -torch.mean(alpha.detach() * (f_z.detach() - surr_pred.detach()) * logq_z + alpha.detach() * surr_pred + (1 - alpha.detach()) * (f_b.detach()) * logq_b) # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred)) grad_logq_z = torch.mean(torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq_b = torch.mean(torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_surr = torch.mean(torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape) # fsdfa # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0] # print (grad_surr) # fsdfasd surr_loss += torch.mean( (alpha * (f_z.detach() - surr_pred) * grad_logq_z + alpha * grad_surr + (1 - alpha) * (f_b.detach()) * grad_logq_b)**2) surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred)) # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0] # print (gradd) # fdsf grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad( [torch.mean( (f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss / k surr_loss = surr_loss / k return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean( alpha)
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cat_bernoulli = Categorical(probs=probs) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq_z = cat.log_prob(cluster_S.detach()).view(B,1) logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f_z = logpxz - logq_z - 1. f_b = logpxz - logq_b - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] # surr_pred, alpha = surrogate.net(surr_input) surr_pred = surrogate.net(surr_input) alpha = torch.sigmoid(surrogate2.net(x)) net_loss += - torch.mean( alpha.detach()*(f_z.detach() - surr_pred.detach()) * logq_z + alpha.detach()*surr_pred + (1-alpha.detach())*(f_b.detach() ) * logq_b) # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred)) grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq_b = torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape) # fsdfa # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0] # print (grad_surr) # fsdfasd surr_loss += torch.mean( (alpha*(f_z.detach() - surr_pred) * grad_logq_z + alpha*grad_surr + (1-alpha)*(f_b.detach()) * grad_logq_b )**2 ) surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred)) # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0] # print (gradd) # fdsf grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss/ k surr_loss = surr_loss/ k return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
col = 0 row = 0 ax = plt.subplot2grid((rows, cols), (row, col), frameon=False, colspan=1, rowspan=1) n_cats = 2 # needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True) # needsoftmax_mixtureweight = torch.tensor([], requires_grad=True) # weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() theta = .99 weights = torch.tensor([1 - theta, theta], requires_grad=True).float() cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.])) val = 1. val2 = 0 val3 = 0 cmap = 'Blues' alpha = 1. xlimits = [val3, val] ylimits = [val2, val] numticks = 51 x = np.linspace(*xlimits, num=numticks) y = np.linspace(*ylimits, num=numticks) X, Y = np.meshgrid(x, y) aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T aaa = torch.tensor(aaa).float() logprob = cat.log_prob(aaa)
def forward(self, instruction=None, observation=None, memory=None, compute_message_probs=False, time=None): if not hasattr(self, 'random_corrector'): self.random_corrector = False if not hasattr(self, 'var_len'): self.var_len = False if not hasattr(self, 'script'): self.script = False if not self.script: memory_rnn_output, memory = self.forward_film( instruction=instruction, observation=observation, memory=memory) batch_size = instruction.size(0) correction_encodings = [] entropy = 0.0 lengths = np.array([self.corr_length] * batch_size) total_corr_loss = 0 for i in range(self.corr_length): if i == 0: # every message starts with a SOS token decoder_input = torch.tensor([self.sos_id] * batch_size, dtype=torch.long, device=self.device) decoder_input_embedded = self.word_embedding_corrector( decoder_input).unsqueeze(1) decoder_hidden = memory_rnn_output.unsqueeze(0) if self.random_corrector: # randomize corrections device = torch.device( "cuda" if decoder_input_embedded.is_cuda else "cpu") decoder_input_embedded = torch.randn( decoder_input_embedded.size(), device=device) decoder_hidden = torch.randn(decoder_hidden.size(), device=device) rnn_output, decoder_hidden = self.decoder_rnn( decoder_input_embedded, decoder_hidden) vocab_scores = self.out(rnn_output) vocab_probs = F.softmax(vocab_scores, dim=-1) entropy += Categorical(vocab_probs).entropy() tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) + self.max_tau) tau = tau.expand(-1, self.num_embeddings).unsqueeze(1) if self.training: # Apply Gumbel SM cat_distr = RelaxedOneHotCategorical(tau, vocab_probs) corr_weights = cat_distr.rsample() corr_weights_hard = torch.zeros_like(corr_weights, device=self.device) corr_weights_hard.scatter_( -1, torch.argmax(corr_weights, dim=-1, keepdim=True), 1.0) # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable corr_weights = (corr_weights_hard - corr_weights).detach() + corr_weights else: # greedy sample corr_weights = torch.zeros_like(vocab_probs, device=self.device) corr_weights.scatter_( -1, torch.argmax(vocab_probs, dim=-1, keepdim=True), 1.0) if self.var_len: # consider sequence done when eos receives highest value max_idx = torch.argmax(corr_weights, dim=-1) eos_batches = max_idx.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > i) & eos_batches) != 0 lengths[update_idx] = i # compute correction error through pseudo-target: sequence of eos symbols to encourage short messages pseudo_target = torch.tensor( [self.eos_id for j in range(batch_size)], dtype=torch.long, device=self.device) loss = self.correction_loss(corr_weights.squeeze(1), pseudo_target) total_corr_loss += loss correction_encodings += [corr_weights] decoder_input_embedded = torch.matmul( corr_weights, self.word_embedding_corrector.weight) # one-hot vectors on forward, soft approximations on backward pass correction_encodings = torch.stack(correction_encodings, dim=1).squeeze(2) lengths = torch.tensor(lengths, dtype=torch.long, device=self.device) result = { 'correction_encodings': correction_encodings, 'correction_messages': self.decode_corrections(correction_encodings), 'correction_entropy': torch.mean(entropy), 'corrector_memory': memory, 'correction_lengths': lengths, 'correction_loss': total_corr_loss } else: # there is a script of pre-established guidance messages correction_messages = self.script[time] correction_encodings = self.encode_corrections(correction_messages) result = { 'correction_encodings': correction_encodings, 'correction_messages': correction_messages } return (result)
def forward(self, obs, memory, instr_embedding=None, tau=1.2): # Calculating instruction embedding if self.use_instr and instr_embedding is None: if self.lang_model == 'gru': _, hidden = self.instr_rnn(self.word_embedding(obs.instr)) instr_embedding = hidden[-1] #Calculating the image imedding x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3) if self.arch.startswith("expert_filmcnn"): image_embedding = self.image_conv(x) #Calculating FiLM_embedding from image and instruction embedding for controler in self.controllers: x = controler(image_embedding, instr_embedding) FiLM_embedding = F.relu(self.film_pool(x)) else: FiLM_embedding = self.image_conv(x) FiLM_embedding = FiLM_embedding.reshape(FiLM_embedding.shape[0], -1) #Going through the memory layer if self.use_memory: hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) hidden = self.memory_rnn(FiLM_embedding, hidden) embedding = hidden[0] memory = torch.cat(hidden, dim=1) else: embedding = x if self.use_instr and not "filmcnn" in self.arch: embedding = torch.cat((embedding, instr_embedding), dim=1) if hasattr(self, 'aux_info') and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() memory_rnn_output = embedding batch_size = memory_rnn_output.shape[0] message = [] for i in range(self.message_length): if i == 0: decoder_input = torch.tensor([self.sos_id] * batch_size, dtype=torch.long, device=self.device) decoder_input_embedded = self.word_embedding_decoder( decoder_input).unsqueeze(1) decoder_hidden = memory_rnn_output.unsqueeze(0) decoder_out, decoder_hidden = self.decoder_rnn( decoder_input_embedded, decoder_hidden) vocab_scores = self.hidden2word(decoder_out) vocab_probs = F.softmax(vocab_scores, -1) tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) + self.max_tau) tau = tau.expand(-1, self.vocab_size).unsqueeze(1) if self.training: rohc = RelaxedOneHotCategorical(tau, vocab_probs) token = rohc.rsample() # Straight-through part token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: token = torch.zeros_like(vocab_probs, device=self.device) token.scatter_(-1, torch.argmax(vocab_probs, dim=-1, keepdim=True), 1.0) message.append(token) decoder_input_embedded = torch.matmul( token, self.word_embedding_decoder.weight) comm = torch.stack(message, dim=1).squeeze(2) return comm, memory
print('Grad mean', np.mean(grads, axis=0)) print('Grad std', np.std(grads, axis=0)) print('Avg logprobgrad', np.mean(logprobgrads, axis=0)) print('Std logprobgrad', np.std(logprobgrads, axis=0)) print() #REINFORCE P(Z) # needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True).float() # needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) weights = torch.softmax(needsoftmax_mixtureweight, dim=0) # cat = Categorical(probs= weights) cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.])) #.cuda()) # dist = Bernoulli(bern_param) # samps = [] avg_samp = torch.zeros(n_cats) gradspz = [] logprobgrads = [] for i in range(n): # samp = dist.sample() cluster_S = cat.sample() logprob = cat.log_prob(cluster_S.detach()) cluster_H = H(cluster_S) # one_hot = torch.zeros(n_cats) one_hot[cluster_H] = 1. # logprob = dist.log_prob(samp.detach())
optim_surr.step() # x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) probs2 = probs.cpu().data.numpy()[0] print() for iii in range(len(probs2)): print(str(iii)+':'+str(probs2[iii]), end =" ") print () # print (probs.shape) print (probs) # fsdf cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) net_loss = 0 loss = 0 surr_loss = 0 grads = 0 for jj in range(k): cluster_S = cat.rsample() # print (cluster_S.shape) # dfsafd # cluster_onehot = torch.zeros(n_components).cuda() # cluster_onehot[jj%20] =1. # cluster_S = torch.softmax(cluster_onehot, dim=0).view(1,20)
for step in range(n_steps): optim.zero_grad() loss = 0 net_loss = 0 surr_loss = 0 for i in range(batch_size): x = sample_true().cuda().view(1, 1) logits = encoder.net(x) # print (logits.shape) # print (torch.softmax(logits, dim=1)) # fasd # cat = Categorical(probs= torch.softmax(logits, dim=0)) cat = RelaxedOneHotCategorical(probs=torch.softmax(logits, dim=1), temperature=torch.tensor([1. ]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. # print (cluster_onehot) # print (cluster_H) # print (cluster_S) # fdsa logprob_cluster = cat.log_prob(cluster_S.detach()) if logprob_cluster != logprob_cluster: print('nan') # print (logprob_cluster) pxz = logprob_undercomponent( x,
def H(soft): return torch.argmax(soft, dim=1) surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2) train_ = 1 n_steps = 1000#0 #0 #1000 #50000 # B = 1 #32 #0 k=3 if train_: optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7) #Train surrogate for i in range(n_steps+1): warmup = 1. cat = RelaxedOneHotCategorical(logits=logits.repeat(B,1), temperature=torch.tensor([1.])) z = cat.rsample() logprob = cat.log_prob(z.detach()).view(B,1) b = H(z) reward = f(b).view(B,1) cz = surrogate.net(z) # estimator = (reward - cz) * logprob + cz # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0] gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0] gradcz = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0] # print (reward.shape, cz.shape, gradlogprob.shape, gradcz.shape) # fdasf # grad = (reward-cz) *gradlogprob + gradcz
lr = .002 # needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True) # needsoftmax_mixtureweight = torch.randn(n_cats, requires_grad=True) weights = torch.softmax(needsoftmax_mixtureweight, dim=0) # cat = Categorical(probs= weights) avg_samp = torch.zeros(n_cats) grads = [] logprobgrads = [] momem = 0 for step in range(max_steps): weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() cat = RelaxedOneHotCategorical(probs=weights, temperature=torch.tensor([1.])) cluster_S = cat.sample() logprob = cat.log_prob(cluster_S.detach()) cluster_H = H(cluster_S) # logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(needsoftmax_mixtureweight), retain_graph=False)[0] one_hot = torch.zeros(n_cats) one_hot[cluster_H] = 1. f_val = f(one_hot) # grad = f_val * logprobgrad # needsoftmax_mixtureweight = needsoftmax_mixtureweight + lr*grad