def forward(self, input, discrete=False): # This if was modified. I think the reason is because in AutoSpeech individual items have 2 dimensions (time, freq), while in my ToyASV2019 items have three dimensions (channels, time, freq), but channels is always 1. if len(input.shape) < 4: input = input.unsqueeze(1) s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): if cell.reduction: if discrete: weights = self.alphas_reduce else: weights = gumbel_softmax( F.log_softmax(self.alphas_reduce, dim=-1)) else: if discrete: weights = self.alphas_normal else: weights = gumbel_softmax( F.log_softmax(self.alphas_normal, dim=-1)) s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob) v = self.global_pooling(s1) v = v.view(v.size(0), -1) # This is not needed, in anti-spoofing we always forward through the whole network # if not self.training: # return v y = self.classifier(v) return y
def forward(self, data, sp_send, sp_rec, sp_send_t, sp_rec_t): r""" Parameters: data: input concept embedding matrix sp_send: one-hot encoded send-node index(sparse tensor) sp_rec: one-hot encoded receive-node index(sparse tensor) sp_send_t: one-hot encoded send-node index(sparse tensor, transpose) sp_rec_t: one-hot encoded receive-node index(sparse tensor, transpose) Shape: data: [concept_num, embedding_dim] sp_send: [edge_num, concept_num] sp_rec: [edge_num, concept_num] sp_send_t: [concept_num, edge_num] sp_rec_t: [concept_num, edge_num] Return: graphs: latent graph list modeled by z which has different edge types output: the reconstructed data prob: q(z|x) distribution """ logits = self.encoder( data, sp_send, sp_rec, sp_send_t, sp_rec_t) # [edge_num, output_dim(edge_type_num)] edges = gumbel_softmax(logits, tau=self.tau, dim=-1) # [edge_num, edge_type_num] prob = F.softmax(logits, dim=-1) output = self.decoder(data, edges, sp_send, sp_rec, sp_send_t, sp_rec_t) # [concept_num, embedding_dim] graphs = self._get_graph(edges, sp_send, sp_rec) return graphs, output, prob
def encode_t2r(self, inputs, traj_emb_pred, n_epoch): latent_edge_samples, latent_edge_probs, latent_edge_logits = \ None, None, None if self.opt_dynamics(n_epoch) and hasattr(self, 'enc_t2r'): if traj_emb_pred is None: nri_input = inputs['traj'] else: nri_input = traj_emb_pred if self.enc_t2r_detach_i2t_grad: nri_input = nri_input.detach() latent_edge_logits = self.enc_t2r(nri_input) latent_edge_samples = gumbel_softmax(latent_edge_logits, self.temp, hard=not self.training) latent_edge_probs = F.softmax(latent_edge_logits, dim=-1) # [B * T, n_atoms * (n_atoms - 1), 256] - dynamic graph inference # [B * 1, n_atoms * (n_atoms - 1), 256] - static graph inference batch_size = nri_input.size(0) shp = latent_edge_logits.shape shp = [batch_size, -1] + list(shp[1:]) latent_edge_logits = latent_edge_logits.view(shp) latent_edge_samples = latent_edge_samples.view(shp) latent_edge_probs = latent_edge_probs.view(shp) return latent_edge_samples, latent_edge_probs, latent_edge_logits
def forward(self, x, y, K): ''' :param x: encoded utterance in shape (B, 2*H) :param y: encoded response in shape (B, 2*H) (optional) :param K: encoded knowledge in shape (B, N, 2*H) :return: prior, posterior, selected knowledge, selected knowledge logits for BOW_loss ''' if y is not None: prior = F.log_softmax(torch.bmm(x.unsqueeze(1), K.transpose(-1, -2)), dim=-1).squeeze(1) response = self.mlp(torch.cat((x, y), dim=-1)) # response: [n_batch, 2*n_hidden] K = K.transpose(-1, -2) # K: [n_batch, 2*n_hidden, N] posterior_logits = torch.bmm(response.unsqueeze(1), K).squeeze(1) posterior = F.softmax(posterior_logits, dim=-1) k_idx = gumbel_softmax(posterior_logits, self.temperature) # k_idx: [n_batch, N(one_hot)] k_i = torch.bmm(K, k_idx.unsqueeze(2)).squeeze(2) # k_i: [n_batch, 2*n_hidden] k_logits = F.log_softmax(self.mlp_k(k_i), dim=-1) # k_logits: [n_batch, n_vocab] return prior, posterior, k_i, k_logits # prior: [n_batch, N], posterior: [n_batch, N] else: n_batch = K.size(0) k_i = torch.Tensor(n_batch, 2*self.n_hidden).cuda() prior = torch.bmm(x.unsqueeze(1), K.transpose(-1, -2)).squeeze(1) k_idx = prior.max(1)[1].unsqueeze(1) # k_idx: [n_batch, 1] for i in range(n_batch): k_i[i] = K[i, k_idx[i]] return k_i
def code_transformer(codes, step, max_steps): codes_prob, codes_logprob, codes_oh, codes_oh_prob, codes_oh_trans, codes_oh_scale, \ codes_sg, codes_hg_trans, codes_hg_scale = {}, {}, {}, {}, {}, {}, {}, {}, {} alpha = np.power(1 - step / max_steps, 2) for name in codes: logits = codes[name] prob = torch.softmax(logits, 1) # B x n_values logprob = torch.log_softmax(logits, 1) argmax = torch.argmax(logits, 1, True) # B one_hot = torch.zeros_like(logits).scatter_(1, argmax, 1.0) one_hot_prob = prob * one_hot one_hot_trans = one_hot - one_hot_prob.detach() + one_hot_prob one_hot_scale = one_hot_prob / one_hot_prob.sum(1, keepdim=True).detach() soft_gumbel, hard_gumbel_trans, hard_gumbel_scale = U.gumbel_softmax(logits, tau=1, dim=1) codes_prob[name] = prob codes_logprob[name] = logprob codes_oh[name] = one_hot codes_oh_prob[name] = one_hot_prob codes_oh_trans[name] = one_hot_trans codes_oh_scale[name] = one_hot_scale codes_sg[name] = soft_gumbel codes_hg_trans[name] = hard_gumbel_trans codes_hg_scale[name] = hard_gumbel_scale return {'logits': codes, 'prob': codes_prob, 'logprob': codes_logprob, 'one_hot': codes_oh, 'one_hot_prob': codes_oh_prob, 'one_hot_trans': codes_oh_trans, 'one_hot_scale': codes_oh_scale, 'soft_gumbel': codes_sg, 'hard_gumbel_trans': codes_hg_trans, 'hard_gumbel_scale': codes_hg_scale}
def nextStep(query, key, value, mask=None, tau=1): nonlocal i dot_products = (query.unsqueeze(1) * key).sum( -1) # batch x query_len x key_len if self.attend_mode == "only_attend_front": dot_products[:, i + 1:] -= 1e9 if self.window > 0: dot_products[:, :max(0, i - self.window)] -= 1e9 dot_products[:, i + self.window + 1:] -= 1e9 if self.attend_mode == "not_attend_self": dot_products[:, i] -= 1e9 if mask is not None: dot_products -= (1 - mask) * 1e9 logits = dot_products / self.scale if self.gumbel_attend and self.training: probs = gumbel_softmax(logits, tau, dim=-1) else: probs = torch.softmax(logits, dim=-1) probs = probs * ((dot_products <= -5e8).sum(-1, keepdim=True) < dot_products.shape[-1]).float() i += 1 return torch.einsum("ij, ijk->ik", self.dropout(probs), value)
def _get_sender_lstm_output(self, inputs): samples = [] batch_size = inputs.shape[0] sample_loss = torch.zeros(batch_size, device=self.config['device']) total_kl = torch.zeros(batch_size, device=self.config['device']) hx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) cx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) for num in range(self.config['num_binary_messages']): hx, cx = self.sender_cell(inputs, (hx, cx)) output = self.sender_project(hx) pre_logits = self.sender_out(output) sample = utils.gumbel_softmax( pre_logits, self.temperature[num], self.config['device'], ) logits_dist = dists.OneHotCategorical(logits=pre_logits) prior_logits = self.prior[num].unsqueeze(0) prior_logits = prior_logits.expand(batch_size, self.output_size) prior_dist = dists.OneHotCategorical(logits=prior_logits) kl = dists.kl_divergence(logits_dist, prior_dist) total_kl += kl samples.append(sample) return samples, sample_loss, total_kl
def code_transformer(codes): codes_prob, codes_oh, codes_oh_prob, codes_oh_trans, codes_oh_scale, codes_sg, codes_hg = {}, {}, {}, {}, {}, {}, {} for name in codes: logits = codes[name] prob = torch.softmax(logits, 1) # B x n_values argmax = torch.argmax(logits, 1, True) # B one_hot = torch.zeros_like(logits).scatter_(1, argmax, 1.0) one_hot_prob = prob * one_hot one_hot_trans = one_hot - one_hot_prob.detach() + one_hot_prob one_hot_scale = one_hot_prob / one_hot_prob.sum(1, keepdim=True).detach() soft_gumbel, hard_gumbel, _ = U.gumbel_softmax(logits, tau=1, dim=1) codes_prob[name] = prob codes_oh[name] = one_hot codes_oh_prob[name] = one_hot_prob codes_oh_trans[name] = one_hot_trans codes_oh_scale[name] = one_hot_scale codes_sg[name] = soft_gumbel codes_hg[name] = hard_gumbel return { 'prob': codes_prob, 'one_hot': codes_oh, 'one_hot_prob': codes_oh_prob, 'one_hot_trans': codes_oh_trans, 'one_hot_scale': codes_oh_scale, 'soft_gumbel': codes_sg, 'hard_gumbel': codes_hg }
def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, burn_in=False, burn_in_steps=1, dynamic_graph=False, encoder=None, temp=None): inputs = data.transpose(1, 2).contiguous() time_steps = inputs.size(1) # inputs has shape # [batch_size, num_timesteps, num_atoms, num_dims] # rel_type has shape: # [batch_size, num_atoms*(num_atoms-1), num_edge_types] hidden = Variable( torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape)) if inputs.is_cuda: hidden = hidden.cuda() pred_all = [] for step in range(0, inputs.size(1) - 1): if burn_in: if step <= burn_in_steps: ins = inputs[:, step, :, :] else: ins = pred_all[step - 1] else: assert (pred_steps <= time_steps) # Use ground truth trajectory input vs. last prediction if not step % pred_steps: ins = inputs[:, step, :, :] else: ins = pred_all[step - 1] if dynamic_graph and step >= burn_in_steps: # NOTE: Assumes burn_in_steps = args.timesteps logits = encoder( data[:, :, step - burn_in_steps:step, :].contiguous(), rel_rec, rel_send) rel_type = gumbel_softmax(logits, tau=temp, hard=True) pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, rel_type, hidden) pred_all.append(pred) preds = torch.stack(pred_all, dim=1) return preds.transpose(1, 2).contiguous()
def route(self, x, hard=True): x = x.float() logits = self.linear(x) if hard and self.training: return gumbel_softmax(logits) elif hard and not self.training: return ohe_from_logits(logits) else: return logits.softmax(dim=-1)
def update(self, agent_id, obs, acs, rews, next_obs, dones ,t_step, logger=None): obs = torch.from_numpy(obs).float() acs = torch.from_numpy(acs).float() rews = torch.from_numpy(rews[:,agent_id]).float() next_obs = torch.from_numpy(next_obs).float() dones = torch.from_numpy(dones[:,agent_id]).float() acs = acs.view(-1,2) # --------- update critic ------------ # self.critic_optimizer.zero_grad() all_trgt_acs = self.target_policy(next_obs) target_value = (rews + self.gamma * self.target_critic(next_obs,all_trgt_acs) * (1 - dones)) actual_value = self.critic(obs,acs) vf_loss = MSELoss(actual_value, target_value.detach()) # Minimize the loss vf_loss.backward() #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1) self.critic_optimizer.step() # --------- update actor --------------- # self.policy_optimizer.zero_grad() if self.discrete_action: curr_pol_out = self.policy(obs) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = self.policy(obs) curr_pol_vf_in = curr_pol_out pol_loss = -self.critic(obs,curr_pol_vf_in).mean() #pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1) self.policy_optimizer.step() self.update_all_targets() self.eps -= self.eps_decay self.eps = max(self.eps, 0) if logger is not None: logger.add_scalars('agent%i/losses' % self.agent_name, {'vf_loss': vf_loss, 'pol_loss': pol_loss}, self.niter)
def forward(self, query, temperature=1.): ''' query: batch_size x timestep x embedding_dim embedding: n_latent x embedding_dim ''' unnormalized_logits = query @ self.embedding.weight unnormalized_logits = unnormalized_logits / torch.norm(query, p=2, dim=-1, keepdim=True) logits = unnormalized_logits / torch.norm(torch.t(self.embedding.weight), p=2, dim=-1) proba = gumbel_softmax(logits, temperature=temperature) output = self.embedding(proba) return proba, output
def forward(self, input, temperature=5.0): s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): if cell.reduction: if self.alphas_reduce.size(1) == 1: # weights = F.softmax(self.alphas_reduce, dim=0) # weights = F.gumbel_softmax(self.alphas_reduce, temperature, dim=0) weights = gumbel_softmax(F.log_softmax(self.alphas_reduce, dim=0), tau=temperature, hard=True, dim=0) else: # weights = F.softmax(self.alphas_reduce, dim=-1) # weights = F.gumbel_softmax(self.alphas_reduce, temperature, dim=-1) weights = gumbel_softmax(F.log_softmax(self.alphas_reduce, dim=-1), tau=temperature, hard=True, dim=-1) else: if self.alphas_normal.size(1) == 1: # weights = F.softmax(self.alphas_normal, dim=0) # weights = F.gumbel_softmax(self.alphas_normal, temperature, dim=0) weights = gumbel_softmax(F.log_softmax(self.alphas_normal, dim=0), tau=temperature, hard=True, dim=0) else: # weights = F.softmax(self.alphas_normal, dim=-1) # weights = F.gumbel_softmax(self.alphas_normal, temperature, dim=-1) weights = gumbel_softmax(F.log_softmax(self.alphas_normal, dim=-1), tau=temperature, hard=True, dim=-1) s0, s1 = s1, cell(s0, s1, weights) out = self.global_pooling(s1) logits = self.classifier(out.view(out.size(0), -1)) return logits
def update(self): batch_states, batch_actions, batch_rewards, batch_new_states, batch_dones = self.replay_memory.sample_mini_batch( batch_size=self.batch_size) batch_states = batch_states.to(self.device) batch_actions = batch_actions.to(self.device) batch_rewards = batch_rewards.to(self.device) batch_new_states = batch_new_states.to(self.device) batch_dones = batch_dones.to(self.device) critic_loss_per_agent = [] actor_loss_per_agent = [] for idx in range(len(self.actors)): actor = self.actors[idx] critic = self.critics[idx] old_actor = self.old_actors[idx] old_critic = self.old_critics[idx] actor_optimizer = self.actor_optimizers[idx] critic_optimizer = self.critic_optimizers[idx] # update critic predict_Q = critic(state=batch_states, actions=batch_actions).squeeze(-1) old_actor_actions = old_actor(batch_new_states) target_actions = batch_actions.clone().detach() target_actions[:, idx, :] = old_actor_actions target_actions = convert_to_onehot(target_actions, epsilon=self.epsilon) target_Q = self.gamma * old_critic( state=batch_new_states, actions=target_actions).squeeze(-1) * ( 1 - batch_dones) + batch_rewards c_loss = self.critic_loss(input=predict_Q, target=target_Q.detach()) c_loss.backward() torch.nn.utils.clip_grad_norm(critic.parameters(), 0.5) critic_optimizer.step() critic_optimizer.zero_grad() critic_loss_per_agent.append(c_loss.item()) # update actor actor_actions = actor(batch_states) actor_actions = gumbel_softmax(actor_actions, hard=True) predict_actions = batch_actions.clone().detach() predict_actions[:, idx, :] = actor_actions a_loss = -critic(state=batch_states, actions=predict_actions).squeeze(-1) a_loss = a_loss.mean() torch.nn.utils.clip_grad_norm(actor.parameters(), 0.5) a_loss.backward() actor_optimizer.step() actor_optimizer.zero_grad() actor_loss_per_agent.append(a_loss.item()) return sum(actor_loss_per_agent) / len(actor_loss_per_agent), sum( critic_loss_per_agent) / len(critic_loss_per_agent)
def select_action(self, state, temperature=None, is_tensor=False, is_target=False): # TODO after finished: add temperature to Gumbel sampling # __import__('ipdb').set_trace() st = state if not is_tensor: st = torch.from_numpy(state).view(1, -1).float().to(device) if is_target: action = self.policy_targ(st) else: __import__('ipdb').set_trace() action = self.policy(st) action_with_noise = gumbel_softmax(action, hard=True).detach() return action_with_noise
def select_action(self, state, temperature=None, is_tensor=False, is_target=False): if self.use_warmup and self.action_count < WARMUP_STEPS: self.action_count += 1 # TODO after finished: add temperature to Gumbel sampling # TODO ADD varmup steps to sac st = state if not is_tensor: st = torch.from_numpy(state).view(1, -1).float().to(device) if is_target: action = self.policy_targ(st) # action = self.policy_targ(st) else: # __import__('ipdb').set_trace() action = self.policy(st) action_with_noise = gumbel_softmax(action, hard=True).detach() return action_with_noise
def forward(self, inputs, rel_rec, rel_send, tau, hard, pred_steps): # NOTE: Assumes that we have the same graph across all samples. # logits = self.rel_graph # (inputs, rel_rec, rel_send) edges = gumbel_softmax(self.rel_graph, tau, hard) inputs = inputs.transpose(1, 2).contiguous() # sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), # rel_type.size(2)] # NOTE: Assumes rel_type is constant (i.e. same across all time steps). # rel_type = rel_type.unsqueeze(1).expand(sizes) time_steps = inputs.size(1) assert (pred_steps <= time_steps) preds = [] # Only take n-th timesteps as starting points (n: pred_steps) 5 last_pred = inputs[:, 0::pred_steps, :, :] # curr_rel_type = rel_type[:, 0::pred_steps, :, :] # NOTE: Assumes rel_type is constant (i.e. same across all time steps). # Run n prediction steps for step in range(0, pred_steps): last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, edges) preds.append(last_pred) sizes = [ preds[0].size(0), preds[0].size(1) * pred_steps, preds[0].size(2), preds[0].size(3) ] output = Variable(torch.zeros(sizes)) if inputs.is_cuda: output = output.cuda() # Re-assemble correct timeline for i in range(len(preds)): #10 #5 fixed points, 10 each sequence. preds[i] means the ith of each sequence. output[:, i::pred_steps, :, :] = preds[i] pred_all = output[:, :(inputs.size(1) - 1), :, :] return pred_all.transpose(1, 2).contiguous(), \ self.rel_graph.squeeze(1).expand([inputs.size(0), self.num_nodes*(self.num_nodes-1), self.edge_types])
def forward(self, input, img_emb, lengths): lengths, sorted_idx = torch.sort(lengths, descending=True) if len(input.size()) > 2: input = input[sorted_idx] img_emb = img_emb[sorted_idx] if not self.use_gumbel_generator: gumbel = torch.zeros_like(input) for i in range(input.size(1)): gumbel[:, i] = gumbel_softmax(input[:, i].squeeze(), tau=0.5).unsqueeze(1) input = gumbel input_emb = torch.mm(input.view(-1,input.size(-1)), self.embedding.weight)\ .view(input.size(0),-1, self.embedding.embedding_dim) else: input_emb = self.embedding(input) # input_emb = add_gaussian(input_emb, std=0.01) input_emb = self.word_dropout(input_emb) packed_input = rnn_utils.pack_padded_sequence(input_emb, lengths.data.tolist(), batch_first=True) outputs, hidden = self.rnn(packed_input, img_emb.unsqueeze(0)) # process outputs outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0] # get the last time step for each sequence idx = (lengths - 1).view(-1, 1).expand(outputs.size(0), outputs.size(2)).unsqueeze(1) decoded = outputs.gather(1, idx).squeeze() # lengths, sorted_idx = torch.sort(lengths, descending=True) img = self.img_embedding(img_emb[:, :self.masked_size]) # augmented = torch.cat([decoded ,img],1) # augmented = torch.cat([outputs[:,-1,:].squeeze() ,img],1) # augmented = torch.cat([hidden.squeeze() ,img],1) pred = self.fc(decoded) return torch.sigmoid(pred).mean(0), hidden # .view(1)
def forward(self, image): batch_size = image.shape[0] h_img = self.beholder(image).detach() start = [self.w2i["<BOS>"] for _ in range(batch_size)] gen_idx = [] done = np.array([False for _ in range(batch_size)]) h_img = h_img.unsqueeze(0).view(1, -1, self.D_hid).repeat(1, 1, 1) hid = h_img ft = torch.tensor(start, dtype=torch.long, device=device).view(-1).unsqueeze(1) input = self.emb(ft) msg_lens = [self.seq_len for _ in range(batch_size)] for idx in range(self.seq_len): input = F.relu(input) self.rnn.flatten_parameters() output, hid = self.rnn(input, hid) output = output.view(-1, self.D_hid) output = self.hid_to_voc(output) output = output.view(-1, self.vocab_size) top1, topi = U.gumbel_softmax(output, self.temp, self.hard) gen_idx.append(top1) for ii in range(batch_size): if topi[ii] == self.w2i["<EOS>"]: done[ii] = True msg_lens[ii] = idx + 1 if np.array_equal(done, np.array([True for _ in range(batch_size)])): break input = self.emb(topi) gen_idx = torch.stack(gen_idx).permute(1, 0, 2) msg_lens = torch.tensor(msg_lens, dtype=torch.long, device=device) return gen_idx, msg_lens
def test_prior(self, data): batch_size = data.shape[0] input_embs = self.sender_embedding(data) inputs = input_embs.view( batch_size, self.config['num_digits'] * self.config['embedding_size_sender']) hx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) cx = torch.zeros(batch_size, self.config['num_lstm_sender'], device=self.config['device']) samples = [] log_probs = 0 post_probs = 0 for num in range(self.config['num_binary_messages']): hx, cx = self.sender_cell(inputs, (hx, cx)) output = self.sender_project(hx) pre_logits = self.sender_out(output) posterior_prob = torch.log_softmax(pre_logits, -1) sample = utils.gumbel_softmax(pre_logits, self.temperature[num], self.config['device']) samples.append(sample) maxz = torch.argmax(sample, dim=-1, keepdim=True) h_z = torch.zeros(sample.shape, device=self.config['device']).scatter_( -1, maxz, 1) prior_dst = dists.OneHotCategorical(logits=self.prior[num]) log_prob = prior_dst.log_prob(h_z).detach().cpu().numpy() log_probs += log_prob post_probs += posterior_prob[torch.arange(batch_size), maxz.squeeze()] samples = torch.stack(samples).permute(1, 0, 2) prior_prob = log_probs / self.config['num_binary_messages'] post_prob = post_probs.detach().cpu().numpy( ) / self.config['num_binary_messages'] return post_prob, prior_prob, samples
def reconstruct(self, gen_images, gen_captions, gen_len): sorted_lengths, sorted_idx = torch.sort(gen_len, descending=True) gen_captions = gen_captions[sorted_idx] gumbel = torch.zeros_like(gen_captions) for i in range(gen_captions.size(1)): gumbel[:, i] = gumbel_softmax(gen_captions[:, i].squeeze(), tau=0.5).unsqueeze(1) gumbel_emb = torch.mm(gumbel.view(-1,gumbel.size(-1)), self.embedding.weight)\ .view(gumbel.size(0),-1, self.embedding.embedding_dim) img_enc = self.img_encoder_forward(gen_images) txt_enc = self.txt_encoder_forward(gumbel_emb) _, reversed_idx = torch.sort(sorted_idx) txt_enc = txt_enc[reversed_idx] img_mu, img_logv, img_z = self.Hidden2Z_img(img_enc) txt_mu, txt_logv, txt_z = self.Hidden2Z_txt(txt_enc) return img_mu, img_logv, img_z, txt_mu, txt_logv, txt_z
def forward(self, query, key, value, mask=None, tau=1): dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum( -1) # batch x query_len x key_len if self.attend_mode == "only_attend_front": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "only_attend_back": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "not_attend_self": assert query.shape[1] == key.shape[1] eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9 dot_products = dot_products - eye.unsqueeze(0) if self.window > 0: assert query.shape[1] == key.shape[1] window_mask = cuda(torch.ones(key.shape[1], key.shape[1]), device=query) window_mask = (window_mask.triu(self.window + 1) + window_mask.tril(self.window + 1)) * 1e9 dot_products = dot_products - window_mask.unsqueeze(0) if mask is not None: dot_products -= (1 - mask) * 1e9 logits = dot_products / self.scale if self.gumbel_attend and self.training: probs = gumbel_softmax(logits, tau, dim=-1) else: probs = torch.softmax(logits, dim=-1) probs = probs * ((dot_products <= -5e8).sum(-1, keepdim=True) < dot_products.shape[-1]).float() return torch.matmul(self.dropout(probs), value)
def select_action(self, state, temperature=None, is_tensor=False, is_target=False): self.action_count += 1 if self.action_count < self.start_steps: # return random action: # self.action return self.random_action() # print("select action") st = state if not is_tensor: st = torch.from_numpy(state).view(1, -1).float().to(device) if is_target: action = self.policy_targ(st) # action = self.policy_targ(st) else: # __import__('ipdb').set_trace() # print("not target") action = self.policy(st) noise = (self.act_noise**0.5)*torch.randn(action.shape) # __import__('ipdb').set_trace() action += noise action_with_noise = gumbel_softmax(action, hard=True).detach() # __import__('ipdb').set_trace() return action_with_noise
def generator(self, input, input_step, input_size, hidden_size, batch_size, reuse=False): with tf.variable_scope("generator") as scope: # lstm cell and wrap with dropout g_lstm_cell = tf.contrib.rnn.BasicLSTMCell(input_size, forget_bias=0.0, state_is_tuple=True) g_lstm_cell_1 = tf.contrib.rnn.BasicLSTMCell(input_size, forget_bias=0.0, state_is_tuple=True) g_lstm_cell_attention = tf.contrib.rnn.AttentionCellWrapper( g_lstm_cell, attn_length=10) g_lstm_cell_attention_1 = tf.contrib.rnn.AttentionCellWrapper( g_lstm_cell_1, attn_length=10) if self.attention == 1: g_lstm_cell_drop = tf.contrib.rnn.DropoutWrapper( g_lstm_cell_attention, output_keep_prob=0.9) g_lstm_cell_drop_1 = tf.contrib.rnn.DropoutWrapper( g_lstm_cell_attention_1, output_keep_prob=0.9) else: g_lstm_cell_drop = tf.contrib.rnn.DropoutWrapper( g_lstm_cell, output_keep_prob=0.9) g_lstm_cell_drop_1 = tf.contrib.rnn.DropoutWrapper( g_lstm_cell_1, output_keep_prob=0.9) g_cell = tf.contrib.rnn.MultiRNNCell( [g_lstm_cell_drop, g_lstm_cell_drop_1], state_is_tuple=True) g_state_ = g_cell.zero_state(batch_size, tf.float32) # g_W_o = utils.glorot([hidden_size, input_size]) # g_b_o = tf.Variable(tf.random_normal([input_size])) # neural network g_outputs = [] g_state = g_state_ for i in range(input_step): if i > 0: tf.get_variable_scope().reuse_variables() (g_cell_output, g_state) = g_cell( input[:, i, :], g_state) # cell_out: [batch_size, hidden_size] g_outputs.append( g_cell_output ) # output: shape[input_step][batch_size, hidden_size] if self.gumbel == 0: g_output = tf.reshape(tf.concat(g_outputs, axis=1), [-1, input_size]) g_y_soft = tf.nn.softmax(g_output) self.z_ = tf.reshape(g_y_soft, [batch_size, input_step, input_size]) else: g_output = tf.reshape(tf.concat(g_outputs, axis=1), [batch_size, input_step, input_size]) self.z_ = utils.gumbel_softmax(g_output, self.temperature, hard=True) # concentrate input and output of rnn x = tf.concat([input, self.z_], axis=1) return x
def forward(self, query, key, value, mask=None, tau=1): dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum( -1) # batch x query_len x key_len if self.relative_clip: dot_relative = torch.einsum( "ijk,tk->ijt", query, self.key_relative.weight) # batch * query_len * relative_size batch_size, query_len, key_len = dot_products.shape diag_dim = max(query_len, key_len) if self.diag_id.shape[0] < diag_dim: self.diag_id = np.zeros((diag_dim, diag_dim)) for i in range(diag_dim): for j in range(diag_dim): if i <= j - self.relative_clip: self.diag_id[i, j] = 0 elif i >= j + self.relative_clip: self.diag_id[i, j] = self.relative_clip * 2 else: self.diag_id[i, j] = i - j + self.relative_clip diag_id = LongTensor(self.diag_id[:query_len, :key_len]) dot_relative = reshape( dot_relative, "bld", "bl_d", key_len).gather(-1, reshape(diag_id, "lm", "_lm_", batch_size, -1))[:, :, :, 0] # batch * query_len * key_len dot_products = dot_products + dot_relative if self.attend_mode == "only_attend_front": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "only_attend_back": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "not_attend_self": assert query.shape[1] == key.shape[1] eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9 dot_products = dot_products - eye.unsqueeze(0) if self.window > 0: assert query.shape[1] == key.shape[1] window_mask = cuda(torch.ones(key.shape[1], key.shape[1]), device=query) window_mask = (window_mask.triu(self.window + 1) + window_mask.tril(self.window + 1)) * 1e9 dot_products = dot_products - window_mask.unsqueeze(0) if mask is not None: dot_products -= (1 - mask) * 1e9 logits = dot_products / self.scale if self.gumbel_attend and self.training: probs = gumbel_softmax(logits, tau, dim=-1) else: probs = torch.softmax(logits, dim=-1) probs = probs * ( (dot_products <= -5e8).sum(-1, keepdim=True) < dot_products.shape[-1]).float() # batch_size * query_len * key_len probs = self.dropout(probs) res = torch.matmul(probs, value) # batch_size * query_len * d_value if self.relative_clip: if self.recover_id.shape[0] < query_len: self.recover_id = np.zeros((query_len, self.relative_size)) for i in range(query_len): for j in range(self.relative_size): self.recover_id[i, j] = i + j - self.relative_clip recover_id = LongTensor(self.recover_id[:key_len]) recover_id[recover_id < 0] = key_len recover_id[recover_id >= key_len] = key_len probs = torch.cat([probs, zeros(batch_size, query_len, 1)], -1) relative_probs = probs.gather( -1, reshape(recover_id, "qr", "_qr", batch_size)) # batch_size * query_len * relative_size res = res + torch.einsum( "bqr,rd->bqd", relative_probs, self.value_relative.weight) # batch_size * query_len * d_value return res
def train_td3(self, batch_size): self.n_updates += 1 batch = self.memory.sample(min(batch_size, len(self.memory))) states_i, actions_i, rewards_i, next_states_i, dones_i = batch # __import__('ipdb').set_trace() if self.use_maddpg: states_all = torch.cat(states_i, 1) next_states_all = torch.cat(next_states_i, 1) actions_all = torch.cat(actions_i, 1) for i, agent in enumerate(self.agents): # print("training_qnet") if not self.use_maddpg: states_all = states_i[i] next_states_all = next_states_i[i] actions_all = actions_i[i] if self.use_maddpg: next_actions_all = [ag.policy(next_state) for ag, next_state in zip(self.agents, next_states_i)] [self.batch_add_random_acts(e, i) for i, e in enumerate(next_actions_all)] next_actions_all = [onehot_from_logits(e) for e in next_actions_all] else: actions_and_logits = [onehot_from_logits(agent.policy(next_states_i[i]))] next_actions_all = [e[0] for e in actions_and_logits] total_obs = torch.cat([next_states_all, torch.cat(next_actions_all, 1)], 1) qnet_targs = [] for qnet in self.agents[i].qnet_targs: qnet_targs.append(qnet(total_obs).detach()) rewards = rewards_i[i].view(-1, 1) dones = dones_i[i].view(-1, 1) qnet_mins = torch.min(qnet_targs[0], qnet_targs[1]) target_q = rewards + (1 - dones) * GAMMA * (qnet_mins) losses = [] for j, qnet in enumerate(self.agents[i].qnets): input_q = qnet(torch.cat([states_all, actions_all], 1)) self.agents[i].q_optimizers[j].zero_grad() loss = self.criterion(input_q, target_q.detach()) losses.append(loss.item()) loss.backward() # torch.nn.utils.clip_grad_norm_(qnet.parameters(), 0.5) self.agents[i].q_optimizers[j].step() if self.args.use_writer: self.writer.add_scalar(f"Agent_{i}: q_net_loss: ", np.mean(losses), self.n_updates) if self.n_updates % 2 == 0: for i in range(self.n_agents): # print("training policy") actor_loss = 0 # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø # use gumbel softmax max temp trick policy_out = self.agents[i].policy(states_i[i]) gumbel_sample = gumbel_softmax(policy_out, hard=True) if self.use_maddpg: actions_curr_pols = [onehot_from_logits(agent_.policy(state)) for agent_, state in zip(self.agents, states_i)] for action_batch in actions_curr_pols: action_batch.detach_() actions_curr_pols[i] = gumbel_sample actor_loss = - self.agents[i].qnets[0](torch.cat([states_all.detach(), torch.cat(actions_curr_pols, 1)], 1)).mean() else: actor_loss = - self.agents[i].qnets[0](torch.cat([states_all.detach(), gumbel_sample], 1)).mean() self.agents[i].p_optimizer.zero_grad() actor_loss.backward() torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(), 0.5) self.agents[i].p_optimizer.step() actions_i[i].detach_() if self.args.use_writer: self.writer.add_scalar(f"Agent_{i}: policy_objective: ", actor_loss.item(), self.n_updates) self.update_all_targets()
def sample_and_train_sac(self, batch_size): # TODO ADD Model saving, optimize code batch = self.memory.sample(min(batch_size, len(self.memory))) states_i, actions_i, rewards_i, next_states_i, dones_i = batch # __import__('ipdb').set_trace() if self.use_maddpg: states_all = torch.cat(states_i, 1) next_states_all = torch.cat(next_states_i, 1) actions_all = torch.cat(actions_i, 1) for i, agent in enumerate(self.agents): if not self.use_maddpg: states_all = states_i[i] next_states_all = next_states_i[i] actions_all = actions_i[i] if self.use_maddpg: actions_and_logits = [onehot_from_logits(ag.policy(next_state), logprobs=True) for ag, next_state in zip(self.agents, next_states_i)] next_actions_all = [e[0] for e in actions_and_logits] next_logits_all = [self.sac_alpha*e[1] for e in actions_and_logits] # __import__('ipdb').set_trace() else: actions_and_logits = [onehot_from_logits(agent.policy(next_states_i[i]), logprobs=True)] next_actions_all = [e[0] for e in actions_and_logits] next_logits_all = [self.sac_alpha*e[1] for e in actions_and_logits] # computing target total_obs = torch.cat([next_states_all, torch.cat(next_actions_all, 1)], 1) # target_q = self.agents[i].qnet_targ(total_obs).detach() qnet_targs = [] for qnet in self.agents[i].qnet_targs: qnet_targs.append(qnet(total_obs).detach()) rewards = rewards_i[i].view(-1, 1) dones = dones_i[i].view(-1, 1) qnet_mins = torch.min(qnet_targs[0], qnet_targs[1]) # __import__('ipdb').set_trace() logits_idx = i if self.use_maddpg else 0 logits_agent = next_logits_all[logits_idx] # if len(qnet_mins.squeeze(-1)) != len(logits_agent.squeeze(-1)): # __import__('ipdb').set_trace() target_q = rewards + (1 - dones) * GAMMA * (qnet_mins - logits_agent.reshape(qnet_mins.shape)) # __import__('ipdb').set_trace() # computing the inputs for j, qnet in enumerate(self.agents[i].qnets): input_q = qnet(torch.cat([states_all, actions_all], 1)) self.agents[i].q_optimizers[j].zero_grad() # print("----") # __import__('ipdb').set_trace() loss = self.criterion(input_q, target_q.detach()) # print('after') loss.backward() torch.nn.utils.clip_grad_norm_(qnet.parameters(), 0.5) self.agents[i].q_optimizers[j].step() # __import__('ipdb').set_trace() actor_loss = 0 # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø # use gumbel softmax max temp trick policy_out = self.agents[i].policy(states_i[i]) gumbel_sample, act_logprobs = gumbel_softmax(policy_out, hard=True, logprobs=True) act_logprobs = self.sac_alpha*act_logprobs # __import__('ipdb').set_trace() if self.use_maddpg: with torch.no_grad(): actions_curr_pols = [onehot_from_logits(agent_.policy(state)) for agent_, state in zip(self.agents, states_i)] actions_curr_pols[i] = gumbel_sample total_obs = torch.cat([states_all, torch.cat(actions_curr_pols, 1)], 1) qnet_outs = [] for qnet in self.agents[i].qnets: qnet_outs.append(qnet(total_obs)) qnet_mins = torch.min(qnet_outs[0], qnet_outs[1]) actor_loss = - qnet_mins.mean() # __import__('ipdb').set_trace() else: # actor_loss = - self.agents[i].qnet(torch.cat([states_all.detach(), # gumbel_sample], 1)).mean() # actions_curr_pols[i] = gumbel_sample # __import__('ipdb').set_trace() total_obs = torch.cat([states_all, gumbel_sample], 1) qnet_outs = [] for qnet in self.agents[i].qnets: qnet_outs.append(qnet(total_obs)) qnet_mins = torch.min(qnet_outs[0], qnet_outs[1]) actor_loss = - qnet_mins.mean() # actor_loss += (policy_out**2).mean() * 1e-3 self.agents[i].p_optimizer.zero_grad() actor_loss.backward() # nn.utils.clip_grad_norm_(self.policy.parameters(), 5) # torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(), 0.5) self.agents[i].p_optimizer.step() # detach the forward propagated action samples actions_i[i].detach_() # __import__('ipdb').set_trace() if self.args.use_writer: self.writer.add_scalars("Agent_%i" % i, { "vf_loss": loss, "actor_loss": actor_loss }, self.n_updates) self.update_all_targets() self.n_updates += 1
def gumbel_decoder(self, z=None): batch_size = self.batch_size hidden = self.latent2hidden(z) if self.bidirectional or self.num_layers > 1: # unflatten hidden state hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) hidden = hidden.unsqueeze(0) # required for dynamic stopping of sentence generation sequence_idx = torch.arange( 0, batch_size, out=self.tensor()).long() # all idx of batch sequence_running = torch.arange(0, batch_size, out=self.tensor()).long( ) # all idx of batch which are still generating sequence_mask = torch.ones(batch_size, out=self.tensor()).byte() running_seqs = torch.arange(0, batch_size, out=self.tensor()).long( ) # idx of still generating sequences with respect to current loop generations = self.tensor(batch_size, self.max_sequence_length).fill_( self.pad_idx).long() t = 0 while (t < self.max_sequence_length and len(running_seqs) > 0): if t == 0: input_sequence = to_var( torch.Tensor(batch_size).fill_(self.sos_idx).long()) input_sequence = input_sequence.unsqueeze(1) input_embedding = self.embedding(input_sequence) output, hidden = self.decoder_rnn(input_embedding, hidden) logits = self.outputs2vocab(output) # input_sequence = self._sample(logits) input_sequence = gumbel_softmax(logits) # save next input generations = self._save_sample(generations, input_sequence, sequence_running, t) # update gloabl running sequence sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data sequence_running = sequence_idx.masked_select(sequence_mask) # update local running sequences running_mask = (input_sequence != self.eos_idx).data running_seqs = running_seqs.masked_select(running_mask) # prune input and hidden state according to local update if len(running_seqs) > 0: input_sequence = input_sequence[running_seqs] hidden = hidden[:, running_seqs] running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long() t += 1 return generations, z
def sample_and_train(self, batch_size): # TODO ADD Model saving, optimize code batch = self.memory.sample(min(batch_size, len(self.memory))) states_i, actions_i, rewards_i, next_states_i, dones_i = batch states_all = torch.cat(states_i, 1) next_states_all = torch.cat(next_states_i, 1) actions_all = torch.cat(actions_i, 1) for i, agent in enumerate(self.agents): next_actions_all = [ onehot_from_logits(ag.policy_targ(next_state)) for ag, next_state in zip(self.agents, next_states_i) ] # computing target total_obs = torch.cat( [next_states_all, torch.cat(next_actions_all, 1)], 1) target_q = self.agents[i].qnet_targ(total_obs).detach() rewards = rewards_i[i].view(-1, 1) dones = dones_i[i].view(-1, 1) target_q = rewards + (1 - dones) * GAMMA * target_q # computing the inputs input_q = self.agents[i].qnet( torch.cat([states_all, actions_all], 1)) self.agents[i].q_optimizer.zero_grad() loss = self.criterion(input_q, target_q.detach()) # print("LOSS", loss) loss.backward() torch.nn.utils.clip_grad_norm_(self.agents[i].qnet.parameters(), 0.5) self.agents[i].q_optimizer.step() actor_loss = 0 # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø # use gumbel softmax max temp trick policy_out = self.agents[i].policy(states_i[i]) gumbel_sample = gumbel_softmax(policy_out, hard=True) actions_curr_pols = [ onehot_from_logits(agent_.policy(state)) for agent_, state in zip(self.agents, states_i) ] for action_batch in actions_curr_pols: action_batch.detach_() actions_curr_pols[i] = gumbel_sample actor_loss = -self.agents[i].qnet( torch.cat( [states_all.detach(), torch.cat(actions_curr_pols, 1)], 1)).mean() actor_loss += (policy_out**2).mean() * 1e-3 self.agents[i].p_optimizer.zero_grad() actor_loss.backward() # nn.utils.clip_grad_norm_(self.policy.parameters(), 5) torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(), 0.5) self.agents[i].p_optimizer.step() # detach the forward propagated action samples actions_i[i].detach_() if self.args.use_writer: self.writer.add_scalars("Agent_%i" % i, { "vf_loss": loss, "actor_loss": actor_loss }, self.n_updates) self.update_all_targets() self.n_updates += 1
def forward(self, data, rel_type, pred_steps=1, burn_in=False, burn_in_steps=1, dynamic_graph=False, encoder=None, temp=None): # inputs shape [B, T, K, num_dims] inputs = data.transpose(1, 2).contiguous() time_steps = inputs.size(1) # rel_type [B, 1 or T, K*(K-1), n_edge_types] if rel_type.size(1) == 1: # static graph case rt_shp = rel_type.size sizes = [rt_shp(0), inputs.size(1), rt_shp(-2), rt_shp(-1)] rel_type = rel_type.expand(sizes) zrs = torch.zeros( (1, inputs.size(0), inputs.size(2), self.msg_out_shape), device=inputs.device) hidden = (zrs, zrs) pred_all = [] first_step = True rollout_zeros = self.rollout_zeros and \ (self.rollout_zeros_in_train or not self.training) for step in range(0, inputs.size(1) - 1): if burn_in: if step <= burn_in_steps or self.input_delta: ins = inputs[:, step, :, :] if self.input_delta and not first_step: ins = ins - pred_all[step - 1] else: ins = pred_all[step - 1] else: assert (pred_steps <= time_steps) # Use ground truth trajectory inputs vs. last prediction if not step % pred_steps or self.input_delta: ins = inputs[:, step, :, :] if self.input_delta: if not first_step and not rollout_zeros: ins = ins - pred_all[step - 1] elif rollout_zeros: ins = torch.zeros_like(ins, device=ins.device) else: ins = pred_all[step - 1] if dynamic_graph and step >= burn_in_steps: # Note assumes burn_in_steps == args.timesteps logits = encoder( data[:, :, step - burn_in_steps:step, :].contiguous(), self.rr_full, self.rs_full) curr_rel_type = gumbel_softmax(logits, tau=temp, hard=True) else: curr_rel_type = rel_type[:, step, :, :] pred, hidden = self.single_step_forward(ins, hidden, curr_rel_type) pred_all.append(pred) first_step = False preds = torch.stack(pred_all, dim=1) return preds.transpose(1, 2).contiguous()