def sample_simplax(probs): dist = RelaxedOneHotCategorical(probs=probs, temperature=torch.Tensor([1.])) z = dist.rsample() logprob = dist.log_prob(z) b = torch.argmax(z, dim=1) return z, b, logprob
def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits / 100., dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1) check_nan(logprob_cluster) logpxz = logprob_undercomponent( x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz # - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) return surr_loss
def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10 + cols, 4 + rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits / 100., dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size, 1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9, 205, n_evals) x = torch.from_numpy(x1).view(n_evals, 1).float().cuda() z = z.repeat(n_evals, 1) cluster_H = cluster_H.repeat(n_evals, 1) xz = torch.cat([z, x], dim=1) logpxz = logprob_undercomponent( x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz #- logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col = 0 row = i # print (row) ax = plt.subplot2grid((rows, cols), (row, col), frameon=False, colspan=1, rowspan=1) ax.plot(x1, surr_pred, label='Surr') ax.plot(x1, f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir + 'gmm_surr.png' plt.savefig(plt_path) print('saved training plot', plt_path) plt.close()
def simplax(surrogate, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) outputs = {} net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq + surr_pred) # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred)) # grad_logq = torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2) surr_dif = torch.mean(torch.abs(f.detach() - surr_pred)) # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss/ k surr_loss = surr_loss/ k outputs['net_loss'] = net_loss outputs['f'] = f outputs['logpx_given_z'] = logpx_given_z outputs['logpz'] = logpz outputs['logq'] = logq outputs['surr_loss'] = surr_loss outputs['surr_dif'] = surr_dif outputs['grad_path'] = grad_path outputs['grad_score'] = grad_score return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
def forward(self, x): x = self.featCompressor(x) x = self.fc1(x) x = self.fc2(x) logits = self.fc3(x) B, L, K = logits.shape RelaxedOneHotSampler = RelaxedOneHotCategorical(float(self.temper), logits=logits) y = RelaxedOneHotSampler.rsample() return y, F.softmax(logits, dim=-1), logits
def simplax(surrogate, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) net_loss += -torch.mean((f.detach() - surr_pred.detach()) * logq + surr_pred) # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred)) # grad_logq = torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] surr_loss += torch.mean( ((f.detach() - surr_pred) * grad_logq + grad_surr)**2) surr_dif = torch.mean(torch.abs(f.detach() - surr_pred)) grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad( [torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss / k surr_loss = surr_loss / k return net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
def gumbel_softmax_dist(self, param, name, temperature=1e-1, hard=True, sample_size=()): """ST gumbel with pytorch distributions.""" gumbel = RelaxedOneHotCategorical(temperature, logits=param) y = gumbel.rsample(sample_size) if hard: # One-hot the y y = self.st_op(y) return y
def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9,205, n_evals) x = torch.from_numpy(x1).view(n_evals,1).float().cuda() z = z.repeat(n_evals,1) cluster_H = cluster_H.repeat(n_evals,1) xz = torch.cat([z,x], dim=1) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col =0 row = i # print (row) ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) ax.plot(x1,surr_pred, label='Surr') ax.plot(x1,f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_surr.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close()
def _gumbel_softmax(probs, tau: float, hard: bool): """ Computes sampling from the Gumbel Softmax (GS) distribution Args: probs (torch.tensor): probabilities of shape [batch_size, n_classes] tau (float): temperature parameter for the GS hard (bool): discretize if True """ rohc = RelaxedOneHotCategorical(tau, probs) y = rohc.rsample() if hard: y_hard = torch.zeros_like(y) y_hard.scatter_(-1, torch.argmax(y, dim=-1, keepdim=True), 1.0) y = (y_hard - y).detach() + y return y
def reinforce_pz(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. net_loss += - torch.mean((f.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss/ k return net_loss, f, logpx_given_z, logpz, logq
def gumbel_softmax(self, logits, training, tau=1.0, msg_hard=None): device = torch.device("cuda" if logits.is_cuda else "cpu") if training: # Here, Gumbel sample is taken: msg_dists = RelaxedOneHotCategorical(tau, logits=logits) msg = msg_dists.rsample() if msg_hard is None: msg_hard = torch.zeros_like(msg, device=device) msg_hard.scatter_(-1, torch.argmax(msg, dim=-1, keepdim=True), 1.0) # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable msg = (msg_hard - msg).detach() + msg else: if msg_hard is None: msg = torch.zeros_like(logits, device=self.device) msg.scatter_(-1, torch.argmax(logits, dim=-1, keepdim=True), 1.0) else: msg = msg_hard return msg
def reinforce_pz(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) net_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq = cat.log_prob(cluster_S.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq - 1. net_loss += -torch.mean((f.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss / k return net_loss, f, logpx_given_z, logpz, logq
def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(f.detach()-surr_pred)) return surr_loss
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cat_bernoulli = Categorical(probs=probs) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq_z = cat.log_prob(cluster_S.detach()).view(B,1) logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f_z = logpxz - logq_z - 1. f_b = logpxz - logq_b - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] # surr_pred, alpha = surrogate.net(surr_input) surr_pred = surrogate.net(surr_input) alpha = torch.sigmoid(surrogate2.net(x)) net_loss += - torch.mean( alpha.detach()*(f_z.detach() - surr_pred.detach()) * logq_z + alpha.detach()*surr_pred + (1-alpha.detach())*(f_b.detach() ) * logq_b) # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred)) grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq_b = torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape) # fsdfa # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0] # print (grad_surr) # fsdfasd surr_loss += torch.mean( (alpha*(f_z.detach() - surr_pred) * grad_logq_z + alpha*grad_surr + (1-alpha)*(f_b.detach()) * grad_logq_b )**2 ) surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred)) # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0] # print (gradd) # fdsf grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss/ k surr_loss = surr_loss/ k return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
def forward(self, image, question=None, feature_saving=False, cut_image_info=False, cut_question_info=False, ground_truth=None, save_messages=False, save_features=False): self.use_gpu = torch.cuda.is_available() if self.use_gpu: device = 'cuda' else: device = 'cpu' assert not (feature_saving and self.training) self.batch_size = image.shape[0] im_features = [] if not self.train_from_symbolic: for i in range(image.shape[1]): im_features.append(self.stem_conv(image[:, i, :, :, :]).view(self.batch_size, -1, 1)) image_info = im_features[0] candidate_im_features = im_features[1:] else: candidate_im_features = [] ground_truth = ground_truth.type(torch.FloatTensor).to(device) image_info = ground_truth[:, 0, :] for i in range(image.shape[1]-1): candidate_im_features.append(self.bottleneck_in_fc(ground_truth[:, i+1, :]).view(self.batch_size, -1, 1)) #detaching gradients candidate_im_features = [candidate.detach() for candidate in candidate_im_features] if self.bottleneck: if self.training: message = [torch.zeros((self.batch_size, self.vocab_size), dtype=torch.float32)] if self.use_gpu: message[0] = message[0].cuda() message[0][:, self.bound_token_idx] = 1.0 else: message = [torch.full((self.batch_size,), fill_value=self.bound_token_idx, dtype=torch.int64)] if self.use_gpu: message[0] = message[0].cuda() # h0, c0 flattened_im_features = image_info.view(self.batch_size, -1) sender_representation = self.drop(self.bottleneck_in_fc(flattened_im_features)) h = sender_representation c = torch.zeros((self.batch_size, self.message_lstm_hidden_size)) if self.use_gpu: c = c.cuda() entropy = 0.0 # produce words one by one for i in range(self.max_sentence_length): emb = torch.matmul(message[-1], self.message_embedding) if self.training else self.message_embedding[message[-1]] h, c = self.lstm_cell(emb, (h, c)) vocab_scores = self.drop(self.hidden2vocab(h)) p = F.softmax(vocab_scores, dim=1) entropy += Categorical(p).entropy() if self.training: rohc = RelaxedOneHotCategorical(self.tau, p) token = rohc.rsample() # Straight-through part if not self.continuous_communication: token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: if self.greedy: _, token = torch.max(p, -1) else: token = Categorical(p).sample() message.append(token) message = torch.stack(message, dim=1) if self.training: _, m = torch.max(message, dim=-1) else: m = message md = calc_message_distinctness(m) # If we feed the ground_truth to the receiver, we simply hijack the message here if self.use_ground_truth: if self.training: message = batch_2_onehot(ground_truth, self.max_sentence_length, self.vocab_size) else: message = ground_truth.type(torch.LongTensor) if self.use_gpu: message = message.cuda() # Receiver part h = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size)) c = torch.zeros((self.batch_size, self.encoder_lstm_hidden_size)) if self.use_gpu: h = h.cuda() c = c.cuda() emb = torch.matmul(message, self.message_embedding) if self.training else self.message_embedding[message] _, (h, c) = self.message_encoder_lstm(emb, (h[None, ...], c[None, ...])) hidden_receiver = h[0] bottleneck_out = self.drop(self.aff_transform(hidden_receiver)) image_info = bottleneck_out #todo recomment # comm_info = {'entropy': torch.mean(entropy).item() / (self.max_sentence_length+ 1e-7), 'md': md} comm_info = {'entropy': 0, 'md': 0} if save_features: comm_info['image_features'] = flattened_im_features comm_info['sender_repr'] = sender_representation comm_info['receiver_repr'] = hidden_receiver if save_messages: comm_info['message'] = m # in case no communication bottleneck else: comm_info = None orig_im_features = image_info.view(self.batch_size, 1, -1) out = torch.zeros(self.batch_size, len(candidate_im_features)).type(torch.FloatTensor) if self.use_gpu: out = out.cuda() for i in range(len(candidate_im_features)): out[:, i] = torch.bmm(orig_im_features, candidate_im_features[i]).squeeze() return out, comm_info
return torch.argmax(soft, dim=1) surrogate = NN3(input_size=C, output_size=1, n_residual_blocks=2) train_ = 1 n_steps = 1000#0 #0 #1000 #50000 # B = 1 #32 #0 k=3 if train_: optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7) #Train surrogate for i in range(n_steps+1): warmup = 1. cat = RelaxedOneHotCategorical(logits=logits.repeat(B,1), temperature=torch.tensor([1.])) z = cat.rsample() logprob = cat.log_prob(z.detach()).view(B,1) b = H(z) reward = f(b).view(B,1) cz = surrogate.net(z) # estimator = (reward - cz) * logprob + cz # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0] gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0] gradcz = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0] # print (reward.shape, cz.shape, gradlogprob.shape, gradcz.shape) # fdasf # grad = (reward-cz) *gradlogprob + gradcz
train_ = 1 n_steps = 1000 #0 #0 #1000 #50000 # B = 1 #32 #0 k = 3 if train_: optim = torch.optim.Adam(surrogate.parameters(), lr=1e-4, weight_decay=1e-7) #Train surrogate for i in range(n_steps + 1): warmup = 1. cat = RelaxedOneHotCategorical(logits=logits.repeat(B, 1), temperature=torch.tensor([1.])) z = cat.rsample() logprob = cat.log_prob(z.detach()).view(B, 1) b = H(z) reward = f(b).view(B, 1) cz = surrogate.net(z) # estimator = (reward - cz) * logprob + cz # grad = torch.autograd.grad([torch.mean(estimator)], [logits], create_graph=True, retain_graph=True)[0] gradlogprob = torch.autograd.grad([torch.mean(logprob)], [logits], create_graph=True, retain_graph=True)[0] gradcz = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
def forward(self, t, word_counts, tau=1.2): batch_size = t.shape[0] if self.training: message = [ torch.zeros((batch_size, self.vocab_size), dtype=torch.float32) ] if self.use_gpu: message[0] = message[0].cuda() message[0][:, self.bound_token_idx] = 1.0 else: message = [ torch.full((batch_size, ), fill_value=self.bound_token_idx, dtype=torch.int64) ] if self.use_gpu: message[0] = message[0].cuda() # h0, c0 h = self.aff_transform(t) # batch_size, hidden_size c = torch.zeros([batch_size, self.hidden_size]) initial_length = self.max_sentence_length + 1 seq_lengths = torch.ones([batch_size], dtype=torch.int64) * initial_length ce_loss = nn.CrossEntropyLoss(reduction='none') # Handle alpha by giving weight to the padding token w_counts = word_counts.clone() # Tensor is passed by ref w_counts[self.bound_token_idx] *= self.bound_weight denominator = w_counts.sum() if denominator > 0: normalized_word_counts = w_counts / denominator else: normalized_word_counts = w_counts vl_loss = 0.0 entropy = 0.0 if self.use_gpu: c = c.cuda() seq_lengths = seq_lengths.cuda() input_embed_rep = [] for i in range(self.max_sentence_length ): # or sampled <EOS>, but this is batched emb = torch.matmul( message[-1], self.embedding ) if self.training else self.embedding[message[-1]] h, c = self.lstm_cell(emb, (h, c)) vocab_scores = self.linear_probs(h) p = F.softmax(vocab_scores, dim=1) entropy += Categorical(p).entropy() if self.training: rohc = RelaxedOneHotCategorical(tau, p) token = rohc.rsample() # Straight-through part token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: if self.greedy: _, token = torch.max(p, -1) else: token = Categorical(p).sample() message.append(token) input_embed_rep.append(emb) self._calculate_seq_len(seq_lengths, token, initial_length, seq_pos=i + 1) if self.vl_loss_weight > 0.0: vl_loss += ce_loss(vocab_scores - normalized_word_counts, self._discretize_token(token)) return (torch.stack(message, dim=1), seq_lengths, vl_loss, torch.mean(entropy) / self.max_sentence_length, torch.stack(input_embed_rep, dim=1))
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cat_bernoulli = Categorical(probs=probs) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq_z = cat.log_prob(cluster_S.detach()).view(B, 1) logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B, 1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B, 1) logpxz = logpx_given_z + logpz #[B,1] f_z = logpxz - logq_z - 1. f_b = logpxz - logq_b - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] # surr_pred, alpha = surrogate.net(surr_input) surr_pred = surrogate.net(surr_input) alpha = torch.sigmoid(surrogate2.net(x)) net_loss += -torch.mean(alpha.detach() * (f_z.detach() - surr_pred.detach()) * logq_z + alpha.detach() * surr_pred + (1 - alpha.detach()) * (f_b.detach()) * logq_b) # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred)) grad_logq_z = torch.mean(torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq_b = torch.mean(torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_surr = torch.mean(torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape) # fsdfa # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0] # print (grad_surr) # fsdfasd surr_loss += torch.mean( (alpha * (f_z.detach() - surr_pred) * grad_logq_z + alpha * grad_surr + (1 - alpha) * (f_b.detach()) * grad_logq_b)**2) surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred)) # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0] # print (gradd) # fdsf grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad( [torch.mean( (f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss / k surr_loss = surr_loss / k return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean( alpha)
for iii in range(len(probs2)): print(str(iii)+':'+str(probs2[iii]), end =" ") print () # print (probs.shape) print (probs) # fsdf cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) net_loss = 0 loss = 0 surr_loss = 0 grads = 0 for jj in range(k): cluster_S = cat.rsample() # print (cluster_S.shape) # dfsafd # cluster_onehot = torch.zeros(n_components).cuda() # cluster_onehot[jj%20] =1. # cluster_S = torch.softmax(cluster_onehot, dim=0).view(1,20) # print (cluster_onehot) # logprob_cluster = cat.log_prob(cluster_onehot.detach()).view(batch_size,1) # print (logprob_cluster) # fsdfasd cluster_H = H(cluster_S)
def forward(self, obs, memory, instr_embedding=None, tau=1.2): # Calculating instruction embedding if self.use_instr and instr_embedding is None: if self.lang_model == 'gru': _, hidden = self.instr_rnn(self.word_embedding(obs.instr)) instr_embedding = hidden[-1] #Calculating the image imedding x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3) if self.arch.startswith("expert_filmcnn"): image_embedding = self.image_conv(x) #Calculating FiLM_embedding from image and instruction embedding for controler in self.controllers: x = controler(image_embedding, instr_embedding) FiLM_embedding = F.relu(self.film_pool(x)) else: FiLM_embedding = self.image_conv(x) FiLM_embedding = FiLM_embedding.reshape(FiLM_embedding.shape[0], -1) #Going through the memory layer if self.use_memory: hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) hidden = self.memory_rnn(FiLM_embedding, hidden) embedding = hidden[0] memory = torch.cat(hidden, dim=1) else: embedding = x if self.use_instr and not "filmcnn" in self.arch: embedding = torch.cat((embedding, instr_embedding), dim=1) if hasattr(self, 'aux_info') and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() memory_rnn_output = embedding batch_size = memory_rnn_output.shape[0] message = [] for i in range(self.message_length): if i == 0: decoder_input = torch.tensor([self.sos_id] * batch_size, dtype=torch.long, device=self.device) decoder_input_embedded = self.word_embedding_decoder( decoder_input).unsqueeze(1) decoder_hidden = memory_rnn_output.unsqueeze(0) decoder_out, decoder_hidden = self.decoder_rnn( decoder_input_embedded, decoder_hidden) vocab_scores = self.hidden2word(decoder_out) vocab_probs = F.softmax(vocab_scores, -1) tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) + self.max_tau) tau = tau.expand(-1, self.vocab_size).unsqueeze(1) if self.training: rohc = RelaxedOneHotCategorical(tau, vocab_probs) token = rohc.rsample() # Straight-through part token_hard = torch.zeros_like(token) token_hard.scatter_(-1, torch.argmax(token, dim=-1, keepdim=True), 1.0) token = (token_hard - token).detach() + token else: token = torch.zeros_like(vocab_probs, device=self.device) token.scatter_(-1, torch.argmax(vocab_probs, dim=-1, keepdim=True), 1.0) message.append(token) decoder_input_embedded = torch.matmul( token, self.word_embedding_decoder.weight) comm = torch.stack(message, dim=1).squeeze(2) return comm, memory
def forward(self, instruction=None, observation=None, memory=None, compute_message_probs=False, time=None): if not hasattr(self, 'random_corrector'): self.random_corrector = False if not hasattr(self, 'var_len'): self.var_len = False if not hasattr(self, 'script'): self.script = False if not self.script: memory_rnn_output, memory = self.forward_film( instruction=instruction, observation=observation, memory=memory) batch_size = instruction.size(0) correction_encodings = [] entropy = 0.0 lengths = np.array([self.corr_length] * batch_size) total_corr_loss = 0 for i in range(self.corr_length): if i == 0: # every message starts with a SOS token decoder_input = torch.tensor([self.sos_id] * batch_size, dtype=torch.long, device=self.device) decoder_input_embedded = self.word_embedding_corrector( decoder_input).unsqueeze(1) decoder_hidden = memory_rnn_output.unsqueeze(0) if self.random_corrector: # randomize corrections device = torch.device( "cuda" if decoder_input_embedded.is_cuda else "cpu") decoder_input_embedded = torch.randn( decoder_input_embedded.size(), device=device) decoder_hidden = torch.randn(decoder_hidden.size(), device=device) rnn_output, decoder_hidden = self.decoder_rnn( decoder_input_embedded, decoder_hidden) vocab_scores = self.out(rnn_output) vocab_probs = F.softmax(vocab_scores, dim=-1) entropy += Categorical(vocab_probs).entropy() tau = 1.0 / (self.tau_layer(decoder_hidden).squeeze(0) + self.max_tau) tau = tau.expand(-1, self.num_embeddings).unsqueeze(1) if self.training: # Apply Gumbel SM cat_distr = RelaxedOneHotCategorical(tau, vocab_probs) corr_weights = cat_distr.rsample() corr_weights_hard = torch.zeros_like(corr_weights, device=self.device) corr_weights_hard.scatter_( -1, torch.argmax(corr_weights, dim=-1, keepdim=True), 1.0) # detach() detaches the output from the computation graph, so no gradient will be backprop'ed along this variable corr_weights = (corr_weights_hard - corr_weights).detach() + corr_weights else: # greedy sample corr_weights = torch.zeros_like(vocab_probs, device=self.device) corr_weights.scatter_( -1, torch.argmax(vocab_probs, dim=-1, keepdim=True), 1.0) if self.var_len: # consider sequence done when eos receives highest value max_idx = torch.argmax(corr_weights, dim=-1) eos_batches = max_idx.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > i) & eos_batches) != 0 lengths[update_idx] = i # compute correction error through pseudo-target: sequence of eos symbols to encourage short messages pseudo_target = torch.tensor( [self.eos_id for j in range(batch_size)], dtype=torch.long, device=self.device) loss = self.correction_loss(corr_weights.squeeze(1), pseudo_target) total_corr_loss += loss correction_encodings += [corr_weights] decoder_input_embedded = torch.matmul( corr_weights, self.word_embedding_corrector.weight) # one-hot vectors on forward, soft approximations on backward pass correction_encodings = torch.stack(correction_encodings, dim=1).squeeze(2) lengths = torch.tensor(lengths, dtype=torch.long, device=self.device) result = { 'correction_encodings': correction_encodings, 'correction_messages': self.decode_corrections(correction_encodings), 'correction_entropy': torch.mean(entropy), 'corrector_memory': memory, 'correction_lengths': lengths, 'correction_loss': total_corr_loss } else: # there is a script of pre-established guidance messages correction_messages = self.script[time] correction_encodings = self.encode_corrections(correction_messages) result = { 'correction_encodings': correction_encodings, 'correction_messages': correction_messages } return (result)