def forward(self, inputs): x, u = inputs x = self.bn0(x) x = F.tanh(self.linear1(x)) x = F.tanh(self.linear2(x)) V = self.V(x) mu = F.tanh(self.mu(x)) Q = None if u is not None: num_outputs = mu.size(1) L = self.L(x).view(-1, num_outputs, num_outputs) L = L * \ self.tril_mask.expand_as( L) + torch.exp(L) * self.diag_mask.expand_as(L) P = torch.bmm(L, L.transpose(2, 1)) u_mu = (u - mu).unsqueeze(2) A = -0.5 * \ torch.bmm(torch.bmm(u_mu.transpose(2, 1), P), u_mu)[:, :, 0] Q = A + V return mu, Q, V
def predict(self, x_de, x_en): bs = x_de.size(0) emb_de = self.embedding_de(x_de) # bs,n_de,word_dim emb_en = self.embedding_en(x_en) # bs,n_en,word_dim h_enc = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()) c_enc = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()) h_dec = Variable(torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()) c_dec = Variable(torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()) enc_h, _ = self.encoder(emb_de, (h_enc, c_enc)) # (bs,n_de,hiddensz*2) dec_h, _ = self.decoder(emb_en, (h_dec, c_dec)) # (bs,n_en,hiddensz) # all the same. enc_h is bs,n_de,hiddensz*n_directions. h and c are both n_layers*n_directions,bs,hiddensz if self.directions == 2: scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2)) else: scores = torch.bmm(enc_h, dec_h.transpose(1,2)) # (bs,n_de,hiddensz) * (bs,hiddensz,n_en) = (bs,n_de,n_en) scores[(x_de == pad_token).unsqueeze(2).expand(scores.size())] = -math.inf # binary mask attn_dist = F.softmax(scores,dim=1) # bs,n_de,n_en context = torch.bmm(attn_dist.transpose(2,1),enc_h) # (bs,n_en,n_de) * (bs,n_de,hiddensz*ndirections) = (bs,n_en,hiddensz*ndirections) pred = self.vocab_layer(torch.cat([dec_h,context],2)) # bs,n_en,len(EN.vocab) pred = pred[:,:-1,:] # alignment _, tokens = pred.max(2) # bs,n_en-1 sauce = Variable(torch.cuda.LongTensor([[sos_token]]*bs)) # bs return torch.cat([sauce,tokens],1), attn_dist
def forward(self, feat, right, wrong, batch_wrong, fake=None, fake_diff_mask=None): num_wrong = wrong.size(1) batch_size = feat.size(0) feat = feat.view(-1, self.ninp, 1) right_dis = torch.bmm(right.view(-1, 1, self.ninp), feat) wrong_dis = torch.bmm(wrong, feat) batch_wrong_dis = torch.bmm(batch_wrong, feat) wrong_score = torch.sum(torch.exp(wrong_dis - right_dis.expand_as(wrong_dis)),1) \ + torch.sum(torch.exp(batch_wrong_dis - right_dis.expand_as(batch_wrong_dis)),1) loss_dis = torch.sum(torch.log(wrong_score + 1)) loss_norm = right.norm() + feat.norm() + wrong.norm() + batch_wrong.norm() if fake: fake_dis = torch.bmm(fake.view(-1, 1, self.ninp), feat) fake_score = torch.masked_select(torch.exp(fake_dis - right_dis), fake_diff_mask) margin_score = F.relu(torch.log(fake_score + 1) - self.margin) loss_fake = torch.sum(margin_score) loss_dis += loss_fake loss_norm += fake.norm() loss = (loss_dis + 0.1 * loss_norm) / batch_size if fake: return loss, loss_fake.data[0] / batch_size else: return loss
def forward(self, vocab): with torch.no_grad(): batch_shape = vocab['sentence'].shape s_embedding = self.embedding(vocab['sentence'].cuda()) a_embedding = self.embedding(vocab['aspect'].cuda()) packed_s = pack_padded_sequence(s_embedding, vocab['sent_len'], batch_first=True) out_s, (h_s, c1) = self.lstm_s(packed_s) # packed output out_a, (h_a, c2) = self.lstm_a(a_embedding) with torch.no_grad(): unpacked_out_s, _ = pad_packed_sequence(out_s, batch_first=True) # Pair-wise interaction matrix I_matrix = torch.bmm(unpacked_out_s, out_a.permute(0,2,1)) # Column-wise softmax a2s_attn = F.softmax(I_matrix, dim=1) # Row-wise softmax => Column-wise average => aspect attention s2a_attn = F.softmax(I_matrix, dim=2) a_attn = torch.mean(s2a_attn, dim=1) # Final sentence attn => weighted sum of each individual a2s_attn s_attn = torch.bmm(a2s_attn, a_attn.unsqueeze(-1)) final_rep = torch.bmm(unpacked_out_s.permute(0,2,1), s_attn).squeeze(-1) pred = self.fc(final_rep) return pred
def predict(self, x, attn_type = "hard"): #predict with greedy decoding emb = self.embedding(x) h = Variable(torch.zeros(1, x.size(0), self.hidden_dim)) c = Variable(torch.zeros(1, x.size(0), self.hidden_dim)) enc_h, _ = self.encoder(emb, (h, c)) y = [Variable(torch.zeros(x.size(0)).long())] self.attn = [] for t in range(x.size(1)): emb_t = self.embedding(y[-1]) dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c)) scores = torch.bmm(enc_h, dec_h.transpose(1,2)).squeeze(2) attn_dist = F.softmax(scores, dim = 1) self.attn.append(attn_dist.data) if attn_type == "hard": _, argmax = attn_dist.max(1) one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1)) context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1) else: context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1)) _, next_token = pred.max(1) y.append(next_token) self.attn = torch.stack(self.attn, 0).transpose(0, 1) return torch.stack(y, 0).transpose(0, 1)
def forward(self, ht, hs, mask, weighted_ctx=True): ''' ht: batch x ht_dim hs: (seq_len x batch x hs_dim, seq_len x batch x ht_dim) mask: seq_len x batch ''' hs, hs_ = hs # seq_len, batch, _ = hs.size() hs = hs.transpose(0, 1) hs_ = hs_.transpose(0, 1) # hs: batch x seq_len x hs_dim # hs_: batch x seq_len x ht_dim # hs_ = self.hs2ht(hs) # Alignment/Attention Function # batch x ht_dim x 1 ht = ht.unsqueeze(2) # batch x seq_len score = torch.bmm(hs_, ht).squeeze(2) # attn = F.softmax(score, dim=-1) attn = F.softmax(score, dim=-1) * mask.transpose(0, 1) + EPSILON attn = attn / attn.sum(-1, keepdim=True) # Compute weighted sum of hs by attention. # batch x 1 x seq_len attn = attn.unsqueeze(1) if weighted_ctx: # batch x hs_dim weight_hs = torch.bmm(attn, hs).squeeze(1) else: weight_hs = None return weight_hs, attn
def forward(self, agent_qs, states): """Forward pass for the mixer. Arguments: agent_qs: Tensor of shape [B, T, n_agents, n_actions] states: Tensor of shape [B, T, state_dim] """ bs = agent_qs.size(0) states = states.reshape(-1, self.state_dim) agent_qs = agent_qs.view(-1, 1, self.n_agents) # First layer w1 = th.abs(self.hyper_w_1(states)) b1 = self.hyper_b_1(states) w1 = w1.view(-1, self.n_agents, self.embed_dim) b1 = b1.view(-1, 1, self.embed_dim) hidden = F.elu(th.bmm(agent_qs, w1) + b1) # Second layer w_final = th.abs(self.hyper_w_final(states)) w_final = w_final.view(-1, self.embed_dim, 1) # State-dependent bias v = self.V(states).view(-1, 1, 1) # Compute final output y = th.bmm(hidden, w_final) + v # Reshape and return q_tot = y.view(bs, -1, 1) return q_tot
def forward(self, output, context): batch_size = output.size(0) hidden_size = output.size(2) input_size = context.size(1) # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len) attn = torch.bmm(output, context.transpose(1, 2)) mask = torch.eq(attn, 0).data.byte() attn.data.masked_fill_(mask, -float('inf')) attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size) # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim) mix = torch.bmm(attn, context) # concat -> (batch, out_len, 2*dim) combined = torch.cat((mix, output), dim=2) # output -> (batch, out_len, dim) output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size) if not output.is_contiguous(): output = output.contiguous() return output, attn
def forward(self, q, k, v): b_q, t_q, dim_q = list(q.size()) b_k, t_k, dim_k = list(k.size()) b_v, t_v, dim_v = list(v.size()) assert(b_q == b_k and b_k == b_v) # batch size should be equal assert(dim_q == dim_k) # dims should be equal assert(t_k == t_v) # times should be equal b = b_q qk = torch.bmm(q, k.transpose(1, 2)) # b x t_q x t_k qk.div_(dim_k ** 0.5) mask = None if self.causal and t_q > 1: causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1) mask = Variable(causal_mask.unsqueeze(0).expand(b, t_q, t_k), requires_grad=False) if self.mask_k is not None: mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k) mask = mask_k if mask is None else mask | mask_k if self.mask_q is not None: mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k) mask = mask_q if mask is None else mask | mask_q if mask is not None: qk.masked_fill_(mask, -1e9) sm_qk = F.softmax(qk, dim=2) sm_qk = self.dropout(sm_qk) return torch.bmm(sm_qk, v), sm_qk # b x t_q x dim_v
def forward_dot(self, hid, ctx, ctx_mask): r"""Computes Luong-style dot attention probabilities between decoder's hidden state and source annotations. Arguments: hid(Variable): A set of decoder hidden states of shape `T*B*H` where `T` == 1, `B` is batch dim and `H` is hidden state dim. ctx(Variable): A set of annotations of shape `S*B*C` where `S` is the source timestep dim, `B` is batch dim and `C` is annotation dim. ctx_mask(FloatTensor): A binary mask of shape `S*B` with zeroes in the padded timesteps. Returns: scores(Variable): A variable of shape `S*B` containing normalized attention scores for each position and sample. z_t(Variable): A variable of shape `B*H` containing the final attended context vector for this target decoding timestep. """ # Apply transformations first to make last dims both C and then # shuffle dims to prepare for batch mat-mult ctx_ = self.ctx2ctx(ctx).permute(1, 2, 0) # S*B*C -> S*B*C -> B*C*S hid_ = self.hid2ctx(hid).permute(1, 0, 2) # T*B*H -> T*B*C -> B*T*C # 'dot' scores of B*T*S scores = F.softmax(torch.bmm(hid_, ctx_), dim=-1) # Transform back to hidden_dim for further decoders # B*T*S x B*S*C -> B*T*C -> B*T*H z_t = self.ctx2hid(torch.bmm(scores, ctx.transpose(0, 1))) return scores.transpose(0, 1), z_t.transpose(0, 1)
def bnorm(x, U): mx = torch.bmm(U,x) subs = x-mx subs2 = subs*subs vx = torch.bmm(U,subs2) out = subs / (vx.clamp(min=1e-10).sqrt() + 1e-5) return out
def forward(self, q, k, v, attn_mask=None): d_k, d_v = self.d_k, self.d_v n_head = self.n_head residual = q #print('q,k,v:',q.size(),k.size(),v.size()) mb_size, len_q, q_hidden_size = q.size() mb_size, len_k, k_hidden_size = k.size() mb_size, len_v, v_hidden_size = v.size() # treat as a (n_head) size batch q_s = q.repeat(n_head, 1, 1).view(n_head, -1, q_hidden_size) # n_head x (mb_size*len_q) x d_model k_s = k.repeat(n_head, 1, 1).view(n_head, -1, k_hidden_size) # n_head x (mb_size*len_k) x d_model v_s = v.repeat(n_head, 1, 1).view(n_head, -1, v_hidden_size) # n_head x (mb_size*len_v) x d_model #print('q_s,k_s,v_s:',q_s.size(),k_s.size(),v_s.size()) #print('w_qs',self.w_qs.size()) # treat the result as a (n_head * mb_size) size batch q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v # perform attention, result size = (n_head * mb_size) x len_q x d_v #print('attn_mask:',attn_mask.size()) #print(attn_mask) outputs, attns = self.attention.forward(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head,1,1)) # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v) outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) # project back to residual size outputs = self.proj.forward(outputs) outputs = self.dropout(outputs) return self.layer_norm(outputs + residual), attns
def predict2(self, x_de, beamsz, gen_len): emb_de = self.embedding_de(x_de) # "batch size",n_de,word_dim, but "batch size" is 1 in this case! h0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda()) c0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda()) enc_h, _ = self.encoder(emb_de, (h0, c0)) # since enc batch size=1, enc_h is 1,n_de,hiddensz*n_directions if self.directions == 2: enc_h = self.dim_reduce(enc_h) # 1,n_de,hiddensz masterheap = CandList(self.n_layers,self.hidden_dim,enc_h.size(1),beamsz) # in the following loop, beamsz is length 1 for first iteration, length true beamsz (100) afterward for i in range(gen_len): prev = masterheap.get_prev() # beamsz emb_t = self.embedding_en(prev) # embed the last thing we generated. beamsz,word_dim enc_h_expand = enc_h.expand(prev.size(0),-1,-1) # beamsz,n_de,hiddensz h, c = masterheap.get_hiddens() # (n_layers,beamsz,hiddensz),(n_layers,beamsz,hiddensz) dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c)) # dec_h is beamsz,1,hiddensz (batch_first=True) scores = torch.bmm(enc_h_expand, dec_h.transpose(1,2)).squeeze(2) # (beamsz,n_de,hiddensz) * (beamsz,hiddensz,1) = (beamsz,n_de,1). squeeze to beamsz,n_de attn_dist = F.softmax(scores,dim=1) if self.attn_type == "hard": _, argmax = attn_dist.max(1) # beamsz for each batch, select most likely german word to pay attention to one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda()) context = torch.bmm(one_hot.unsqueeze(1), enc_h_expand).squeeze(1) else: context = torch.bmm(attn_dist.unsqueeze(1), enc_h_expand).squeeze(1) # the difference btwn hard and soft is just whether we use a one_hot or a distribution # context is beamsz,hiddensz*n_directions pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1)) # beamsz,len(EN.vocab) # TODO: set the columns corresponding to <pad>,<unk>,</s>,etc to 0 masterheap.update_beam(pred) masterheap.update_hiddens(h,c) masterheap.update_attentions(attn_dist) masterheap.firstloop = False return masterheap.probs,masterheap.wordlist,masterheap.attentions
def predict(self, x_de, x_en): bs = x_de.size(0) emb_de = self.embedding_de(x_de) # bs,n_de,word_dim emb_en = self.embedding_en(x_en) # bs,n_en,word_dim h = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()) c = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()) enc_h, _ = self.encoder(emb_de, (h, c)) dec_h, _ = self.decoder(emb_en, (h, c)) # all the same. enc_h is bs,n_de,hiddensz*n_directions. h and c are both n_layers*n_directions,bs,hiddensz if self.directions == 2: enc_h = self.dim_reduce(enc_h) # bs,n_de,hiddensz scores = torch.bmm(enc_h, dec_h.transpose(1,2)) # (bs,n_de,hiddensz) * (bs,hiddensz,n_en) = (bs,n_de,n_en) y = [Variable(torch.cuda.LongTensor([sos_token]*bs))] # bs self.attn = [] for t in range(x_en.size(1)-1): # iterate over english words, with teacher forcing attn_dist = F.softmax(scores[:,:,t],dim=1) # bs,n_de self.attn.append(attn_dist.data) if self.attn_type == "hard": _, argmax = attn_dist.max(1) # bs. for each batch, select most likely german word to pay attention to one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda()) context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1) else: context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # the difference btwn hard and soft is just whether we use a one_hot or a distribution # context is bs,hiddensz pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab) _, next_token = pred.max(1) # bs y.append(next_token) self.attn = torch.stack(self.attn, 0).transpose(0, 1) # bs,n_en,n_de (for visualization!) y = torch.stack(y,0).transpose(0,1) # bs,n_en return y,self.attn
def backward(ctx, grad_output): batch1, batch2 = ctx.saved_variables grad_add_matrix = grad_batch1 = grad_batch2 = None if ctx.needs_input_grad[0]: grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size) if ctx.alpha != 1: grad_add_matrix = grad_add_matrix.mul(ctx.alpha) if any(ctx.needs_input_grad[1:]): batch_grad_output = (grad_output .unsqueeze(0) .expand(batch1.size(0), batch1.size(1), batch2.size(2))) if ctx.needs_input_grad[1]: grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) if ctx.beta != 1: grad_batch1 *= ctx.beta if ctx.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output) if ctx.beta != 1: grad_batch2 *= ctx.beta return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def lstsq(b, y, alpha=0.01): """ Batched linear least-squares for pytorch with optional L1 regularization. Parameters ---------- b : shape(L, M, N) y : shape(L, M) Returns ------- tuple of (coefficients, model, residuals) """ bT = b.transpose(-1, -2) AA = torch.bmm(bT, b) if alpha != 0: diag = torch.diagonal(AA, dim1=1, dim2=2) diag += alpha RHS = torch.bmm(bT, y[:, :, None]) X, LU = torch.gesv(RHS, AA) fit = torch.bmm(b, X)[..., 0] res = y - fit return X[..., 0], fit, res
def forward(self, x_de, x_en, update_baseline=True): bs = x_de.size(0) # x_de is bs,n_de. x_en is bs,n_en emb_de = self.embedding_de(x_de) # bs,n_de,word_dim emb_en = self.embedding_en(x_en) # bs,n_en,word_dim h0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda() c0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda() h0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda() c0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda() # hidden vars have dimension n_layers*n_directions,bs,hiddensz enc_h, _ = self.encoder(emb_de, (Variable(h0_enc), Variable(c0_enc))) # enc_h is bs,n_de,hiddensz*n_directions. ordering is different from last week because batch_first=True dec_h, _ = self.decoder(emb_en, (Variable(h0_dec), Variable(c0_dec))) # dec_h is bs,n_en,hidden_size*n_directions # we've gotten our encoder/decoder hidden states so we are ready to do attention # first let's get all our scores, which we can do easily since we are using dot-prod attention if self.directions == 2: scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2)) # TODO: any easier ways to reduce dimension? else: scores = torch.bmm(enc_h, dec_h.transpose(1,2)) # (bs,n_de,hiddensz*n_directions) * (bs,hiddensz*n_directions,n_en) = (bs,n_de,n_en) reinforce_loss = 0 # we only use this variable for hard attention loss = 0 avg_reward = 0 # we just iterate to dec_h.size(1)-1, since there's </s> at the end of each sentence for t in range(dec_h.size(1)-1): # iterate over english words, with teacher forcing attn_dist = F.softmax(scores[:, :, t],dim=1) # bs,n_de. these are the alphas (attention scores for each german word) if self.attn_type == "hard": cat = torch.distributions.Categorical(attn_dist) attn_samples = cat.sample() # bs. each element is a sample from categorical distribution one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, attn_samples.data.unsqueeze(1), 1).cuda()) # bs,n_de # made a bunch of one-hot vectors context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1) # now we use the one-hot vectors to select correct hidden vectors from enc_h # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions). squeeze to bs,hiddensz*n_directions else: context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # same dimensions # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions) # context is bs,hidden_size*n_directions # the rnn output and the context together make the decoder "hidden state", which is bs,2*hidden_size*n_directions pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab) y = x_en[:, t+1] # bs. these are our labels no_pad = (y != pad_token) # exclude english padding tokens reward = torch.gather(pred, 1, y.unsqueeze(1)) # bs,1 # reward[i,1] = pred[i,y[i]]. this gets log prob of correct word for each batch. similar to -crossentropy reward = reward.squeeze(1)[no_pad] # less than bs if self.attn_type == "hard": reinforce_loss -= (cat.log_prob(attn_samples[no_pad]) * (reward-self.baseline).detach()).sum() # reinforce rule (just read the formula), with special baseline loss -= reward.sum() # minimizing loss is maximizing reward no_pad_total = (x_en[:,1:] != pad_token).data.sum() # TODO: i think this is right, right? loss /= no_pad_total reinforce_loss /= no_pad_total avg_reward = -loss.data[0] if update_baseline: # update baseline as a moving average self.baseline = Variable(0.95*self.baseline.data + 0.05*avg_reward) return loss, reinforce_loss,avg_reward
def forward(self, match_encoders): ''' match_encoders (pn_steps, batch, hidden_size*2) ''' vh_matrix = self.vh_net(match_encoders) # pn_steps, batch, hidden_size # prediction start h0 = Variable(torch.zeros(match_encoders.size(1), self.hidden_size)).cuda() c0 = Variable(torch.zeros(match_encoders.size(1), self.hidden_size)).cuda() wha1 = self.wa_net(h0) # bacth, hidden_size wha1 = wha1.expand(match_encoders.size(0), wha1.size(0), wha1.size(1)) # pn_steps, batch, hidden_size #print ('_sum.size() ', _sum.size()) #print ('vh_matrix.size() ', vh_matrix.size()) f1 = self.tanh(vh_matrix + wha1) # pn_steps, batch, hidden_size #print ('f1.size() ', f1.size()) vf1 = self.v_net(f1.transpose(0, 1)).squeeze(-1) #batch, pn_steps beta1 = self.softmax(vf1) #batch, pn_steps softmax_beta1 = self.softmax(beta1).view(beta1.size(0), 1, beta1.size(1)) #batch, 1, pn_steps inp = torch.bmm(softmax_beta1, match_encoders.transpose(0, 1)) # bacth, 1, hidden_size inp = inp.squeeze(1) # bacth, hidden_size h1, c1 = self.pointer_lstm(inp, (h0, c0)) wha2 = self.wa_net(h1) # bacth, hidden_size wha2 = wha2.expand(match_encoders.size(0), wha2.size(0), wha2.size(1)) # pn_steps, batch, hidden_size f2 = self.tanh(vh_matrix + wha2) # pn_steps, batch, hidden_size vf2 = self.v_net(f2.transpose(0, 1)).squeeze(-1) #batch, pn_steps beta2 = self.softmax(vf2)#batch, pn_steps softmax_beta2 = self.softmax(beta2).view(beta2.size(0), 1, beta2.size(1)) #batch, 1, pn_steps inp = torch.bmm(softmax_beta2, match_encoders.transpose(0, 1)) # bacth, 1, hidden_size inp = inp.squeeze(1) # bacth, hidden_size h2, c2 = self.pointer_lstm(inp, (h1, c1)) _, start = torch.max(beta1, 1) _, end = torch.max(beta2, 1) beta1 = beta1.view(1, beta1.size(0), beta1.size(1)) beta2 = beta2.view(1, beta2.size(0), beta2.size(1)) logits = torch.cat([beta1, beta2]) start = start.view(1, start.size(0)) end = end.view(1, end.size(0)) prediction = torch.cat([start, end]).transpose(0, 1).cpu().data.numpy() return logits, prediction
def forward(self, query_embeddings, in_memory_embeddings, out_memory_embeddings, attention_mask=None): attention = torch.bmm(in_memory_embeddings, query_embeddings.unsqueeze(2)).squeeze(2) if attention_mask is not None: # exclude masked elements from the softmax attention = attention_mask.float() * attention + (1 - attention_mask.float()) * -1e20 probs = softmax(attention).unsqueeze(1) memory_output = torch.bmm(probs, out_memory_embeddings).squeeze(1) query_embeddings = self.linear(query_embeddings) output = memory_output + query_embeddings return output
def lk_tensor_track_batch(feature_old, feature_new, pts_locations, patch_size, max_step, feature_template=None): # feature[old,new] : 4-D tensor [1, C, H, W] # pts_locations is a 2-D tensor [Num-Pts, (Y,X)] if feature_new.dim() == 3: feature_new = feature_new.unsqueeze(0) if feature_old is not None and feature_old.dim() == 3: feature_old = feature_old.unsqueeze(0) assert feature_new.dim() == 4, 'The dimension of feature-new is not right : {}.'.format(feature_new.dim()) BB, C, H, W = list(feature_new.size()) if feature_old is not None: assert 1 == feature_old.size(0) and 1 == BB, 'The first dimension of feature should be one not {}'.format(feature_old.size()) assert C == feature_old.size(1) and H == feature_old.size(2) and W == feature_old.size(3), 'The size is not right : {}'.format(feature_old.size()) assert isinstance(patch_size, int), 'The format of lk-parameters are not right : {}'.format(patch_size) num_pts = pts_locations.size(0) device = feature_new.device weight_map = Generate_Weight( [patch_size*2+1, patch_size*2+1] ) # [H, W] with torch.no_grad(): weight_map = torch.tensor(weight_map).view(1, 1, 1, patch_size*2+1, patch_size*2+1).to(device) sobelconvx = SobelConv('x', feature_new.dtype).to(device) sobelconvy = SobelConv('y', feature_new.dtype).to(device) # feature_T should be a [num_pts, C, patch, patch] tensor if feature_template is None: feature_T = warp_feature_batch(feature_old, pts_locations, patch_size) else: assert feature_old is None, 'When feature_template is not None. feature_old must be None' feature_T = feature_template assert feature_T.size(2) == patch_size * 2 + 1 and feature_T.size(3) == patch_size * 2 + 1, 'The size of feature-template is not ok : {}'.format(feature_T.size()) gradiant_x = sobelconvx(feature_T) gradiant_y = sobelconvy(feature_T) J = torch.stack([gradiant_x, gradiant_y], dim=1) weightedJ = J * weight_map H = torch.bmm( weightedJ.view(num_pts,2,-1), J.view(num_pts, 2, -1).transpose(2,1) ) inverseH = torch_inverse_batch(H) #print ('PTS : {}'.format(pts_locations)) for step in range(max_step): # Step-1 Warp I with W(x,p) to compute I(W(x;p)) feature_I = warp_feature_batch(feature_new, pts_locations, patch_size) # Step-2 Compute the error feature r = feature_I - feature_T # Step-7 Compute sigma sigma = torch.bmm(weightedJ.view(num_pts,2,-1), r.view(num_pts,-1, 1)) # Step-8 Compute delta-p deltap = torch.bmm(inverseH, sigma).squeeze(-1) pts_locations = pts_locations - deltap return pts_locations
def updateGradInput(self, input, gradOutput): M, v = input self.gradInput[0].resize_as_(M) self.gradInput[1].resize_as_(v) gradOutput = gradOutput.contiguous() assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2 if gradOutput.ndimension() == 2: assert M.ndimension() == 3 assert v.ndimension() == 2 bdim = M.size(0) odim = M.size(1) idim = M.size(2) if self.trans: torch.bmm(v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim), out=self.gradInput[0]) torch.bmm(M, gradOutput.view(bdim, idim, 1), out=self.gradInput[1].view(bdim, odim, 1)) else: torch.bmm(gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim), out=self.gradInput[0]) torch.bmm(M.transpose(1, 2), gradOutput.view(bdim, odim, 1), out=self.gradInput[1].view(bdim, idim, 1)) else: assert M.ndimension() == 2 assert v.ndimension() == 1 if self.trans: torch.ger(v, gradOutput, out=self.gradInput[0]) self.gradInput[1] = M * gradOutput else: torch.ger(gradOutput, v, out=self.gradInput[0]) self.gradInput[1] = M.t() * gradOutput return self.gradInput
def dot_prod_attention(self, h_t: torch.Tensor, src_encoding: torch.Tensor, src_encoding_att_linear: torch.Tensor, mask: torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]: # (batch_size, src_sent_len) att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2) if mask is not None: att_weight.data.masked_fill_(mask.byte(), -float('inf')) softmaxed_att_weight = F.softmax(att_weight, dim=-1) att_view = (att_weight.size(0), 1, att_weight.size(1)) # (batch_size, hidden_size) ctx_vec = torch.bmm(softmaxed_att_weight.view(*att_view), src_encoding).squeeze(1) return ctx_vec, softmaxed_att_weight
def forward(self, input, context): if not self.dot: targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1 else: targetT = input.unsqueeze(2) context_scores = torch.bmm(context, targetT).squeeze(2) context_scores.data.masked_fill_(self.context_mask, -float('inf')) context_attention = F.softmax(context_scores, dim=-1) + EPSILON context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1) combined_representation = torch.cat([input, context_alignment], 1) output = self.tanh(self.linear_out(combined_representation)) return output, context_attention, context_alignment
def forward_predict(self, x, graphs, data_agent, constrain_=True): graphs = deepcopy(graphs) opt = self.opt batch_size = x.size(0) h_0 = Variable(torch.zeros(opt['rnn_layers'], batch_size, opt['rnn_h'])) if opt['cuda']: h_0 = h_0.cuda() enc_out, hidden = self.encoder(self.input_emb(x), h_0) # [batch, seq_in, h], [layer, batch, h] text_out = [[] for _ in range(batch_size)] for i in range(opt['max_seq_out']): if i == 0: y_in = Variable(torch.zeros(batch_size, 1, self.y_dim)) if opt['cuda']: y_in = y_in.cuda() else: y_in = y_onehot.unsqueeze(1) dec_out, hidden = self.decoder(y_in, hidden) alpha = F.softmax(torch.bmm(enc_out, hidden[-1].unsqueeze(2))) attention = torch.bmm(enc_out.transpose(1, 2), alpha).squeeze(2) dec_out = self.mapping(torch.cat([attention, dec_out.squeeze(1)], dim=1)) # [batch, y_dim] y_mask = torch.zeros(batch_size, self.y_dim) for j in range(batch_size): data_agent.get_mask(graphs[j], y_mask[j]) if opt['cuda']: y_mask = y_mask.cuda() y_mask = Variable(y_mask, volatile=True) if constrain_: dec_out = dec_out * y_mask + -1e7 * (1 - y_mask) y_out = torch.max(dec_out, 1, keepdim=True)[1].data # [batch, 1] y_onehot = torch.zeros(batch_size, self.y_dim) # [batch, y_dim] y_onehot.scatter_(1, y_out.cpu(), 1) y_onehot = Variable(y_onehot) if opt['cuda']: y_onehot = y_onehot.cuda() y_out = y_out.squeeze() for j in range(batch_size): if len(text_out[j]) > 0 and text_out[j][-1] == 'STOP': continue text_out[j].append(data_agent.reverse_parse_action(data_agent.get_action_tuple(y_out[j]))) if text_out[j][-1] != 'STOP': exec_result = graphs[j].parse_exec(text_out[j][-1]) if constrain_: assert exec_result, text_out[j][-1] return text_out
def forward(self, video_instances, resnet_ftrs, optical_ftrs, object_ftrs): #################### Attention Level 1 ################################# total_instances = sum(video_instances) attn1 = self.attn1(object_ftrs) # [700, 100, 4, 4] attn1 = attn1.view(total_instances*self.num_frames, self.obj_per_frame, 1) # [70000, 4, 1] object_ftrs = object_ftrs.view(total_instances*self.num_frames, \ resnet_dim, self.obj_per_frame) # [700,100,4,2048] to [70000,2048,4] for bmm object_attended = torch.bmm(object_ftrs, attn1) # [70000, 2048, 4] and [70000, 4, 1] to [70000, 2048, 1] object_attended = object_attended.view(total_instances, self.num_frames, resnet_dim) # [70000, 2048, 1] to [700, 100, 2048] ###################### Attention Level 2 ############################### all_features = torch.cat((object_attended, resnet_ftrs, optical_ftrs), 2) # [700, 100, 3*2048] attn2 = self.attn2(all_features) # [700, 100, 3] attn2 = attn2.view(total_instances*self.num_frames, 3, 1) # [70000, 3, 1] all_features = all_features.view(total_instances*self.num_frames, \ resnet_dim, 3) # [700, 100, 3*2048] to [70000, 2048, 3] features_attended = torch.bmm(all_features, attn2) # [70000, 2048, 3] and [70000, 3, 1] to [70000, 2048, 1] features_attended = features_attended.view(total_instances, self.num_frames, resnet_dim) # [70000, 2048, 1] to [700, 100, 2048] ##################### Attention Level 3 ################################ video_feature = features_attended.view(total_instances, self.num_frames* resnet_dim) # [700, 100, 2048] to [700, 100*2048] attn3 = self.attn3(video_feature) # [700, 100] attn3 = attn3.unsqueeze(2).repeat(1, 1, resnet_dim) # [700, 100] to [700, 100, 1] to [700, 100, 2048] video_feature = video_feature.view(total_instances, self.num_frames, resnet_dim) video_attended = video_feature * attn3 # [700, 100, 2048] return video_attended, self.attn1, self.attn2, self.attn3
def forward(self, context_ids, doc_ids, target_noise_ids): """Sparse computation of scores (unnormalized log probabilities) that should be passed to the negative sampling loss. Parameters ---------- context_ids: torch.Tensor of size (batch_size, num_context_words) Vocabulary indices of context words. doc_ids: torch.Tensor of size (batch_size,) Document indices of paragraphs. target_noise_ids: torch.Tensor of size (batch_size, num_noise_words + 1) Vocabulary indices of target and noise words. The first element in each row is the ground truth index (i.e. the target), other elements are indices of samples from the noise distribution. Returns ------- autograd.Variable of size (batch_size, num_noise_words + 1) """ # combine a paragraph vector with word vectors of # input (context) words x = torch.add( self._D[doc_ids, :], torch.sum(self._W[context_ids, :], dim=1)) # sparse computation of scores (unnormalized log probabilities) # for negative sampling return torch.bmm( x.unsqueeze(1), self._O[:, target_noise_ids].permute(1, 0, 2)).squeeze()
def forward(self, s_t_hat, h, enc_padding_mask, coverage): b, t_k, n = list(h.size()) h = h.view(-1, n) # B * t_k x 2*hidden_dim encoder_feature = self.W_h(h) dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim dec_fea_expanded = dec_fea_expanded.view(-1, n) # B * t_k x 2*hidden_dim att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim if config.is_coverage: coverage_input = coverage.view(-1, 1) # B * t_k x 1 coverage_feature = self.W_c(coverage_input) # B * t_k x 2*hidden_dim att_features = att_features + coverage_feature e = F.tanh(att_features) # B * t_k x 2*hidden_dim scores = self.v(e) # B * t_k x 1 scores = scores.view(-1, t_k) # B x t_k attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k normalization_factor = attn_dist_.sum(1, keepdim=True) attn_dist = attn_dist_ / normalization_factor attn_dist = attn_dist.unsqueeze(1) # B x 1 x t_k h = h.view(-1, t_k, n) # B x t_k x 2*hidden_dim c_t = torch.bmm(attn_dist, h) # B x 1 x n c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim attn_dist = attn_dist.view(-1, t_k) # B x t_k if config.is_coverage: coverage = coverage.view(-1, t_k) coverage = coverage + attn_dist return c_t, attn_dist, coverage
def outer(vec1, vec2=None): '''Batch support for vectors outer products. This function is broadcast-able, so you can provide batched vec1 or batched vec2 or both. Args: vec1: A vector of size (Batch, Size1). vec2: A vector of size (Batch, Size2) if vec2 is None, vec2 = vec1. Returns: The outer product of vec1 and vec2 (Batch, Size1, Size2). ''' if vec2 is None: vec2 = vec1 if len(vec1.size()) == 1 and len(vec2.size()) == 1: return torch.ger(vec1, vec2) else: # batch outer product if len(vec1.size()) == 1: vec1 = torch.unsqueeze(vec1, 0) if len(vec2.size()) == 1: vec2 = torch.unsqueeze(vec2, 0) vec1 = torch.unsqueeze(vec1, -1) vec2 = torch.unsqueeze(vec2, -2) if vec1.size(0) == vec2.size(0): return torch.bmm(vec1, vec2) else: return vec1.matmul(vec2)
def forward(self, xt, fc_feats, att_feats, p_att_feats, state): # The p_att_feats here is already projected att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size att = p_att_feats.view(-1, att_size, self.att_hid_size) att_h = self.h2att(state[0][-1]) # batch * att_hid_size att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size dot = att + att_h # batch * att_size * att_hid_size dot = F.tanh(dot) # batch * att_size * att_hid_size dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size dot = self.alpha_net(dot) # (batch * att_size) * 1 dot = dot.view(-1, att_size) # batch * att_size weight = F.softmax(dot) # batch * att_size att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) sigmoid_chunk = F.sigmoid(sigmoid_chunk) in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ self.a2c(att_res) in_transform = torch.max(\ in_transform.narrow(1, 0, self.rnn_size), in_transform.narrow(1, self.rnn_size, self.rnn_size)) next_c = forget_gate * state[1][-1] + in_gate * in_transform next_h = out_gate * F.tanh(next_c) output = self.dropout(next_h) state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) return output, state
def score(self, hidden, encoder_outputs): # [B*T*2H]->[B*T*H] energy = self.attn(torch.cat([hidden, encoder_outputs], 2)) energy = energy.transpose(1, 2) # [B*H*T] v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H] energy = torch.bmm(v, energy) # [B*1*T] return energy.squeeze(1) # [B*T]
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) if reinforce: raise NotImplementedError('Our model doesn\'t have RL') # Predict the number of conditions # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict condition number. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc) num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, hidden=(cond_num_h1, cond_num_h2)) num_att_val = self.cond_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) cond_num_score = self.cond_num_out(K_cond_num) #Predict the columns of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc) h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) col_att_val = torch.bmm(e_cond_col, self.cond_col_att(h_col_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) cond_col_score = self.cond_col_out( self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: cond_col_score[b, num:] = -100 #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: # print gt_cond chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) op_att_val = torch.matmul( self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) cond_op_score = self.cond_op_out( self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_str_name_enc) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B * 4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) g_ext = g_str_s.unsqueeze(3) col_ext = col_emb.unsqueeze(2).unsqueeze(2) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze(4) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 else: h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32) init_inp[:, 0, 0] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B * 4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() data = torch.zeros(B * 4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: #To one-hot cur_inp = Variable(data.cuda()) else: cur_inp = Variable(data) cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] cond_score = (cond_num_score, cond_col_score, cond_op_score, cond_str_score) return cond_score
def forward(self, x, A1, diag): Amatrix = (A1 + A1.transpose(-2,-1)) A = Amatrix + diag.diag_embed(dim1=-2, dim2=-1) y = torch.bmm(A, x) return y
def step(self, Ybar_t: torch.tensor, dec_state: Tuple[torch.tensor, torch.tensor], enc_hiddens: torch.tensor, enc_hiddens_proj: torch.tensor, enc_masks: torch.tensor) -> Tuple[Tuple, torch.tensor, torch.tensor]: """ Compute one forward step of the LSTM decoder, including the attention computation. @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h_e). The input for the decoder, where b = batch size, e = embedding size, h = hidden size. @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h_d), where b = batch size, h_d = hidden_size_dec. First tensor is decoder's prev hidden state, second tensor is decoder's prev cell. @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h_e * 2), where b = batch size, src_len = maximum source length, h = hidden size. @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h_e * 2) to h. Tensor is with shape (b, src_len, h), where b = batch size, src_len = maximum source length, h = hidden size. @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len), where b = batch size, src_len is maximum source length. @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size. First tensor is decoder's new hidden state, second tensor is decoder's new cell. @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size. @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution. Note: You will not use this outside of this function. We are simply returning this value so that we can sanity check your implementation. """ combined_output = None e_t = None # YOUR CODE HERE (~3 Lines) # TODO: # 1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state. # 2. Split dec_state into its two parts (dec_hidden, dec_cell) # 3. Compute the attention scores e_t [src_len*2h*1], and alpha, a Tensor shape (b, src_len). # Note: b = batch_size, src_len = maximum source length, h = hidden size. # # Hints: # - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched) # - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched). # - Use batched matrix multiplication (torch.bmm) to compute e_t. # - To get the tensors into the right shapes for bmm, you'll need to do some squeezing and unsqueezing. # - When using the squeeze() function make sure to specify the dimension you want to squeeze # over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1. # # Use the following docs to implement this functionality: # Batch Multiplication: # https://pytorch.org/docs/stable/torch.html#torch.bmm # Tensor Unsqueeze: # https://pytorch.org/docs/stable/torch.html#torch.unsqueeze # Tensor Squeeze: # https://pytorch.org/docs/stable/torch.html#torch.squeeze # INPUTS: # # Ybar_t: [b x (e + h_enc)] <-- in pdf, Y_t is [e+h_dec x 1] # dec_state (OG): ([b x h_dec], [b x h_dec]) # DECODER: # self.decoder = nn.LSTMCell(embed_size + hidden_size_enc, hidden_size_dec) dec_state = dec_hidden, dec_cell = self.decoder(Ybar_t, dec_state) # ([b x h], [b x h]) # # COMPUTE E_T # e_t = self.attention_function(dec_hidden, enc_hiddens_proj) # END YOUR CODE # Set e_t to -inf where enc_masks has 1 if enc_masks is not None: e_t.data.masked_fill_(enc_masks.bool(), -float('inf')) # YOUR CODE HERE (~6 Lines) # TODO: # 1. Apply softmax to e_t to yield alpha_t # 2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the # attention output vector, a_t. # Hints: # - alpha_t is shape (b, src_len) # - enc_hiddens is shape (b, src_len, 2h) # - a_t should be shape (b, 2h) # - You will need to do some squeezing and unsqueezing. # Note: b = batch size, src_len = maximum source length, h = hidden size. # 3. Concatenate dec_hidden with a_t to compute tensor U_t # 4. Use the output projection layer to compute tensor V_t # 5. Compute tensor O_t using the Tanh function and the dropout layer. # # Use the following docs to implement this functionality: # Softmax: # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax # Batch Multiplication: # https://pytorch.org/docs/stable/torch.html#torch.bmm # Tensor View: # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view # Tensor Concatenation: # https://pytorch.org/docs/stable/torch.html#torch.cat # Tanh: # https://pytorch.org/docs/stable/torch.html#torch.tanh # # COMPUTE A # alpha_t = torch.nn.functional.softmax(e_t, 1) # [b x src_len] alpha_t = alpha_t.unsqueeze(1) # [b x 1 x src_len] a_t = torch.bmm(alpha_t, enc_hiddens) # [b,1,sl]*[b,sl,2h] -> [b,1,2h] a_t = a_t.squeeze(1) # [b,1,2h] -> [b,2h] U_t = torch.cat((dec_hidden, a_t), 1) # [b x h] + [b x 2h] = [b x 3h] V_t = self.combined_output_projection(U_t) # [h x 3h] * [b x 3h (x 1)] -> [b x h (x 1)] O_t = self.dropout( torch.tanh(V_t) ) # END YOUR CODE combined_output = O_t return dec_state, combined_output, e_t
def test(self, partition='test'): num_ways_test = self.config['num_ways'] num_shots_test = self.config['num_shots'] test_batch_size = self.config['num_batch'] test_iteration = 150 support_shot_nums = num_shots_test query_shot_nums = 1 val_data_loader = VRDDataLoader("test") log_flag = True best_acc = 0 predicate_detection = self.config['predicate_detection'] num_supports = num_ways_test * num_shots_test num_queries = num_ways_test * 1 num_samples = num_supports + num_queries support_edge_mask = torch.zeros(test_batch_size, num_samples, num_samples).cuda() support_edge_mask[:, :num_supports, :num_supports] = 1 query_edge_mask = 1 - support_edge_mask evaluation_mask = torch.ones(test_batch_size, num_samples, num_samples).cuda() query_edge_losses = [] query_edge_accrs = [] query_node_accrs = [] total_nodes_acc = [] test_acc_vec = [] feature_dim = self.config['feature_dim'] for iter in range(test_iteration // test_batch_size): support_all_input, query_all_input, os_label, [support_label, query_label], [idx_for_class, idx_for_data] = val_data_loader.get_task_batch( num_tasks=test_batch_size, num_ways=num_ways_test, num_shots=num_shots_test, seed=iter) os_support_full_label = torch.cat( [os_label[0].view(test_batch_size, -1).cuda(), os_label[1].view(test_batch_size, -1).cuda()], 1) os_query_full_label = torch.cat( [os_label[2].view(test_batch_size, -1).cuda(), os_label[3].view(test_batch_size, -1).cuda()], 1) os_full_label = torch.cat([os_support_full_label, os_query_full_label], 1) os_full_edge = self.label2edge(os_full_label) support_label = support_label.cuda() query_label = query_label.cuda() full_label = torch.cat([support_label, query_label], 1) full_edge = self.label2edge(full_label) init_edge = full_edge.clone() init_edge[:, :, num_supports:, :] = 0.5 init_edge[:, :, :, num_supports:] = 0.5 for i in range(num_queries): init_edge[:, 0, num_supports + i, num_supports + i] = 1.0 init_edge[:, 1, num_supports + i, num_supports + i] = 0.0 # set as train mode self.enc_module.eval() self.gnn_module.eval() support_subject_emb = support_all_input[4].cuda() support_object_emb = support_all_input[5].cuda() query_subject_emb = query_all_input[4].cuda() query_object_emb = query_all_input[5].cuda() support_predicate_input = support_all_input[1].view(test_batch_size * num_ways_test, 1, num_shots_test, -1).cuda() query_predicate_input = query_all_input[1].view(test_batch_size * num_ways_test, 1, 1, -1).cuda() all_support_full_data, support_mapping_subj_feature, support_mapping_obj_feature = self.enc_module( support_all_input[0].cuda(), support_predicate_input, support_all_input[2].cuda(), support_all_input[3].cuda(), support_subject_emb, support_object_emb, support_shot_nums) all_query_full_data, query_mapping_subj_feature, query_mapping_obj_feature = self.enc_module( query_all_input[0].cuda(), query_predicate_input, query_all_input[2].cuda(), query_all_input[3].cuda(), query_subject_emb, query_object_emb, query_shot_nums) support_subject_input = support_mapping_subj_feature.view(test_batch_size, num_supports, feature_dim).cuda() support_object_input = support_mapping_obj_feature.view(test_batch_size, num_supports, feature_dim).cuda() support_full_data = torch.cat([support_subject_input, support_object_input], 1) query_subject_input = query_mapping_subj_feature.view(test_batch_size, num_queries, feature_dim).cuda() query_object_input = query_mapping_obj_feature.view(test_batch_size, num_queries, feature_dim).cuda() query_full_data = torch.cat([query_subject_input, query_object_input], 1) feature_full_data = torch.cat([support_full_data, query_full_data], 1) all_support_full_data = all_support_full_data.view(test_batch_size, num_supports, feature_dim) all_query_full_data = all_query_full_data.view(test_batch_size, num_queries, feature_dim) full_data = torch.cat([all_support_full_data, all_query_full_data], 1) full_logit_all, object_full_logit_layers, object_out = self.gnn_module(node_feat=full_data, edge_feat=init_edge, object_node_feat=feature_full_data, object_edge_feat=os_full_edge) full_logit = full_logit_all[-1] full_edge_loss = self.edge_loss(1 - full_logit[:, 0], 1 - full_edge[:, 0]) query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum( query_edge_mask * evaluation_mask) pos_query_edge_loss = torch.sum( full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum( query_edge_mask * full_edge[:, 0] * evaluation_mask) neg_query_edge_loss = torch.sum( full_edge_loss * query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) / torch.sum( query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) query_edge_loss = pos_query_edge_loss + neg_query_edge_loss full_edge_accr = self.hit(full_logit, 1 - full_edge[:, 0].long()) query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum( query_edge_mask * evaluation_mask) query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports], self.one_hot_encode(num_ways_test, support_label.long())) # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways) query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean() query_edge_losses += [query_edge_loss.item()] query_edge_accrs += [query_edge_accr.item()] query_node_accrs += [query_node_accr.item()] print('evaluation: mean=%.2f%%, /n' % (np.array(query_node_accrs).mean() * 100,)) return np.array(total_nodes_acc).mean()
def train_self_supervised(self): seed = 0 import numpy as np np.random.seed(seed) import random as rn rn.seed(seed) import os os.environ['CUDA_VISIBLE_DEVICES'] = str(0) import torch torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) from networks import Dense_Net, Dense_Net_with_softmax Nets = [] AEs = [] for view_id in range(self.n_view): Net = Dense_Net(input_dim=self.input_shape[view_id], out_dim=self.output_shape) Nets.append(Net) AE = Dense_Net(input_dim=self.output_shape, out_dim=self.input_shape[view_id]) AEs.append(AE) if torch.cuda.is_available(): for view_id in range(self.n_view): Nets[view_id].cuda() AEs[view_id].cuda() W = torch.tensor(self.W) W = Variable(W.cuda(), requires_grad=False) get_grad_params = lambda model: [ x for x in model.parameters() if x.requires_grad ] params = [] optims = [] for view_id in range(self.n_view): params.append( get_grad_params(Nets[view_id]) + get_grad_params(AEs[view_id])) optims.append( optim.Adam(params[view_id], self.lr[view_id], [self.beta1, self.beta2])) discriminator_losses, losses, valid_results = [], [], [] for view_id in range(self.n_view): discriminator_losses.append([]) losses.append([]) valid_results.append([]) criterion = lambda x, y: (((x - y)**2).sum(1).sqrt()).mean() tr_d_loss, tr_ae_loss, val_d_loss, val_ae_loss = [], [], [], [] for view_id in range(self.n_view): tr_d_loss.append([]) tr_ae_loss.append([]) val_d_loss.append([]) val_ae_loss.append([]) valid_loss_min = 1e9 for epoch in range(self.epochs): rand_idx = np.arange(self.train_data[view_id].shape[0]) np.random.shuffle(rand_idx) batch_count = int(self.train_data[view_id].shape[0] / float(self.batch_size)) k = 0 mean_loss = [] mean_tr_d_loss, mean_tr_ae_loss = [], [] for view_id in range(self.n_view): mean_loss.append([]) mean_tr_d_loss.append([]) mean_tr_ae_loss.append([]) for batch_idx in range(batch_count): self_supervised_features = [] ae_loss = 0 for view_id in range(self.n_view): print(('ViewID: %d, Epoch %d/%d') % (view_id, epoch + 1, self.epochs)) idx = rand_idx[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size] train_y = self.to_one_hot(self.train_labels[view_id][idx]) train_x = self.to_var( torch.tensor(self.train_data[view_id][idx])) optimizer = optims[view_id] Net = Nets[view_id] AE = AEs[view_id] optimizer.zero_grad() network_outputs = Net(train_x) ae_input = network_outputs[-1] pred = ae_input.view([ae_input.shape[0], -1]).mm(W) self_supervised_features.append(pred) ae_data = train_x ae_pred = AE(ae_input)[-1] ae_loss += criterion(ae_pred, ae_data) labeled_inx = train_y.sum(1) > 0 d_loss = 0 if (not self.use_nce): identity_mat = torch.eye( self_supervised_features[0].shape[0]).reshape( 1, self_supervised_features[0].shape[0], self_supervised_features[0].shape[0]).repeat( self_supervised_features[0].shape[1], 1, 1) if torch.cuda.is_available(): identity_mat = identity_mat.cuda() f_1 = self_supervised_features[0].view( self_supervised_features[0].shape[1], self_supervised_features[0].shape[0], -1) f_2 = self_supervised_features[1].view( self_supervised_features[1].shape[1], -1, self_supervised_features[0].shape[0]) loss_pred = torch.bmm(f_1, f_2) cpc_loss = self.cpc_loss_func(loss_pred, identity_mat) else: # NCE CPC loss implementation (Eq. 6) [parts of code referred from https://github.com/HobbitLong/CMC/blob/master/NCE/NCECriterion.py] eps = 1e-7 anchor_samples_indices = self.sample_anchor_points( self_supervised_features) pos_samples_indices = self.sample_positives( anchor_samples_indices, self_supervised_features) neg_samples_indices = self.sample_negatives( anchor_samples_indices, self_supervised_features) anchor_samples = torch.zeros( (self.num_anchors, self_supervised_features[0].shape[1])) pos_samples = torch.zeros( (self.num_anchors, self_supervised_features[0].shape[1])) neg_samples = torch.zeros( (self.num_anchors, self.num_negative_samples, self_supervised_features[0].shape[1])) for anchor_id in range(len(anchor_samples_indices)): anchor_point = anchor_samples_indices[anchor_id] anchor_samples[anchor_id] = self_supervised_features[ anchor_point[0]][anchor_point[1]] for pos_id in range(len(pos_samples_indices)): pos_point = pos_samples_indices[pos_id] pos_samples[pos_id] = self_supervised_features[ pos_point[0]][pos_point[1]] for neg_id in range(len(neg_samples_indices)): for neg_sample_id in range( len(neg_samples_indices[neg_id])): neg_point = neg_samples_indices[neg_id][ neg_sample_id] neg_samples[neg_id][ neg_sample_id] = self_supervised_features[ neg_point[0]][neg_point[1]] if (torch.cuda.is_available()): anchor_samples = anchor_samples.cuda() pos_samples = pos_samples.cuda() neg_samples = neg_samples.cuda() # noise distribution Pn = 1 / float(self_supervised_features[0].shape[0]) # number of noise samples m = self_supervised_features[0].shape[1] D1 = torch.div( pos_samples, torch.zeros_like(pos_samples).fill_(eps) + pos_samples.add(m * Pn + eps)) D1 = torch.clamp(D1, min=eps) log_D1 = D1.log() D0 = torch.div( neg_samples.clone().fill_(m * Pn), torch.zeros_like(neg_samples).fill_(eps) + neg_samples.add(m * Pn + eps)) D0 = torch.clamp(D0, min=eps) log_D0 = D0.log() cpc_loss = torch.mean( -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / self_supervised_features[0].shape[0]**2) for f_ind in range(len(self_supervised_features)): train_y = self.to_one_hot(self.train_labels[f_ind][idx]) feature = self_supervised_features[f_ind] curr_loss = criterion(feature[labeled_inx], train_y[labeled_inx]) d_loss += curr_loss # cross modal loss (Eq. 2) cross_loss = 0 for f_ind1 in range(len(self_supervised_features)): for f_ind2 in range(len(self_supervised_features)): if (f_ind1 == f_ind2): continue else: feature_1 = self_supervised_features[f_ind1] feature_2 = self_supervised_features[f_ind2] train_y1 = self.to_one_hot( self.train_labels[f_ind1][idx]) train_y2 = self.to_one_hot( self.train_labels[f_ind2][idx]) curr_loss = criterion(feature_1[labeled_inx], feature_2[labeled_inx]) cross_loss += curr_loss ae_loss *= self.alpha cross_loss *= self.beta cpc_loss *= self.gamma d_loss *= self.delta loss = ae_loss + cpc_loss + d_loss + cross_loss loss.backward() for view_id in range(self.n_view): optims[view_id].step() mean_loss[view_id].append(self.to_data(loss)) mean_tr_d_loss[view_id].append(self.to_data(d_loss)) mean_tr_ae_loss[view_id].append(self.to_data(ae_loss)) if ((epoch + 1) % self.sample_interval == 0) and (batch_idx == batch_count - 1): valid_pres = [] test_pres = [] for view_id in range(self.n_view): losses[view_id].append(np.mean(mean_loss[view_id])) utils.show_progressbar([batch_idx, batch_count], mean_loss=np.mean( mean_loss[view_id])) pre_labels = utils.predict( lambda x: Nets[view_id](x)[-1].view( [x.shape[0], -1]).mm(W).view([x.shape[0], -1]), self.valid_data[view_id], self.batch_size).reshape( [self.valid_data[view_id].shape[0], -1]) valid_labels = self.to_one_hot( self.valid_labels[view_id]) valid_d_loss = self.to_data( criterion(self.to_var(torch.tensor(pre_labels)), valid_labels)) if valid_loss_min > valid_d_loss and not self.just_valid: valid_loss_min = valid_d_loss valid_pre_0 = utils.predict( lambda x: Nets[0] (x)[-1].view([x.shape[0], -1]), self.valid_data[0], self.batch_size).reshape( [self.valid_data[0].shape[0], -1]) valid_pre_1 = utils.predict( lambda x: Nets[1] (x)[-1].view([x.shape[0], -1]), self.valid_data[1], self.batch_size).reshape( [self.valid_data[1].shape[0], -1]) test_pre_0 = utils.predict( lambda x: Nets[0] (x)[-1].view([x.shape[0], -1]), self.test_data[0], self.batch_size).reshape( [self.test_data[0].shape[0], -1]) test_pre_1 = utils.predict( lambda x: Nets[1] (x)[-1].view([x.shape[0], -1]), self.test_data[1], self.batch_size).reshape( [self.test_data[1].shape[0], -1]) valid_pres.append(valid_pre_0) test_pres.append(test_pre_0) elif self.just_valid: tr_d_loss[view_id].append( np.mean(mean_tr_d_loss[view_id])) val_d_loss[view_id].append(valid_d_loss[view_id]) tr_ae_loss[view_id].append( np.mean(mean_tr_ae_loss[view_id])) elif batch_idx == batch_count - 1: utils.show_progressbar([batch_idx, batch_count], mean_loss=np.mean( mean_loss[view_id])) losses[view_id].append(np.mean(mean_loss[view_id])) else: utils.show_progressbar([batch_idx, batch_count], loss=loss) k += 1 torch.save(Nets[0].state_dict(), './features/' + self.datasets + '_image_encoder_weights') torch.save(Nets[1].state_dict(), './features/' + self.datasets + '_text_encoder_weights') torch.save(AEs[0].state_dict(), './features/' + self.datasets + '_image_decoder_weights') torch.save(AEs[1].state_dict(), './features/' + self.datasets + '_text_decoder_weights') valid_pres_all = [valid_pre_0, valid_pre_1] test_pres_all = [test_pre_0, test_pre_1] if not self.just_valid: for view_id in range(self.n_view): sio.savemat( 'features/' + self.datasets + '_' + str(view_id) + '.mat', { 'valid_fea': valid_pres_all[view_id], 'valid_lab': self.valid_labels[view_id], 'test_fea': test_pres_all[view_id], 'test_lab': self.test_labels[view_id] }) return [valid_pre_0, test_pre_0] else: self.tr_d_loss[view_id] = tr_d_loss[view_id] self.tr_ae_loss[view_id] = tr_ae_loss[view_id] self.val_d_loss[view_id] = val_d_loss[view_id]
def _rank_k_bmm(self, x, u, v): xu = torch.bmm(x[:, None], u.view(x.shape[0], x.shape[-1], self.rank)) xuv = torch.bmm(xu, v.view(x.shape[0], self.rank, -1)) return xuv[:, 0]
def gram_matrix(self, inp): b, c, h, w = inp.size() features = inp.view(b, c, h*w) G = torch.bmm(features, features.transpose(1, 2)) return G.div(h*w)
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, mode=""): # if self.pad_level == "sent" or self.pad_level == "sentence": text_inputs = text_inputs.view( text_inputs.shape[0], self.max_num_sents * self.max_len_sent) mask = mask_input.view(text_inputs.shape) # encoder_out = self.base_encoder(text_inputs, mask_input, len_seq) # applying conv1d after rnn avg_pooled = torch.zeros(text_inputs.shape[0], text_inputs.shape[1], self.conv_output_size) avg_pooled = utils.cast_type(avg_pooled, FLOAT, self.use_gpu) for cur_batch, cur_tensor in enumerate(encoder_out): ## Actual length version if self.target_model == "conll17_al": cur_seq_len = int(len_seq[cur_batch]) cur_tensor = cur_tensor.unsqueeze(0) crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len) crop_tensor = crop_tensor.transpose(1, 0) cur_tensor = crop_tensor ## published version: do not consider actual length else: cur_tensor = cur_tensor.unsqueeze(1) # applying conv cur_tensor = self.conv(cur_tensor) cur_tensor = self.leak_relu(cur_tensor) cur_tensor = self.dropout_layer(cur_tensor) # cur_tensor = self.avg_pool_1d(cur_tensor) cur_tensor = self.avg_adapt_pool1(cur_tensor) cur_tensor = cur_tensor.view(cur_tensor.shape[0], self.conv_output_size) avg_pooled[cur_batch, :cur_tensor.shape[0], :] = cur_tensor len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu) ## implement attention by parameters context_weight = self.context_weight.unsqueeze(1) context_weight = context_weight.expand(text_inputs.shape[0], self.conv_output_size, 1) attn_weight = torch.bmm(avg_pooled, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = self.softmax(attn_weight) # attention applied attn_vec = torch.bmm(avg_pooled.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) ## implement attention by linear #attn_vec = self.attn(encoder_out.view(self.batch_size, -1)).unsqueeze(2) #attn_vec = self.softmax(attn_vec) #ilc_vec_attn = torch.bmm(encoder_out.transpose(1, 2), attn_vec).squeeze(2) ## FC # fully connected stage fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.corpus_target.lower() == "asap": fc_out = self.sigmoid(fc_out) return fc_out
def forward(self, inputs, predict_M, getAttention=False): """ do the LSTM process :param inputs: the input sequcen :param pred_len: the length of the predict value :return: outputs:the ouput sequence """ assert len(inputs.size( )) == 3, '[LSTM]: input dimension must be of length 3 i.e. [M*S*D]' # encoder inputs = self.bn(inputs) inputs, h, r_batch_size = self._prepare_input_h(inputs) c, h = self.encoder(inputs, h) #attention and decoder outs = [] predict_M, _, _ = self._prepare_input_h(predict_M) # this must be in the outside of the for loop if not self.args.batch_first: c = c.permute(1, 0, 2) h_ = h # we use the h from encoder h_ = [(args.encoder_num_layer*binum,r_batch_size,encoder_hidden_dim), # (args.encoder_num_layer*binum,r_batch_size,encoder_hidden_dim)] if getAttention: attention = [] for i in range(self.args.pred_len): #attention if not self.args.batch_first: h_temp = h_[0].permute(1, 0, 2) h_temp = h_temp.contiguous().view(r_batch_size, -1) weights = F.softmax(self.sequenceForAttention(h_temp), dim=-1) # [r_batch_size,args.seq_len] if getAttention: temp_weights = weights.squeeze() temp_weights = temp_weights.cpu().detach().numpy() attention.append(temp_weights.tolist()) weights = torch.unsqueeze(weights, 1) out = torch.bmm(weights, c) #[r_batch_size,1,encoder_hidden_dim] if not self.args.batch_first: out = out.permute(1, 0, -1) #decoder out, h_ = self.decoder(out, h_) outs.append(out) outs = torch.stack( outs, dim=0) if not self.args.batch_first else torch.stack(outs, dim=1) outs = torch.squeeze( outs, dim=2) if self.args.batch_first else torch.squeeze(outs, dim=1) outs = self.sequenceForOut(outs) if not self.args.batch_first: outs = outs.permute(1, 0, 2) if getAttention: return outs, attention else: return outs
def forward(self, x0, x): x0xl = torch.bmm(x0.unsqueeze(-1), x.unsqueeze(-2)) return torch.tensordot(x0xl, self.weights, [[-1], [0]]) + self.bias + x
def forward(self, x, A1): Amatrix = torch.tril(A1) y = torch.bmm(Amatrix, x) return y
def forward(self, hidden, encoder_outputs, normalize=True): encoder_outputs = encoder_outputs.transpose(0, 1) # [B,T,H] attn_energies = self.score(hidden, encoder_outputs) normalized_energy = F.softmax(attn_energies, dim=2) # [B,1,T] context = torch.bmm(normalized_energy, encoder_outputs) # [B,1,H] return context.transpose(0, 1) # [1,B,H]
def score(self, hidden, encoder_outputs): energy = F.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) # [B*T*2H]->[B*T*H] energy = energy.transpose(2,1) # [B*H*T] v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H] energy = torch.bmm(v,energy) # [B*1*T] return energy.squeeze(1) #[B*T]
def forward(self, input): b, c, h, w = input.size() F = input.view(b, c, h * w) G = torch.bmm(F, F.transpose(1, 2)) G.div_(h * w) return G
def _layer(self, t, x, weight, bias): # weights is (batch, in_dim + 1, out_dim) ttx = self._pack_inputs(t, x) # (batch, in_dim + 1) ttx = ttx.view(ttx.size(0), 1, ttx.size(1)) # (batch, 1, in_dim + 1) xw = torch.bmm(ttx, weight)[:, 0, :] # (batch, out_dim) return xw + bias
def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None): # assert contents_char is not None and question_ans_char is not None batch_size = question_ans.size()[0] max_content_len = contents.size()[2] max_question_len = question_ans.size()[1] contents_num = contents.size()[1] # word-level embedding: (seq_len, batch, embedding_size) content_vec = [] content_mask = [] question_vec, question_mask = self.embedding.forward(question_ans) for i in range(contents_num): cur_content = contents[:, i, :] cur_content_vec, cur_content_mask = self.embedding.forward(cur_content) content_vec.append(cur_content_vec) content_mask.append(cur_content_mask) # char-level embedding: (seq_len, batch, char_embedding_size) # context_emb_char, context_char_mask = self.char_embedding.forward(context_char) # question_emb_char, question_char_mask = self.char_embedding.forward(question_char) question_encode, _ = self.context_layer.forward(question_vec,question_mask) # size=(cur_batch_max_questionans_len, batch, 256) content_encode = [] # word-level encode: (seq_len, batch, hidden_size) for i in range(contents_num): cur_content_vec = content_vec[i] cur_content_mask = content_mask[i] cur_content_encode, _ = self.context_layer.forward(cur_content_vec,cur_content_mask) # size=(cur_batch_max_content_len, batch, 256) content_encode.append(cur_content_encode) # 将所有的content编码后统一到相同的长度 200,所有的question编码后统一到相同的长度100 same_sized_content_encode = [] for i in range(contents_num): cur_content_encode = content_encode[i] cur_content_encode = self.full_matrix_to_specify_size(cur_content_encode, [max_content_len, batch_size,cur_content_encode.size()[2]]) # size=(200,16,256) same_sized_content_encode.append(cur_content_encode) same_sized_question_encode = self.full_matrix_to_specify_size(question_encode, [max_question_len, batch_size,question_encode.size()[2]]) # size=(100,16,256) # 计算gating layer的值 reasoning_content_gating_val = [] reasoning_question_gating_val = None decision_content_gating_val = [] decision_question_gating_val = None for i in range(contents_num): cur_content_encode = same_sized_content_encode[i] # size=(200,16,256) cur_gating_input = cur_content_encode.permute(1,2,0) # size=(16,256,200) cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_gating_input) # size=(16,1,200) cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 cur_decision_content_gating_val = self.decision_gating_layer(cur_gating_input) # size=(16,1,200) cur_decision_content_gating_val =cur_decision_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 reasoning_content_gating_val.append(cur_reasoning_content_gating_val) decision_content_gating_val.append(cur_decision_content_gating_val) question_gating_input = same_sized_question_encode.permute(1,2,0) # size=(16,256,100) reasoning_question_gating_val = self.reasoning_gating_layer(question_gating_input) # size=(16,1,100) reasoning_question_gating_val=reasoning_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 decision_question_gating_val = self.decision_gating_layer(question_gating_input) # size=(16,1,100) decision_question_gating_val=decision_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 # 计算gate loss todo: 貌似无法返回多个变量,暂时无用 # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)]) # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val]) # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val]) # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val]) # mean_gate_val = torch.mean(all_gate_val) # Matching Matrix computing, question 和每一个content都要计算matching matrix Matching_matrix = [] for i in range(contents_num): cur_content_encode = same_sized_content_encode[i] cur_Matching_matrix = self.compute_matching_matrix(same_sized_question_encode, cur_content_encode) # (batch, question_len , content_len) eg(16,100,200) Matching_matrix.append(cur_Matching_matrix) # compute an & bn an_matrix = [] bn_matrix = [] for i in range(contents_num): cur_Matching_matrix = Matching_matrix[i] cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2) # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len) cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1) # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len) an_matrix.append(cur_an_matrix) bn_matrix.append(cur_bn_matrix) # compute RnQ & RnD RnQ = [] # list[tensor[16,100,256]] RnD = [] for i in range(contents_num): cur_an_matrix = an_matrix[i] cur_content_encode = same_sized_content_encode[i] cur_bn_matrix = bn_matrix[i] cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode) # size=(batch, curbatch_max_question_len , 256) eg[16,100,256] cur_RnD = self.compute_RnD(cur_bn_matrix,same_sized_question_encode) # size=(batch, curbatch_max_content_len , 256) eg[16,200,256] RnQ.append(cur_RnQ) RnD.append(cur_RnD) ########### compute Mmn' ############## D_RnD = [] # 先获得D和RnD的concatenation for i in range(contents_num): cur_content_encode = same_sized_content_encode[i].transpose(0, 1) # size=(16,200,256) cur_RnD = RnD[i] # size=(16,200,256) # embed() cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2) # size=(16,200,512) D_RnD.append(cur_D_RnD) RmD = [] # list[tensor(16,200,512)] for i in range(contents_num): D_RnD_m = D_RnD[i] # size=(16,200,512) Mmn_i=[] RmD_i = [] for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n) # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200) Mmn_i.append(Mmn_i_j) Mmn_i=torch.stack(Mmn_i).permute(1,2,3,0)# size=(16,200,200,10) softmax_Mmn_i=self.reduce_softmax(Mmn_i) # size=(16,200,200,10) for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) beta_mn_i_j = softmax_Mmn_i[:,:,:,j] cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n) # size=(16,200,512) RmD_i.append(cur_RmD) RmD_i = torch.stack(RmD_i) # size=(10,16,200,512) RmD_i = RmD_i.transpose(0, 1) # size=(16,10,200,512) RmD_i = torch.sum(RmD_i, dim=1) # size=(16,200,512) RmD.append(RmD_i) # RmD=torch.stack(RmD).transpose(0,1) #size=(16,10,200,512) matching_feature_row = [] # list[tensor(16,200,2)] matching_feature_col = [] # list[tensor(16,100,2)] for i in range(contents_num): cur_Matching_matrix = Matching_matrix[i] # size=(16,100,200) cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1) # size=(16,200) cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1) # size=(16,200) cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row]).permute(1,2,0) # size=(16,200,2) matching_feature_row.append(cur_matching_feature_row) cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2) # size=(16,100) cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2) # size=(16,100) cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col]).permute(1,2,0) # size=(16,100,2) matching_feature_col.append(cur_matching_feature_col) # print(253) # embed() reasoning_feature = [] RnQ_reasoning_out=[] RmD_reasoning_out=[] for i in range(contents_num): cur_RnQ = RnQ[i] # size=(16,100,256) cur_RmD = RmD[i] # size=(16,200,512) cur_matching_feature_col = matching_feature_col[i] # size=(16,100,2) cur_matching_feature_row = matching_feature_row[i] # size=(16,200,2) cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2) # size=(16,100,258) cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2) # size=(16,200,514) cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx) cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx) gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258) gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514) # 经过reasoning层 cur_RnQ_reasoning_out, _ = self.question_reasoning_layer.forward(gated_cur_RnQ.transpose(0,1),cur_RnQ_mask) # size=(max_sequence_len,16,256) cur_RmD_reasoning_out, _ = self.content_reasoning_layer.forward(gated_cur_RmD.transpose(0,1),cur_RmD_mask) # size=(max_sequence_len,16,256) # 所有的矩阵变成相同的大小 cur_RnQ_reasoning_out = self.full_matrix_to_specify_size(cur_RnQ_reasoning_out, [max_question_len, batch_size, cur_RnQ_reasoning_out.size()[2]]) # size=(100,16,256) cur_RmD_reasoning_out = self.full_matrix_to_specify_size(cur_RmD_reasoning_out, [max_content_len, batch_size, cur_RmD_reasoning_out.size()[2]]) # size=(200,16,256) #过decision layer的gate层 cur_RnQ_reasoning_out=cur_RnQ_reasoning_out.transpose(0,1) #size(16,100,256) cur_RmD_reasoning_out=cur_RmD_reasoning_out.transpose(0,1) #size(16,200,256) RnQ_reasoning_out.append(cur_RnQ_reasoning_out) RmD_reasoning_out.append(cur_RmD_reasoning_out) # gated_RnQ_out=self.compute_gated_value(cur_RnQ_reasoning_out,decision_question_gating_val)#size(16,100,256) # gated_RmD_out=self.compute_gated_value(cur_RmD_reasoning_out,decision_content_gating_val[i])#size(16,200,256) # # # 将2种feature cat到一起得到300*256的表示 # cur_reasoning_feature = torch.cat([gated_RnQ_out, gated_RmD_out], dim=1) # size(16,300,256) || when content=100 size(16,200,256) # reasoning_feature.append(cur_reasoning_feature) # 10个文档的cat到一起 RnQ_reasoning_out=torch.stack(RnQ_reasoning_out).transpose(0,1) #size=(16,10,100,256) RmD_reasoning_out=torch.stack(RmD_reasoning_out).transpose(0,1) #size=(16,10,200,256) RnQ_reasoning_out_maxpool,_=torch.max(RnQ_reasoning_out,dim=1) #size=(16,100,256) RmD_reasoning_out_maxpool,_=torch.max(RmD_reasoning_out,dim=1) #size=(16,200,256) # gated_RnQ_reasoning_out_maxpool=self.compute_gated_value(RnQ_reasoning_out_maxpool,decision_question_gating_val) #size(16,100,256) # gated_RmD_reasoning_out_maxpool=self.compute_gated_value(RmD_reasoning_out_maxpool,decision_content_gating_val) fc_input_RnQ_maxpool,_=torch.max(RnQ_reasoning_out_maxpool,dim=1) # size(16,256) fc_input_RnQ_meanpool=torch.mean(RnQ_reasoning_out_maxpool,dim=1) # size(16,256) fc_input_RmD_maxpool,_=torch.max(RmD_reasoning_out_maxpool,dim=1)#size(16,256) fc_input_RmD_meanpool=torch.mean(RmD_reasoning_out_maxpool,dim=1) #size(16,256) fc_input=torch.cat([fc_input_RnQ_maxpool,fc_input_RnQ_meanpool,fc_input_RmD_maxpool,fc_input_RmD_meanpool],dim=1) #size(16,1024) # reasoning_feature = torch.cat(reasoning_feature, dim=1) # size=(16,3000,256) | when content=100 size(16,2000,256) # # print(299) # # embed() # maxpooling_reasoning_feature_column, _ = torch.max(reasoning_feature, dim=1) # size(16,256) # meanpooling_reasoning_feature_column = torch.mean(reasoning_feature, dim=1) # size(16,256) # # maxpooling_reasoning_feature_row, _ = torch.max(reasoning_feature, dim=2) # size=(16,3000) | when content=100 size(16,2000) # meanpooling_reasoning_feature_row = torch.mean(reasoning_feature, dim=2) # size=(16,3000) | when content=100 size(16,2000) # print(228, "============================") # pooling_reasoning_feature = torch.cat([maxpooling_reasoning_feature_row, meanpooling_reasoning_feature_row, maxpooling_reasoning_feature_column,meanpooling_reasoning_feature_column], dim=1) decision_input=fc_input.view(int(batch_size/5), 5120) # size=(16,1024*5) 五分类问题 # # print(312) # embed() output = self.decision_layer.forward(decision_input) # size=(batchsize/5,5) # temp_gate_val=torch.stack([mean_gate_val,torch.tensor(0.0).to(self.device)]).resize_(1,2) # output_with_gate_val=torch.cat([output,temp_gate_val],dim=0) # logics=logics.resize_(logics.size()[0],1) return output # logics 是反向的话乘以-1,正向的话是乘以1
def train(self): num_ways_train = self.config['num_ways'] num_shots_train = self.config['num_shots'] support_shot_nums = num_shots_train query_shot_nums = 1 meta_batch_size = self.config['num_batch'] num_layers = self.config['num_layers'] train_iteration = self.config['train_iteration'] lr = 1e-3 test_interval = 150 predicate_detection = self.config['predicate_detection'] feature_dim = self.config['feature_dim'] val_acc = self.val_acc num_supports = num_ways_train * num_shots_train num_queries = num_ways_train * 1 num_samples = num_supports + num_queries support_edge_mask = torch.zeros(meta_batch_size, num_samples, num_samples).cuda() support_edge_mask[:, :num_supports, :num_supports] = 1 query_edge_mask = 1 - support_edge_mask evaluation_mask = torch.ones(meta_batch_size, num_samples, num_samples).cuda() object_support_edge_mask = torch.zeros(meta_batch_size, 2 * num_samples, 2 * num_samples).cuda() object_query_edge_mask = 1 - object_support_edge_mask object_evaluation_mask = torch.ones(meta_batch_size, 2 * num_samples, 2 * num_samples).cuda() for iter in range(self.global_step + 1, train_iteration + 1): self.optimizer.zero_grad() self.global_step = iter support_all_input, query_all_input, os_label, [support_label, query_label], [idx_for_class, idx_for_data] = self.data_loader.get_task_batch( num_tasks=meta_batch_size, num_ways=num_ways_train, num_shots=num_shots_train, seed=iter) os_support_full_label = torch.cat( [os_label[0].view(meta_batch_size, -1).cuda(), os_label[1].view(meta_batch_size, -1).cuda()], 1) os_query_full_label = torch.cat( [os_label[2].view(meta_batch_size, -1).cuda(), os_label[3].view(meta_batch_size, -1).cuda()], 1) os_full_label = torch.cat([os_support_full_label, os_query_full_label], 1) os_full_edge = self.label2edge(os_full_label) support_label = support_label.cuda() query_label = query_label.cuda() full_label = torch.cat([support_label, query_label], 1) full_edge = self.label2edge(full_label) init_edge = full_edge.clone() init_edge[:, :, num_supports:, :] = 0.5 init_edge[:, :, :, num_supports:] = 0.5 for i in range(num_queries): init_edge[:, 0, num_supports + i, num_supports + i] = 1.0 init_edge[:, 1, num_supports + i, num_supports + i] = 0.0 self.enc_module.train() self.gnn_module.train() support_predicate_input = support_all_input[1].view(meta_batch_size * num_ways_train, 1, num_shots_train, -1).cuda() query_predicate_input = query_all_input[1].view(meta_batch_size * num_ways_train, 1, 1, -1).cuda() support_subject_emb = support_all_input[4].cuda() support_object_emb = support_all_input[5].cuda() query_subject_emb = query_all_input[4].cuda() query_object_emb = query_all_input[5].cuda() all_support_full_data, support_mapping_subj_feature, support_mapping_obj_feature = self.enc_module( support_all_input[0].cuda(), support_predicate_input, support_all_input[2].cuda(), support_all_input[3].cuda(), support_subject_emb, support_object_emb, support_shot_nums) all_query_full_data, query_mapping_subj_feature, query_mapping_obj_feature = self.enc_module( query_all_input[0].cuda(), query_predicate_input, query_all_input[2].cuda(), query_all_input[3].cuda(), query_subject_emb, query_object_emb, query_shot_nums) all_support_full_data = all_support_full_data.view(meta_batch_size, num_supports, feature_dim) all_query_full_data = all_query_full_data.view(meta_batch_size, num_queries, feature_dim) full_data = torch.cat([all_support_full_data, all_query_full_data], 1) # support_subject_input = support_mapping_subj_feature.view(meta_batch_size, num_supports, feature_dim).cuda() support_object_input = support_mapping_obj_feature.view(meta_batch_size, num_supports, feature_dim).cuda() support_full_data = torch.cat([support_subject_input, support_object_input], 1) query_subject_input = query_mapping_subj_feature.view(meta_batch_size, num_queries, feature_dim).cuda() query_object_input = query_mapping_obj_feature.view(meta_batch_size, num_queries, feature_dim).cuda() query_full_data = torch.cat([query_subject_input, query_object_input], 1) feature_full_data = torch.cat([support_full_data, query_full_data], 1) full_logit_layers, object_full_logit_layers, object_out = self.gnn_module(node_feat=full_data, edge_feat=init_edge, object_node_feat=feature_full_data, object_edge_feat=os_full_edge) support_subject_output_label = object_out[0] support_object_output_label = object_out[1] query_subject_output_label = object_out[2] query_object_output_label = object_out[3] subject_support_loss = self.criterion(object_out[0], os_label[0].cuda().long()) object_support_loss = self.criterion(object_out[1], os_label[1].cuda().long()) subject_query_loss = self.criterion(object_out[2], os_label[2].cuda().long()) object_query_loss = self.criterion(object_out[3], os_label[3].cuda().long()) object_class_loss = ( subject_support_loss + object_support_loss + subject_query_loss + object_query_loss) / 4 full_edge_loss_layers = [self.edge_loss((1 - full_logit_layer[:, 0]), (1 - full_edge[:, 0])) for full_logit_layer in full_logit_layers] object_full_edge_loss_layers = [self.edge_loss((1 - full_logit_layer[:, 0]), (1 - os_full_edge[:, 0])) for full_logit_layer in object_full_logit_layers] object_pos_query_edge_loss_layers = [torch.sum( full_edge_loss_layer * object_query_edge_mask * os_full_edge[:, 0] * object_evaluation_mask) / torch.sum( object_query_edge_mask * os_full_edge[:, 0] * object_evaluation_mask) for full_edge_loss_layer in object_full_edge_loss_layers] object_neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * object_query_edge_mask * ( 1 - os_full_edge[:, 0]) * object_evaluation_mask) / torch.sum( object_query_edge_mask * (1 - os_full_edge[:, 0])) for full_edge_loss_layer in object_full_edge_loss_layers] object_query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(object_pos_query_edge_loss_layers, object_neg_query_edge_loss_layers)] pos_query_edge_loss_layers = [ torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum( query_edge_mask * full_edge[:, 0]) for full_edge_loss_layer in full_edge_loss_layers] neg_query_edge_loss_layers = [ torch.sum(full_edge_loss_layer * query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) / torch.sum( query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers] query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)] full_edge_accr_layers = [self.hit(full_logit_layer, 1 - full_edge[:, 0].long()) for full_logit_layer in full_logit_layers] query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum( query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers] query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports], self.one_hot_encode(num_ways_train, support_label.long())) for full_logit_layer in full_logit_layers] query_node_accr_layers = [ torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for query_node_pred_layer in query_node_pred_layers] query_nodes_acc = query_node_accr_layers[-1] total_loss_layers = query_edge_loss_layers total_loss = [] for l in range(num_layers - 1): total_loss += [total_loss_layers[l].view(-1) * 0.5] total_loss += [total_loss_layers[-1].view(-1) * 1.0] total_loss = torch.mean(torch.cat(total_loss, 0)) object_total_loss_layers = object_query_edge_loss_layers object_total_loss = [] for l in range(num_layers - 1): object_total_loss += [object_total_loss_layers[l].view(-1) * 0.5] object_total_loss += [object_total_loss_layers[-1].view(-1) * 1.0] object_total_loss = torch.mean(torch.cat(object_total_loss, 0)) total_loss = total_loss + object_total_loss + 1 * object_class_loss total_loss.backward() self.optimizer.step() print('train/edge_loss'+ str( query_edge_loss_layers[-1])) print("\t train/edge_accr" + str(query_edge_accr_layers[-1])) print("\t train/node_accr" + str(query_nodes_acc)+"\n") if self.global_step % test_interval == 0: val_acc = self.test(partition='val') is_best = 0 if val_acc >= self.val_acc: test_acc = self.test(partition='test') self.test_acc = test_acc self.val_acc = val_acc is_best = 1 print("val/best_accr"+ str(self.val_acc)) print("\t val/best_accr" + str(self.test_acc)+"\n" )
def compute_matching_matrix(self, question_encode, content_encode): question_encode_trans = question_encode.transpose(0, 1) # (batch, seq_len, embedding_size) content_encode_trans = content_encode.transpose(0, 1) # (batch, seq_len, embedding_size) content_encode_trans = content_encode_trans.transpose(1, 2) # (batch, embedding_size, seq_len) Matching_matrix = torch.bmm(question_encode_trans, content_encode_trans) # (batch, question_len , content_len) return Matching_matrix
def compute_cross_document_attention(self, content_m, content_n): content_n_trans = content_n.transpose(1, 2) # (batch, 512, 200) Matching_matrix = torch.bmm(content_m, content_n_trans) # (batch, question_len , content_len) return Matching_matrix
def multi_head_attention_forward( query, # type: Tensor key, # type: Tensor value, # type: Tensor embed_dim_to_check, # type: int num_heads, # type: int in_proj_weight, # type: Tensor in_proj_bias, # type: Tensor bias_k, # type: Optional[Tensor] bias_v, # type: Optional[Tensor] add_zero_attn, # type: bool dropout_p, # type: float out_proj_weight, # type: Tensor out_proj_bias, # type: Tensor training=True, # type: bool key_padding_mask=None, # type: Optional[Tensor] need_weights=True, # type: bool attn_mask=None, # type: Optional[Tensor] use_separate_proj_weight=False, # type: bool q_proj_weight=None, # type: Optional[Tensor] k_proj_weight=None, # type: Optional[Tensor] v_proj_weight=None, # type: Optional[Tensor] static_k=None, # type: Optional[Tensor] static_v=None, # type: Optional[Tensor] relative_atten_weights=None, # type: Tensor app_relation=True, ): # type: (...) -> Tuple[Tensor, Optional[Tensor]] r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: mask that prevents attention to certain positions. This is an additive mask (i.e. the values will be added to the attention layer). use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. relative_atten_weights: used to deal with relative relationship, add in the atten weights before softmax Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - relative_atten_weights: math:`(N, num_heads, L, S)`, where N is the batch size, L is the target sequence length, S is the source sequence length. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ qkv_same = torch.equal(query, key) and torch.equal(key, value) kv_same = torch.equal(key, value) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check assert list(query.size()) == [tgt_len, bsz, embed_dim] assert key.size() == value.size() head_dim = embed_dim // num_heads assert head_dim * num_heads == embed_dim, \ 'embed_dim must be divisible by num_heads' scaling = float(head_dim)**-0.5 if use_separate_proj_weight is not True: if qkv_same: # self-attention q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk( 3, dim=-1) elif kv_same: # encoder-decoder attention # This is inline in_proj function # with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = F.linear(query, _w, _b) if key is None: assert value is None k = None v = None else: # This is inline in_proj function # with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] k, v = F.linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function # with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = F.linear(query, _w, _b) # This is inline in_proj function # with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] k = F.linear(key, _w, _b) # This is inline in_proj function # with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] v = F.linear(value, _w, _b) else: q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) len1, len2 = q_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == query.size(-1) k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) len1, len2 = k_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == key.size(-1) v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) len1, len2 = v_proj_weight_non_opt.size() assert len1 == embed_dim and len2 == value.size(-1) if in_proj_bias is not None: q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) else: q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) q = q * scaling if bias_k is not None and bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [ attn_mask, torch.zeros( (attn_mask.size(0), 1), dtype=attn_mask.dtype, device=attn_mask.device, ), ], dim=1, ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros( (key_padding_mask.size(0), 1), dtype=key_padding_mask.dtype, device=key_padding_mask.device, ), ], dim=1, ) else: assert static_k is None, 'bias cannot be added to static key.' assert static_v is None, 'bias cannot be added to static value.' else: assert bias_k is None assert bias_v is None q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if add_zero_attn: src_len += 1 k = torch.cat( [ k, torch.zeros( (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device), ], dim=1, ) v = torch.cat( [ v, torch.zeros( (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device), ], dim=1, ) if attn_mask is not None: attn_mask = torch.cat( [ attn_mask, torch.zeros( (attn_mask.size(0), 1), dtype=attn_mask.dtype, device=attn_mask.device, ), ], dim=1, ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros( (key_padding_mask.size(0), 1), dtype=key_padding_mask.dtype, device=key_padding_mask.device, ), ], dim=1, ) if app_relation: attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list( attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] else: attn_output_weights = None if relative_atten_weights is not None: if app_relation: attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len) else: attn_output_weights = 0 attn_output_weights += relative_atten_weights attn_output_weights = attn_output_weights.reshape( bsz * num_heads, tgt_len, src_len) assert attn_output_weights is not None, \ ('Please either specify relative position relation ' 'or appearance relation.') if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) attn_output_weights = F.softmax(attn_output_weights, dim=-1) attn_output_weights = F.dropout( attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view( tgt_len, bsz, embed_dim) attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None
def matrix_power3(self, Input): B = torch.bmm(Input, Input) return torch.bmm(B, Input)
def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer( self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer( embedded_question) projected_embedded_passage = self._encoding_proj_layer( embedded_passage) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) passsage_attention_over_attention = torch.bmm( passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum( encoded_passage, passsage_attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors, ], dim=-1, )) # The recurrent modeling layers. Since these layers share the same parameters, # we don't construct them conditioned on answering abilities. modeled_passage_list = [ self._modeling_proj_layer(merged_passage_attention_vectors) ] for _ in range(4): modeled_passage = self._dropout( self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Pop the first one, which is input modeled_passage_list.pop(0) # The first modeling layer is used to calculate the vector representation of passage passage_weights = self._passage_weights_predictor( modeled_passage_list[0]).squeeze(-1) passage_weights = masked_softmax(passage_weights, passage_mask) passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights) # The vector representation of question is calculated based on the unmatched encoding, # because we may want to infer the answer ability only based on the question words. question_weights = self._question_weights_predictor( encoded_question).squeeze(-1) question_weights = masked_softmax(question_weights, question_mask) question_vector = util.weighted_sum(encoded_question, question_weights) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = self._answer_ability_predictor( torch.cat([passage_vector, question_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor(passage_vector) count_number_log_probs = torch.nn.functional.log_softmax( count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = torch.gather( count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self. _counting_index] if "passage_span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length, modeling_dim * 2)) passage_for_span_start = torch.cat( [modeled_passage_list[0], modeled_passage_list[1]], dim=-1) # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor( passage_for_span_start).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) passage_for_span_end = torch.cat( [modeled_passage_list[0], modeled_passage_list[2]], dim=-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor( passage_for_span_end).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = replace_masked_values_with_big_negative_number( passage_span_start_logits, passage_mask) passage_span_end_logits = replace_masked_values_with_big_negative_number( passage_span_end_logits, passage_mask) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) # Shape: (batch_size, 2) best_passage_start_log_probs = torch.gather( passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1) best_passage_end_log_probs = torch.gather( passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs if len(self.answering_abilities) > 1: best_passage_span_log_prob += answer_ability_log_probs[:, self. _passage_span_extraction_index] if "question_span_extraction" in self.answering_abilities: # Shape: (batch_size, question_length) encoded_question_for_span_prediction = torch.cat( [ encoded_question, passage_vector.unsqueeze(1).repeat( 1, encoded_question.size(1), 1), ], -1, ) question_span_start_logits = self._question_span_start_predictor( encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = self._question_span_end_predictor( encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax( question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = replace_masked_values_with_big_negative_number( question_span_start_logits, question_mask) question_span_end_logits = replace_masked_values_with_big_negative_number( question_span_end_logits, question_mask) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) # Shape: (batch_size, 2) best_question_start_log_probs = torch.gather( question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)).squeeze(-1) best_question_end_log_probs = torch.gather( question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_question_span_log_prob = (best_question_start_log_probs + best_question_end_log_probs) if len(self.answering_abilities) > 1: best_question_span_log_prob += answer_ability_log_probs[:, self. _question_span_extraction_index] if "addition_subtraction" in self.answering_abilities: # Shape: (batch_size, # of numbers in the passage) number_indices = number_indices.squeeze(-1) number_mask = number_indices != -1 clamped_number_indices = util.replace_masked_values( number_indices, number_mask, 0) encoded_passage_for_numbers = torch.cat( [modeled_passage_list[0], modeled_passage_list[3]], dim=-1) # Shape: (batch_size, # of numbers in the passage, encoding_dim) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1)), ) # Shape: (batch_size, # of numbers in the passage) encoded_numbers = torch.cat( [ encoded_numbers, passage_vector.unsqueeze(1).repeat( 1, encoded_numbers.size(1), 1), ], -1, ) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax( number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values( best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values( best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[:, self. _addition_subtraction_index] output_dict = {} # If answer is given, compute the loss. if (answer_as_passage_spans is not None or answer_as_question_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None): log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = gold_passage_span_starts != -1 clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = torch.gather( passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = torch.gather( passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = ( log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = ( replace_masked_values_with_big_negative_number( log_likelihood_for_passage_spans, gold_passage_span_mask, )) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = gold_question_span_starts != -1 clamped_gold_question_span_starts = util.replace_masked_values( gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = util.replace_masked_values( gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = torch.gather( question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = torch.gather( question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = ( log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = ( replace_masked_values_with_big_negative_number( log_likelihood_for_question_spans, gold_question_span_mask, )) # Shape: (batch_size, ) log_marginal_likelihood_for_question_span = util.logsumexp( log_likelihood_for_question_spans) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = answer_as_add_sub_expressions.sum( -1) > 0 # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose( 1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather( number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = util.replace_masked_values( log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum( 1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = replace_masked_values_with_big_negative_number( log_likelihood_for_add_subs, gold_add_sub_mask) # Shape: (batch_size, ) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = answer_as_counts != -1 # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values( answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = replace_masked_values_with_big_negative_number( log_likelihood_for_counts, gold_count_mask) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp( log_likelihood_for_counts) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = (all_log_marginal_likelihoods + answer_ability_log_probs) marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" passage_str = metadata[i]["original_passage"] offsets = metadata[i]["passage_token_offsets"] predicted_span = tuple( best_passage_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = passage_str[start_offset:end_offset] answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" question_str = metadata[i]["original_question"] offsets = metadata[i]["question_token_offsets"] predicted_span = tuple( best_question_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = question_str[start_offset:end_offset] answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif (predicted_ability_str == "addition_subtraction" ): # plus_minus combination answer answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]["original_numbers"] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [ sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy() ] result = sum([ sign * number for sign, number in zip( predicted_signs, original_numbers) ]) predicted_answer = str(result) offsets = metadata[i]["passage_token_offsets"] number_indices = metadata[i]["number_indices"] number_positions = [ offsets[index] for index in number_indices ] answer_json["numbers"] = [] for offset, value, sign in zip(number_positions, original_numbers, predicted_signs): answer_json["numbers"].append({ "span": offset, "value": value, "sign": sign }) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = result elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu( ).numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get("answer_annotations", []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) # This is used for the demo. output_dict[ "passage_question_attention"] = passage_question_attention output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict
def compute_RnQ(self, an_matrix, content_encode): content_encode_trans = content_encode.transpose(0, 1) ## (batch, content_len, embedding_size) RnQ = torch.bmm(an_matrix, content_encode_trans) return RnQ
processed['attention_mask']).to(device) token_type_ids = torch.tensor( processed['token_type_ids']).to(device) output = model( input_ids=input_ids.view(M, max_seq_length), attention_mask=attention_mask.view(M, max_seq_length), token_type_ids=token_type_ids.view(M, max_seq_length)) start_logits = output.start_logits end_logits = output.end_logits start_prob = softmax(start_logits) end_prob = softmax(end_logits) span_prob = torch.bmm(start_prob.view(M, -1, 1), end_prob.view(M, 1, -1)) span_prob = torch.triu(span_prob) # mask to limit the length of the span mask = torch.ones_like(span_prob) mask = torch.triu(mask, diagonal=max_answer_length) span_prob[mask == 1] = 0 # mask out the question span_prob[token_type_ids == 0] = 0 best_span = torch.argmax(span_prob.view(M, -1), dim=1) offset = processed['offset_mapping'] new_prediction = [ processed['backs'][idx]
def compute_RnD(self, bn_matrix, question_encode): bn_matrix_trans = bn_matrix.transpose(1, 2) # size=(batch,content_len,question_len) question_encode_trans = question_encode.transpose(0, 1) # (batch, question_len, embedding_size) RnD = torch.bmm(bn_matrix_trans, question_encode_trans) return RnD
def forward(self, u_enc_out, z_tm1, last_hidden, u_input_np, pv_z_enc_out, prev_z_input_np, u_emb, pv_z_emb): sparse_u_input = Variable(get_sparse_input_aug(u_input_np), requires_grad=False) if pv_z_enc_out is not None: context = self.attn_u(last_hidden, torch.cat([pv_z_enc_out, u_enc_out], dim=0)) else: context = self.attn_u(last_hidden, u_enc_out) embed_z = self.emb(z_tm1) #embed_z = F.dropout(embed_z, self.dropout_rate) #embed_z = self.inp_dropout(embed_z) gru_in = torch.cat([embed_z, context], 2) gru_out, last_hidden = self.gru(gru_in, last_hidden) #gru_out = F.dropout(gru_out, self.dropout_rate) #gru_out = self.inp_dropout(gru_out) gen_score = self.proj(torch.cat([gru_out, context], 2)).squeeze(0) #gen_score = F.dropout(gen_score, self.dropout_rate) #gen_score = self.inp_dropout(gen_score) u_copy_score = F.tanh(self.proj_copy1(u_enc_out.transpose( 0, 1))) # [B,T,H] # stable version of copynet u_copy_score = torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) u_copy_score = u_copy_score.cpu() u_copy_score_max = torch.max(u_copy_score, dim=1, keepdim=True)[0] u_copy_score = torch.exp(u_copy_score - u_copy_score_max) # [B,T] u_copy_score = torch.log( torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze(1) + u_copy_score_max # [B,V] u_copy_score = cuda_(u_copy_score) if pv_z_enc_out is None: #u_copy_score = F.dropout(u_copy_score, self.dropout_rate) #u_copy_score = self.inp_dropout(u_copy_score) scores = F.softmax(torch.cat([gen_score, u_copy_score], dim=1), dim=1) gen_score, u_copy_score = scores[:, :cfg.vocab_size], \ scores[:, cfg.vocab_size:] proba = gen_score + u_copy_score[:, :cfg.vocab_size] # [B,V] proba = torch.cat([proba, u_copy_score[:, cfg.vocab_size:]], 1) else: sparse_pv_z_input = Variable(get_sparse_input_aug(prev_z_input_np), requires_grad=False) pv_z_copy_score = F.tanh( self.proj_copy2(pv_z_enc_out.transpose(0, 1))) # [B,T,H] pv_z_copy_score = torch.matmul( pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) pv_z_copy_score = pv_z_copy_score.cpu() pv_z_copy_score_max = torch.max(pv_z_copy_score, dim=1, keepdim=True)[0] pv_z_copy_score = torch.exp(pv_z_copy_score - pv_z_copy_score_max) # [B,T] pv_z_copy_score = torch.log( torch.bmm(pv_z_copy_score.unsqueeze(1), sparse_pv_z_input) ).squeeze(1) + pv_z_copy_score_max # [B,V] pv_z_copy_score = cuda_(pv_z_copy_score) scores = F.softmax(torch.cat( [gen_score, u_copy_score, pv_z_copy_score], dim=1), dim=1) gen_score, u_copy_score, pv_z_copy_score = scores[:, :cfg.vocab_size], \ scores[:, cfg.vocab_size:2 * cfg.vocab_size + u_input_np.shape[0]], \ scores[:, 2 * cfg.vocab_size + u_input_np.shape[0]:] proba = gen_score + u_copy_score[:, :cfg. vocab_size] + pv_z_copy_score[:, : cfg. vocab_size] # [B,V] proba = torch.cat([ proba, pv_z_copy_score[:, cfg.vocab_size:], u_copy_score[:, cfg.vocab_size:] ], 1) return gru_out, last_hidden, proba
def step(self, Ybar_t: torch.Tensor, dec_state: Tuple[torch.Tensor, torch.Tensor], enc_hiddens: torch.Tensor, enc_hiddens_proj: torch.Tensor, enc_masks: torch.Tensor) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: """ Compute one forward step of the LSTM decoder, including the attention computation. @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder, where b = batch size, e = embedding size, h = hidden size. @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size. First tensor is decoder's prev hidden state, second tensor is decoder's prev cell. @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size, src_len = maximum source length, h = hidden size. @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h), where b = batch size, src_len = maximum source length, h = hidden size. @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len), where b = batch size, src_len is maximum source length. @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size. First tensor is decoder's new hidden state, second tensor is decoder's new cell. @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size. @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution. Note: You will not use this outside of this function. We are simply returning this value so that we can sanity check your implementation. """ combined_output = None ### YOUR CODE HERE (~3 Lines) ### TODO: ### 1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state. ### 2. Split dec_state into its two parts (dec_hidden, dec_cell) ### 3. Compute the attention scores e_t, a Tensor shape (b, src_len). ### Note: b = batch_size, src_len = maximum source length, h = hidden size. ### ### Hints: ### - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched) ### - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched). ### - Use batched matrix multiplication (torch.bmm) to compute e_t. ### - To get the tensors into the right shapes for bmm, you will need to do some squeezing and unsqueezing. ### - When using the squeeze() function make sure to specify the dimension you want to squeeze ### over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1. ### ### Use the following docs to implement this functionality: ### Batch Multiplication: ### https://pytorch.org/docs/stable/torch.html#torch.bmm ### Tensor Unsqueeze: ### https://pytorch.org/docs/stable/torch.html#torch.unsqueeze ### Tensor Squeeze: ### https://pytorch.org/docs/stable/torch.html#torch.squeeze dec_state = self.decoder(Ybar_t, dec_state) dec_hidden, dec_cell = dec_state e_t = torch.bmm(enc_hiddens_proj, dec_hidden.unsqueeze(2)) e_t = e_t.squeeze(2) ### END YOUR CODE # Set e_t to -inf where enc_masks has 1 if enc_masks is not None: e_t.data.masked_fill_(enc_masks.byte(), -float('inf')) ### YOUR CODE HERE (~6 Lines) ### TODO: ### 1. Apply softmax to e_t to yield alpha_t ### 2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the ### attention output vector, a_t. #$$ Hints: ### - alpha_t is shape (b, src_len) ### - enc_hiddens is shape (b, src_len, 2h) ### - a_t should be shape (b, 2h) ### - You will need to do some squeezing and unsqueezing. ### Note: b = batch size, src_len = maximum source length, h = hidden size. ### ### 3. Concatenate dec_hidden with a_t to compute tensor U_t ### 4. Apply the combined output projection layer to U_t to compute tensor V_t ### 5. Compute tensor O_t by first applying the Tanh function and then the dropout layer. ### ### Use the following docs to implement this functionality: ### Softmax: ### https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax ### Batch Multiplication: ### https://pytorch.org/docs/stable/torch.html#torch.bmm ### Tensor View: ### https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view ### Tensor Concatenation: ### https://pytorch.org/docs/stable/torch.html#torch.cat ### Tanh: ### https://pytorch.org/docs/stable/torch.html#torch.tanh alpha_t = F.softmax(e_t, dim = 1).unsqueeze(1) a_t = torch.bmm(alpha_t, enc_hiddens).squeeze(1) u_t = torch.cat((dec_hidden, a_t), dim = 1) v_t = self.combined_output_projection(u_t) O_t = self.dropout(torch.tanh(v_t)) ### END YOUR CODE combined_output = O_t return dec_state, combined_output, e_t
def forward(self, x): # input transform if self.use_point_stn: # from tuples to list of single points x = x.view(x.size(0), 3, -1) trans = self.stn1(x) x = x.transpose(2, 1) x = torch.bmm(x, trans) x = x.transpose(2, 1) x = x.contiguous().view(x.size(0), 3 * self.point_tuple, -1) else: trans = None # if self.use_mask: # mask = self.mask(x) # x = torch.cat([x, mask], 1) # else: # mask=None x = F.relu(self.bn0a(self.conv0a(x))) x = F.relu(self.bn0b(self.conv0b(x))) # mlp (64,64) # feature transform if self.use_feat_stn: trans2 = self.stn2(x) x = x.transpose(2, 1) x = torch.bmm(x, trans2) x = x.transpose(2, 1) else: trans2 = None # mlp (64,128,1024) x = F.relu(self.bn1(self.conv1(x))) if self.use_mask: pointfeat = x n_pts = x.size()[2] x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) # mlp (1024,1024*num_scales) if self.num_scales > 1: x = self.bn4(self.conv4(F.relu(x))) if self.get_pointfvals: pointfvals = x else: pointfvals = None # so the intermediate result can be forgotten if it is not needed # symmetric max operation over all points if self.num_scales == 1: if self.sym_op == 'max': x = self.mp1(x) elif self.sym_op == 'sum': x = torch.sum(x, 2, keepdim=True) else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) else: x_scales = x.new_empty(x.size(0), 1024 * self.num_scales**2, 1) if self.sym_op == 'max': for s in range(self.num_scales): x_scales[:, s * self.num_scales * 1024:(s + 1) * self.num_scales * 1024, :] = self.mp1( x[:, :, s * self.num_points:(s + 1) * self.num_points]) elif self.sym_op == 'sum': for s in range(self.num_scales): x_scales[:, s * self.num_scales * 1024:(s + 1) * self.num_scales * 1024, :] = torch.sum( x[:, :, s * self.num_points:(s + 1) * self.num_points], 2, keepdim=True) else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) x = x_scales #x= x.transpose(2,1) x = x.contiguous().view(-1, 1024 * self.num_scales**2) if self.use_mask: x2 = x.view(-1, 1024, 1).repeat(1, 1, n_pts) x2 = torch.cat([x2, pointfeat], 1) x2 = F.relu(self.bna(self.conva(x2))) #x2 = F.relu(self.bnb(self.convb(x2))) x2 = F.relu(self.bnc(self.convc(x2))) mask = self.convd(x2) else: mask = None return x, trans, trans2, mask
def transpose(self, x, A1): Amatrix = torch.tril(A1).transpose(-2,-1) y = torch.bmm(Amatrix, x) return y
def read_vectors(self, memory, read_weights): return torch.bmm(read_weights, memory)