Ejemplo n.º 1
0
    def forward(self, inputs):
        x, u = inputs
        x = self.bn0(x)
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))

        V = self.V(x)
        mu = F.tanh(self.mu(x))

        Q = None
        if u is not None:
            num_outputs = mu.size(1)
            L = self.L(x).view(-1, num_outputs, num_outputs)
            L = L * \
                self.tril_mask.expand_as(
                    L) + torch.exp(L) * self.diag_mask.expand_as(L)
            P = torch.bmm(L, L.transpose(2, 1))

            u_mu = (u - mu).unsqueeze(2)
            A = -0.5 * \
                torch.bmm(torch.bmm(u_mu.transpose(2, 1), P), u_mu)[:, :, 0]

            Q = A + V

        return mu, Q, V
 def predict(self, x_de, x_en):
     bs = x_de.size(0)
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h_enc = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     c_enc = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     h_dec = Variable(torch.zeros(self.n_layers, bs, self.hidden_dim).cuda())
     c_dec = Variable(torch.zeros(self.n_layers, bs, self.hidden_dim).cuda())
     enc_h, _ = self.encoder(emb_de, (h_enc, c_enc)) # (bs,n_de,hiddensz*2)
     dec_h, _ = self.decoder(emb_en, (h_dec, c_dec)) # (bs,n_en,hiddensz)
     # all the same. enc_h is bs,n_de,hiddensz*n_directions. h and c are both n_layers*n_directions,bs,hiddensz
     if self.directions == 2:
         scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2))
     else:
         scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz) * (bs,hiddensz,n_en) = (bs,n_de,n_en)
     scores[(x_de == pad_token).unsqueeze(2).expand(scores.size())] = -math.inf # binary mask
     attn_dist = F.softmax(scores,dim=1) # bs,n_de,n_en
     context = torch.bmm(attn_dist.transpose(2,1),enc_h)
     # (bs,n_en,n_de) * (bs,n_de,hiddensz*ndirections) = (bs,n_en,hiddensz*ndirections)
     pred = self.vocab_layer(torch.cat([dec_h,context],2)) # bs,n_en,len(EN.vocab)
     pred = pred[:,:-1,:] # alignment
     _, tokens = pred.max(2) # bs,n_en-1
     sauce = Variable(torch.cuda.LongTensor([[sos_token]]*bs)) # bs
     return torch.cat([sauce,tokens],1), attn_dist
Ejemplo n.º 3
0
    def forward(self, feat, right, wrong, batch_wrong, fake=None, fake_diff_mask=None):

        num_wrong = wrong.size(1)
        batch_size = feat.size(0)

        feat = feat.view(-1, self.ninp, 1)
        right_dis = torch.bmm(right.view(-1, 1, self.ninp), feat)
        wrong_dis = torch.bmm(wrong, feat)
        batch_wrong_dis = torch.bmm(batch_wrong, feat)

        wrong_score = torch.sum(torch.exp(wrong_dis - right_dis.expand_as(wrong_dis)),1) \
                + torch.sum(torch.exp(batch_wrong_dis - right_dis.expand_as(batch_wrong_dis)),1)

        loss_dis = torch.sum(torch.log(wrong_score + 1))
        loss_norm = right.norm() + feat.norm() + wrong.norm() + batch_wrong.norm()

        if fake:
            fake_dis = torch.bmm(fake.view(-1, 1, self.ninp), feat)
            fake_score = torch.masked_select(torch.exp(fake_dis - right_dis), fake_diff_mask)

            margin_score = F.relu(torch.log(fake_score + 1) - self.margin)
            loss_fake = torch.sum(margin_score)
            loss_dis += loss_fake
            loss_norm += fake.norm()

        loss = (loss_dis + 0.1 * loss_norm) / batch_size
        if fake:
            return loss, loss_fake.data[0] / batch_size
        else:
            return loss
Ejemplo n.º 4
0
    def forward(self, vocab):
        with torch.no_grad():
            batch_shape = vocab['sentence'].shape
            s_embedding = self.embedding(vocab['sentence'].cuda())
            a_embedding = self.embedding(vocab['aspect'].cuda())

            packed_s = pack_padded_sequence(s_embedding, vocab['sent_len'], batch_first=True)

        out_s, (h_s, c1) = self.lstm_s(packed_s) # packed output
        out_a, (h_a, c2) = self.lstm_a(a_embedding)

        with torch.no_grad():
            unpacked_out_s, _ = pad_packed_sequence(out_s, batch_first=True)

        # Pair-wise interaction matrix
        I_matrix = torch.bmm(unpacked_out_s, out_a.permute(0,2,1))

        # Column-wise softmax
        a2s_attn = F.softmax(I_matrix, dim=1)

        # Row-wise softmax => Column-wise average => aspect attention
        s2a_attn = F.softmax(I_matrix, dim=2)
        a_attn = torch.mean(s2a_attn, dim=1)

        # Final sentence attn => weighted sum of each individual a2s_attn
        s_attn = torch.bmm(a2s_attn, a_attn.unsqueeze(-1))

        final_rep = torch.bmm(unpacked_out_s.permute(0,2,1), s_attn).squeeze(-1)
        pred = self.fc(final_rep)
        return pred
Ejemplo n.º 5
0
 def predict(self, x, attn_type = "hard"):
     #predict with greedy decoding
     emb = self.embedding(x)
     h = Variable(torch.zeros(1, x.size(0), self.hidden_dim))
     c = Variable(torch.zeros(1, x.size(0), self.hidden_dim))
     enc_h, _ = self.encoder(emb, (h, c))
     y = [Variable(torch.zeros(x.size(0)).long())]
     self.attn = []        
     for t in range(x.size(1)):
         emb_t = self.embedding(y[-1])
         dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c))
         scores = torch.bmm(enc_h, dec_h.transpose(1,2)).squeeze(2)
         attn_dist = F.softmax(scores, dim = 1)
         self.attn.append(attn_dist.data)
         if attn_type == "hard":
             _, argmax = attn_dist.max(1)
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1))
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)                    
         else:                
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1)
         pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1))
         _, next_token = pred.max(1)
         y.append(next_token)
     self.attn = torch.stack(self.attn, 0).transpose(0, 1)
     return torch.stack(y, 0).transpose(0, 1)
Ejemplo n.º 6
0
    def forward(self, ht, hs, mask, weighted_ctx=True):
        '''
        ht: batch x ht_dim
        hs: (seq_len x batch x hs_dim, seq_len x batch x ht_dim)
        mask: seq_len x batch
        '''
        hs, hs_ = hs
        # seq_len, batch, _ = hs.size()
        hs = hs.transpose(0, 1)
        hs_ = hs_.transpose(0, 1)
        # hs: batch x seq_len x hs_dim
        # hs_: batch x seq_len x ht_dim
        # hs_ = self.hs2ht(hs)
        # Alignment/Attention Function
        # batch x ht_dim x 1
        ht = ht.unsqueeze(2)
        # batch x seq_len
        score = torch.bmm(hs_, ht).squeeze(2)
        # attn = F.softmax(score, dim=-1)
        attn = F.softmax(score, dim=-1) * mask.transpose(0, 1) + EPSILON
        attn = attn / attn.sum(-1, keepdim=True)

        # Compute weighted sum of hs by attention.
        # batch x 1 x seq_len
        attn = attn.unsqueeze(1)
        if weighted_ctx:
            # batch x hs_dim
            weight_hs = torch.bmm(attn, hs).squeeze(1)
        else:
            weight_hs = None

        return weight_hs, attn
Ejemplo n.º 7
0
    def forward(self, agent_qs, states):
        """Forward pass for the mixer.

        Arguments:
            agent_qs: Tensor of shape [B, T, n_agents, n_actions]
            states: Tensor of shape [B, T, state_dim]
        """
        bs = agent_qs.size(0)
        states = states.reshape(-1, self.state_dim)
        agent_qs = agent_qs.view(-1, 1, self.n_agents)
        # First layer
        w1 = th.abs(self.hyper_w_1(states))
        b1 = self.hyper_b_1(states)
        w1 = w1.view(-1, self.n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        hidden = F.elu(th.bmm(agent_qs, w1) + b1)
        # Second layer
        w_final = th.abs(self.hyper_w_final(states))
        w_final = w_final.view(-1, self.embed_dim, 1)
        # State-dependent bias
        v = self.V(states).view(-1, 1, 1)
        # Compute final output
        y = th.bmm(hidden, w_final) + v
        # Reshape and return
        q_tot = y.view(bs, -1, 1)
        return q_tot
    def forward(self, output, context):
        batch_size = output.size(0)
        hidden_size = output.size(2)
        input_size = context.size(1)

        # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len)
        attn = torch.bmm(output, context.transpose(1, 2))
        mask = torch.eq(attn, 0).data.byte()
        attn.data.masked_fill_(mask, -float('inf'))
        attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)

        # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim)
        mix = torch.bmm(attn, context)

        # concat -> (batch, out_len, 2*dim)
        combined = torch.cat((mix, output), dim=2)

        # output -> (batch, out_len, dim)
        output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size)


        if not output.is_contiguous():
            output = output.contiguous()

        return output, attn
Ejemplo n.º 9
0
    def forward(self, q, k, v):
        b_q, t_q, dim_q = list(q.size())
        b_k, t_k, dim_k = list(k.size())
        b_v, t_v, dim_v = list(v.size())
        assert(b_q == b_k and b_k == b_v)  # batch size should be equal
        assert(dim_q == dim_k)  # dims should be equal
        assert(t_k == t_v)  # times should be equal
        b = b_q
        qk = torch.bmm(q, k.transpose(1, 2))  # b x t_q x t_k
        qk.div_(dim_k ** 0.5)
        mask = None
        if self.causal and t_q > 1:
            causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1)
            mask = Variable(causal_mask.unsqueeze(0).expand(b, t_q, t_k),
                            requires_grad=False)
        if self.mask_k is not None:
            mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k)
            mask = mask_k if mask is None else mask | mask_k
        if self.mask_q is not None:
            mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
            mask = mask_q if mask is None else mask | mask_q
        if mask is not None:
            qk.masked_fill_(mask, -1e9)

        sm_qk = F.softmax(qk, dim=2)
        sm_qk = self.dropout(sm_qk)
        return torch.bmm(sm_qk, v), sm_qk  # b x t_q x dim_v
Ejemplo n.º 10
0
    def forward_dot(self, hid, ctx, ctx_mask):
        r"""Computes Luong-style dot attention probabilities between
        decoder's hidden state and source annotations.

        Arguments:
            hid(Variable): A set of decoder hidden states of shape `T*B*H`
                where `T` == 1, `B` is batch dim and `H` is hidden state dim.
            ctx(Variable): A set of annotations of shape `S*B*C` where `S`
                is the source timestep dim, `B` is batch dim and `C`
                is annotation dim.
            ctx_mask(FloatTensor): A binary mask of shape `S*B` with zeroes
                in the padded timesteps.

        Returns:
            scores(Variable): A variable of shape `S*B` containing normalized
                attention scores for each position and sample.
            z_t(Variable): A variable of shape `B*H` containing the final
                attended context vector for this target decoding timestep.
        """
        # Apply transformations first to make last dims both C and then
        # shuffle dims to prepare for batch mat-mult
        ctx_ = self.ctx2ctx(ctx).permute(1, 2, 0)   # S*B*C -> S*B*C -> B*C*S
        hid_ = self.hid2ctx(hid).permute(1, 0, 2)   # T*B*H -> T*B*C -> B*T*C

        # 'dot' scores of B*T*S
        scores = F.softmax(torch.bmm(hid_, ctx_), dim=-1)

        # Transform back to hidden_dim for further decoders
        # B*T*S x B*S*C -> B*T*C -> B*T*H
        z_t = self.ctx2hid(torch.bmm(scores, ctx.transpose(0, 1)))

        return scores.transpose(0, 1), z_t.transpose(0, 1)
Ejemplo n.º 11
0
def bnorm(x, U):
    mx = torch.bmm(U,x)
    subs = x-mx
    subs2 = subs*subs
    vx = torch.bmm(U,subs2)
    out = subs / (vx.clamp(min=1e-10).sqrt() + 1e-5)
    return out
Ejemplo n.º 12
0
    def forward(self, q, k, v, attn_mask=None):

        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        residual = q
        #print('q,k,v:',q.size(),k.size(),v.size())
        mb_size, len_q, q_hidden_size = q.size()
        mb_size, len_k, k_hidden_size = k.size()
        mb_size, len_v, v_hidden_size = v.size()

        # treat as a (n_head) size batch
        q_s = q.repeat(n_head, 1, 1).view(n_head, -1, q_hidden_size) # n_head x (mb_size*len_q) x d_model
        k_s = k.repeat(n_head, 1, 1).view(n_head, -1, k_hidden_size) # n_head x (mb_size*len_k) x d_model
        v_s = v.repeat(n_head, 1, 1).view(n_head, -1, v_hidden_size) # n_head x (mb_size*len_v) x d_model
        #print('q_s,k_s,v_s:',q_s.size(),k_s.size(),v_s.size())
        #print('w_qs',self.w_qs.size())
        # treat the result as a (n_head * mb_size) size batch
        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)   # (n_head*mb_size) x len_q x d_k
        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)   # (n_head*mb_size) x len_k x d_k
        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)   # (n_head*mb_size) x len_v x d_v

        # perform attention, result size = (n_head * mb_size) x len_q x d_v
        #print('attn_mask:',attn_mask.size())
        #print(attn_mask)
        outputs, attns = self.attention.forward(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head,1,1))

        # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
        outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) 

        # project back to residual size
        outputs = self.proj.forward(outputs)
        outputs = self.dropout(outputs)

        return self.layer_norm(outputs + residual), attns
Ejemplo n.º 13
0
 def predict2(self, x_de, beamsz, gen_len):
     emb_de = self.embedding_de(x_de) # "batch size",n_de,word_dim, but "batch size" is 1 in this case!
     h0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda())
     c0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda())
     enc_h, _ = self.encoder(emb_de, (h0, c0))
     # since enc batch size=1, enc_h is 1,n_de,hiddensz*n_directions
     if self.directions == 2:
         enc_h = self.dim_reduce(enc_h) # 1,n_de,hiddensz
     masterheap = CandList(self.n_layers,self.hidden_dim,enc_h.size(1),beamsz)
     # in the following loop, beamsz is length 1 for first iteration, length true beamsz (100) afterward
     for i in range(gen_len):
         prev = masterheap.get_prev() # beamsz
         emb_t = self.embedding_en(prev) # embed the last thing we generated. beamsz,word_dim
         enc_h_expand = enc_h.expand(prev.size(0),-1,-1) # beamsz,n_de,hiddensz
         
         h, c = masterheap.get_hiddens() # (n_layers,beamsz,hiddensz),(n_layers,beamsz,hiddensz)
         dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c)) # dec_h is beamsz,1,hiddensz (batch_first=True)
         scores = torch.bmm(enc_h_expand, dec_h.transpose(1,2)).squeeze(2)
         # (beamsz,n_de,hiddensz) * (beamsz,hiddensz,1) = (beamsz,n_de,1). squeeze to beamsz,n_de
         attn_dist = F.softmax(scores,dim=1)
         if self.attn_type == "hard":
             _, argmax = attn_dist.max(1) # beamsz for each batch, select most likely german word to pay attention to
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda())
             context = torch.bmm(one_hot.unsqueeze(1), enc_h_expand).squeeze(1)
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h_expand).squeeze(1)
         # the difference btwn hard and soft is just whether we use a one_hot or a distribution
         # context is beamsz,hiddensz*n_directions
         pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1)) # beamsz,len(EN.vocab)
         # TODO: set the columns corresponding to <pad>,<unk>,</s>,etc to 0
         masterheap.update_beam(pred)
         masterheap.update_hiddens(h,c)
         masterheap.update_attentions(attn_dist)
         masterheap.firstloop = False
     return masterheap.probs,masterheap.wordlist,masterheap.attentions
Ejemplo n.º 14
0
 def predict(self, x_de, x_en):
     bs = x_de.size(0)
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     c = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     enc_h, _ = self.encoder(emb_de, (h, c))
     dec_h, _ = self.decoder(emb_en, (h, c))
     # all the same. enc_h is bs,n_de,hiddensz*n_directions. h and c are both n_layers*n_directions,bs,hiddensz
     if self.directions == 2:
         enc_h = self.dim_reduce(enc_h) # bs,n_de,hiddensz
     scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz) * (bs,hiddensz,n_en) = (bs,n_de,n_en)
     y = [Variable(torch.cuda.LongTensor([sos_token]*bs))] # bs
     self.attn = []
     for t in range(x_en.size(1)-1): # iterate over english words, with teacher forcing
         attn_dist = F.softmax(scores[:,:,t],dim=1) # bs,n_de
         self.attn.append(attn_dist.data)
         if self.attn_type == "hard":
             _, argmax = attn_dist.max(1) # bs. for each batch, select most likely german word to pay attention to
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda())
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1)
         # the difference btwn hard and soft is just whether we use a one_hot or a distribution
         # context is bs,hiddensz
         pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab)
         _, next_token = pred.max(1) # bs
         y.append(next_token)
     self.attn = torch.stack(self.attn, 0).transpose(0, 1) # bs,n_en,n_de (for visualization!)
     y = torch.stack(y,0).transpose(0,1) # bs,n_en
     return y,self.attn
Ejemplo n.º 15
0
    def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size)
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .unsqueeze(0)
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
Ejemplo n.º 16
0
def lstsq(b, y, alpha=0.01):
    """
    Batched linear least-squares for pytorch with optional L1 regularization.

    Parameters
    ----------

    b : shape(L, M, N)
    y : shape(L, M)

    Returns
    -------
    tuple of (coefficients, model, residuals)

    """
    bT = b.transpose(-1, -2)
    AA = torch.bmm(bT, b)
    if alpha != 0:
        diag = torch.diagonal(AA, dim1=1, dim2=2)
        diag += alpha
    RHS = torch.bmm(bT, y[:, :, None])
    X, LU = torch.gesv(RHS, AA)
    fit = torch.bmm(b, X)[..., 0]
    res = y - fit
    return X[..., 0], fit, res
Ejemplo n.º 17
0
 def forward(self, x_de, x_en, update_baseline=True):
     bs = x_de.size(0)
     # x_de is bs,n_de. x_en is bs,n_en
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     c0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     h0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     c0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     # hidden vars have dimension n_layers*n_directions,bs,hiddensz
     enc_h, _ = self.encoder(emb_de, (Variable(h0_enc), Variable(c0_enc)))
     # enc_h is bs,n_de,hiddensz*n_directions. ordering is different from last week because batch_first=True
     dec_h, _ = self.decoder(emb_en, (Variable(h0_dec), Variable(c0_dec)))
     # dec_h is bs,n_en,hidden_size*n_directions
     # we've gotten our encoder/decoder hidden states so we are ready to do attention
     # first let's get all our scores, which we can do easily since we are using dot-prod attention
     if self.directions == 2:
         scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2))
         # TODO: any easier ways to reduce dimension?
     else:
         scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz*n_directions) * (bs,hiddensz*n_directions,n_en) = (bs,n_de,n_en)
     reinforce_loss = 0 # we only use this variable for hard attention
     loss = 0
     avg_reward = 0
     # we just iterate to dec_h.size(1)-1, since there's </s> at the end of each sentence
     for t in range(dec_h.size(1)-1): # iterate over english words, with teacher forcing
         attn_dist = F.softmax(scores[:, :, t],dim=1) # bs,n_de. these are the alphas (attention scores for each german word)
         if self.attn_type == "hard":
             cat = torch.distributions.Categorical(attn_dist) 
             attn_samples = cat.sample() # bs. each element is a sample from categorical distribution
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, attn_samples.data.unsqueeze(1), 1).cuda()) # bs,n_de
             # made a bunch of one-hot vectors
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)
             # now we use the one-hot vectors to select correct hidden vectors from enc_h
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions). squeeze to bs,hiddensz*n_directions
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # same dimensions
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions)
         # context is bs,hidden_size*n_directions
         # the rnn output and the context together make the decoder "hidden state", which is bs,2*hidden_size*n_directions
         pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab)
         y = x_en[:, t+1] # bs. these are our labels
         no_pad = (y != pad_token) # exclude english padding tokens
         reward = torch.gather(pred, 1, y.unsqueeze(1)) # bs,1
         # reward[i,1] = pred[i,y[i]]. this gets log prob of correct word for each batch. similar to -crossentropy
         reward = reward.squeeze(1)[no_pad] # less than bs
         if self.attn_type == "hard":
             reinforce_loss -= (cat.log_prob(attn_samples[no_pad]) * (reward-self.baseline).detach()).sum() 
             # reinforce rule (just read the formula), with special baseline
         loss -= reward.sum() # minimizing loss is maximizing reward
     no_pad_total = (x_en[:,1:] != pad_token).data.sum() # TODO: i think this is right, right?
     loss /= no_pad_total
     reinforce_loss /= no_pad_total
     avg_reward = -loss.data[0]
     if update_baseline: # update baseline as a moving average
         self.baseline = Variable(0.95*self.baseline.data + 0.05*avg_reward)
     return loss, reinforce_loss,avg_reward
    def forward(self, match_encoders):
        
        '''
        match_encoders (pn_steps, batch, hidden_size*2)
        '''
        vh_matrix = self.vh_net(match_encoders) # pn_steps, batch, hidden_size
        
        # prediction start
        h0 = Variable(torch.zeros(match_encoders.size(1), self.hidden_size)).cuda()
        c0 = Variable(torch.zeros(match_encoders.size(1), self.hidden_size)).cuda()
        
        wha1 = self.wa_net(h0) # bacth, hidden_size
        wha1 = wha1.expand(match_encoders.size(0), wha1.size(0), wha1.size(1)) # pn_steps, batch, hidden_size
        #print ('_sum.size() ', _sum.size())
        #print ('vh_matrix.size() ', vh_matrix.size())
        f1 = self.tanh(vh_matrix + wha1) # pn_steps, batch, hidden_size
        #print ('f1.size() ', f1.size())
        vf1 = self.v_net(f1.transpose(0, 1)).squeeze(-1) #batch, pn_steps
        
        beta1 = self.softmax(vf1) #batch, pn_steps
        softmax_beta1 = self.softmax(beta1).view(beta1.size(0), 1, beta1.size(1)) #batch, 1, pn_steps
        
        inp = torch.bmm(softmax_beta1, match_encoders.transpose(0, 1)) # bacth, 1, hidden_size
        inp = inp.squeeze(1) # bacth, hidden_size
        
        h1, c1 = self.pointer_lstm(inp, (h0, c0))
        
        
        wha2 = self.wa_net(h1) # bacth, hidden_size
        wha2 = wha2.expand(match_encoders.size(0), wha2.size(0), wha2.size(1)) # pn_steps, batch, hidden_size
        f2 = self.tanh(vh_matrix + wha2) # pn_steps, batch, hidden_size
        vf2 = self.v_net(f2.transpose(0, 1)).squeeze(-1) #batch, pn_steps
        
        beta2 = self.softmax(vf2)#batch, pn_steps
        softmax_beta2 = self.softmax(beta2).view(beta2.size(0), 1, beta2.size(1)) #batch, 1, pn_steps
        
        inp = torch.bmm(softmax_beta2, match_encoders.transpose(0, 1)) # bacth, 1, hidden_size
        inp = inp.squeeze(1) # bacth, hidden_size
        
        h2, c2 = self.pointer_lstm(inp, (h1, c1))
            
        _, start = torch.max(beta1, 1)
        _, end = torch.max(beta2, 1)
        
        beta1 = beta1.view(1, beta1.size(0), beta1.size(1))
        beta2 = beta2.view(1, beta2.size(0), beta2.size(1))
        
        logits = torch.cat([beta1, beta2])
        
        start = start.view(1, start.size(0))
        end = end.view(1, end.size(0))
        
        prediction = torch.cat([start, end]).transpose(0, 1).cpu().data.numpy()
        

        return logits, prediction
Ejemplo n.º 19
0
 def forward(self, query_embeddings, in_memory_embeddings, out_memory_embeddings, attention_mask=None):
     attention = torch.bmm(in_memory_embeddings, query_embeddings.unsqueeze(2)).squeeze(2)
     if attention_mask is not None:
         # exclude masked elements from the softmax
         attention = attention_mask.float() * attention + (1 - attention_mask.float()) * -1e20
     probs = softmax(attention).unsqueeze(1)
     memory_output = torch.bmm(probs, out_memory_embeddings).squeeze(1)
     query_embeddings = self.linear(query_embeddings)
     output = memory_output + query_embeddings
     return output
def lk_tensor_track_batch(feature_old, feature_new, pts_locations, patch_size, max_step, feature_template=None):
  # feature[old,new] : 4-D tensor [1, C, H, W]
  # pts_locations is a 2-D tensor [Num-Pts, (Y,X)]
  if feature_new.dim() == 3:
    feature_new = feature_new.unsqueeze(0)
  if feature_old is not None and feature_old.dim() == 3:
    feature_old = feature_old.unsqueeze(0)
  assert feature_new.dim() == 4, 'The dimension of feature-new is not right : {}.'.format(feature_new.dim())
  BB, C, H, W = list(feature_new.size())
  if feature_old is not None:
    assert 1 == feature_old.size(0) and 1 == BB, 'The first dimension of feature should be one not {}'.format(feature_old.size())
    assert C == feature_old.size(1) and H == feature_old.size(2) and W == feature_old.size(3), 'The size is not right : {}'.format(feature_old.size())
  assert isinstance(patch_size, int), 'The format of lk-parameters are not right : {}'.format(patch_size)
  num_pts = pts_locations.size(0)
  device = feature_new.device

  weight_map = Generate_Weight( [patch_size*2+1, patch_size*2+1] ) # [H, W]
  with torch.no_grad():
    weight_map = torch.tensor(weight_map).view(1, 1, 1, patch_size*2+1, patch_size*2+1).to(device)

    sobelconvx = SobelConv('x', feature_new.dtype).to(device)
    sobelconvy = SobelConv('y', feature_new.dtype).to(device)
  
  # feature_T should be a [num_pts, C, patch, patch] tensor
  if feature_template is None:
    feature_T = warp_feature_batch(feature_old, pts_locations, patch_size)
  else:
    assert feature_old is None, 'When feature_template is not None. feature_old must be None'
    feature_T = feature_template
  assert feature_T.size(2) == patch_size * 2 + 1 and feature_T.size(3) == patch_size * 2 + 1, 'The size of feature-template is not ok : {}'.format(feature_T.size())
  gradiant_x = sobelconvx(feature_T)
  gradiant_y = sobelconvy(feature_T)
  J = torch.stack([gradiant_x, gradiant_y], dim=1)
  weightedJ = J * weight_map
  H = torch.bmm( weightedJ.view(num_pts,2,-1), J.view(num_pts, 2, -1).transpose(2,1) )
  inverseH = torch_inverse_batch(H)

  #print ('PTS : {}'.format(pts_locations))
  for step in range(max_step):
    # Step-1 Warp I with W(x,p) to compute I(W(x;p))
    feature_I = warp_feature_batch(feature_new, pts_locations, patch_size)
    # Step-2 Compute the error feature
    r = feature_I - feature_T
    # Step-7 Compute sigma
    sigma = torch.bmm(weightedJ.view(num_pts,2,-1), r.view(num_pts,-1, 1))
    # Step-8 Compute delta-p
    deltap = torch.bmm(inverseH, sigma).squeeze(-1)
    pts_locations = pts_locations - deltap

  return pts_locations
Ejemplo n.º 21
0
    def updateGradInput(self, input, gradOutput):
        M, v = input
        self.gradInput[0].resize_as_(M)
        self.gradInput[1].resize_as_(v)
        gradOutput = gradOutput.contiguous()

        assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2

        if gradOutput.ndimension() == 2:
            assert M.ndimension() == 3
            assert v.ndimension() == 2
            bdim = M.size(0)
            odim = M.size(1)
            idim = M.size(2)

            if self.trans:
                torch.bmm(v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim), out=self.gradInput[0])
                torch.bmm(M, gradOutput.view(bdim, idim, 1), out=self.gradInput[1].view(bdim, odim, 1))
            else:
                torch.bmm(gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim), out=self.gradInput[0])
                torch.bmm(M.transpose(1, 2), gradOutput.view(bdim, odim, 1), out=self.gradInput[1].view(bdim, idim, 1))
        else:
            assert M.ndimension() == 2
            assert v.ndimension() == 1

            if self.trans:
                torch.ger(v, gradOutput, out=self.gradInput[0])
                self.gradInput[1] = M * gradOutput
            else:
                torch.ger(gradOutput, v, out=self.gradInput[0])
                self.gradInput[1] = M.t() * gradOutput

        return self.gradInput
Ejemplo n.º 22
0
    def dot_prod_attention(self, h_t: torch.Tensor, src_encoding: torch.Tensor, src_encoding_att_linear: torch.Tensor,
                           mask: torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
        # (batch_size, src_sent_len)
        att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)

        if mask is not None:
            att_weight.data.masked_fill_(mask.byte(), -float('inf'))

        softmaxed_att_weight = F.softmax(att_weight, dim=-1)

        att_view = (att_weight.size(0), 1, att_weight.size(1))
        # (batch_size, hidden_size)
        ctx_vec = torch.bmm(softmaxed_att_weight.view(*att_view), src_encoding).squeeze(1)

        return ctx_vec, softmaxed_att_weight
Ejemplo n.º 23
0
    def forward(self, input, context):
        if not self.dot:
            targetT = self.linear_in(input).unsqueeze(2)  # batch x dim x 1
        else:
            targetT = input.unsqueeze(2)

        context_scores = torch.bmm(context, targetT).squeeze(2)
        context_scores.data.masked_fill_(self.context_mask, -float('inf'))
        context_attention = F.softmax(context_scores, dim=-1) + EPSILON
        context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1)

        combined_representation = torch.cat([input, context_alignment], 1)
        output = self.tanh(self.linear_out(combined_representation))

        return output, context_attention, context_alignment
Ejemplo n.º 24
0
    def forward_predict(self, x, graphs, data_agent, constrain_=True):
        graphs = deepcopy(graphs)
        opt = self.opt
        batch_size = x.size(0)
        h_0 = Variable(torch.zeros(opt['rnn_layers'], batch_size, opt['rnn_h']))
        if opt['cuda']:
            h_0 = h_0.cuda()

        enc_out, hidden = self.encoder(self.input_emb(x), h_0) # [batch, seq_in, h], [layer, batch, h]
        text_out = [[] for _ in range(batch_size)]
        for i in range(opt['max_seq_out']):
            if i == 0:
                y_in = Variable(torch.zeros(batch_size, 1, self.y_dim))
                if opt['cuda']:
                    y_in = y_in.cuda()
            else:
                y_in = y_onehot.unsqueeze(1)

            dec_out, hidden = self.decoder(y_in, hidden)
            alpha = F.softmax(torch.bmm(enc_out, hidden[-1].unsqueeze(2)))
            attention = torch.bmm(enc_out.transpose(1, 2), alpha).squeeze(2)
            dec_out = self.mapping(torch.cat([attention, dec_out.squeeze(1)], dim=1)) # [batch, y_dim]

            y_mask = torch.zeros(batch_size, self.y_dim)
            for j in range(batch_size):
                data_agent.get_mask(graphs[j], y_mask[j])
            if opt['cuda']:
                y_mask = y_mask.cuda()
            y_mask = Variable(y_mask, volatile=True)
            if constrain_:
                dec_out = dec_out * y_mask + -1e7 * (1 - y_mask)

            y_out = torch.max(dec_out, 1, keepdim=True)[1].data # [batch, 1]
            y_onehot = torch.zeros(batch_size, self.y_dim) # [batch, y_dim]
            y_onehot.scatter_(1, y_out.cpu(), 1)
            y_onehot = Variable(y_onehot)
            if opt['cuda']:
                y_onehot = y_onehot.cuda()

            y_out = y_out.squeeze()
            for j in range(batch_size):
                if len(text_out[j]) > 0 and text_out[j][-1] == 'STOP': continue
                text_out[j].append(data_agent.reverse_parse_action(data_agent.get_action_tuple(y_out[j])))
                if text_out[j][-1] != 'STOP':
                    exec_result = graphs[j].parse_exec(text_out[j][-1])
                    if constrain_:
                        assert exec_result, text_out[j][-1]
        return text_out
Ejemplo n.º 25
0
    def forward(self, video_instances, resnet_ftrs, optical_ftrs, object_ftrs):

        #################### Attention Level 1 #################################
        total_instances = sum(video_instances)

        attn1 = self.attn1(object_ftrs)
        # [700, 100, 4, 4]
        attn1 = attn1.view(total_instances*self.num_frames, self.obj_per_frame, 1)
        # [70000, 4, 1]
        object_ftrs = object_ftrs.view(total_instances*self.num_frames, \
                        resnet_dim, self.obj_per_frame)
        # [700,100,4,2048] to [70000,2048,4] for bmm
        object_attended = torch.bmm(object_ftrs, attn1)
        # [70000, 2048, 4] and [70000, 4, 1] to [70000, 2048, 1]
        object_attended = object_attended.view(total_instances, self.num_frames, resnet_dim)
        # [70000, 2048, 1] to [700, 100, 2048]

        ###################### Attention Level 2 ###############################

        all_features = torch.cat((object_attended, resnet_ftrs, optical_ftrs), 2)
        # [700, 100, 3*2048]
        attn2 = self.attn2(all_features)
        # [700, 100, 3]
        attn2 = attn2.view(total_instances*self.num_frames, 3, 1)
        # [70000, 3, 1]

        all_features = all_features.view(total_instances*self.num_frames, \
                                         resnet_dim, 3)
        # [700, 100, 3*2048] to [70000, 2048, 3]
        features_attended = torch.bmm(all_features, attn2)
        # [70000, 2048, 3] and [70000, 3, 1] to [70000, 2048, 1]
        features_attended = features_attended.view(total_instances, self.num_frames, resnet_dim)
        # [70000, 2048, 1] to [700, 100, 2048]

        ##################### Attention Level 3 ################################

        video_feature = features_attended.view(total_instances, self.num_frames* resnet_dim)
        # [700, 100, 2048] to [700, 100*2048]

        attn3 = self.attn3(video_feature)
        # [700, 100]
        attn3 = attn3.unsqueeze(2).repeat(1, 1, resnet_dim)
        # [700, 100] to [700, 100, 1] to [700, 100, 2048]

        video_feature = video_feature.view(total_instances, self.num_frames, resnet_dim)
        video_attended = video_feature * attn3
        # [700, 100, 2048]
        return video_attended, self.attn1, self.attn2, self.attn3
Ejemplo n.º 26
0
    def forward(self, context_ids, doc_ids, target_noise_ids):
        """Sparse computation of scores (unnormalized log probabilities)
        that should be passed to the negative sampling loss.

        Parameters
        ----------
        context_ids: torch.Tensor of size (batch_size, num_context_words)
            Vocabulary indices of context words.

        doc_ids: torch.Tensor of size (batch_size,)
            Document indices of paragraphs.

        target_noise_ids: torch.Tensor of size (batch_size, num_noise_words + 1)
            Vocabulary indices of target and noise words. The first element in
            each row is the ground truth index (i.e. the target), other
            elements are indices of samples from the noise distribution.

        Returns
        -------
            autograd.Variable of size (batch_size, num_noise_words + 1)
        """
        # combine a paragraph vector with word vectors of
        # input (context) words
        x = torch.add(
            self._D[doc_ids, :], torch.sum(self._W[context_ids, :], dim=1))

        # sparse computation of scores (unnormalized log probabilities)
        # for negative sampling
        return torch.bmm(
            x.unsqueeze(1),
            self._O[:, target_noise_ids].permute(1, 0, 2)).squeeze()
Ejemplo n.º 27
0
    def forward(self, s_t_hat, h, enc_padding_mask, coverage):
        b, t_k, n = list(h.size())
        h = h.view(-1, n)  # B * t_k x 2*hidden_dim
        encoder_feature = self.W_h(h)

        dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim
        dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim
        dec_fea_expanded = dec_fea_expanded.view(-1, n)  # B * t_k x 2*hidden_dim

        att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim
        if config.is_coverage:
            coverage_input = coverage.view(-1, 1)  # B * t_k x 1
            coverage_feature = self.W_c(coverage_input)  # B * t_k x 2*hidden_dim
            att_features = att_features + coverage_feature

        e = F.tanh(att_features) # B * t_k x 2*hidden_dim
        scores = self.v(e)  # B * t_k x 1
        scores = scores.view(-1, t_k)  # B x t_k

        attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k
        normalization_factor = attn_dist_.sum(1, keepdim=True)
        attn_dist = attn_dist_ / normalization_factor

        attn_dist = attn_dist.unsqueeze(1)  # B x 1 x t_k
        h = h.view(-1, t_k, n)  # B x t_k x 2*hidden_dim
        c_t = torch.bmm(attn_dist, h)  # B x 1 x n
        c_t = c_t.view(-1, config.hidden_dim * 2)  # B x 2*hidden_dim

        attn_dist = attn_dist.view(-1, t_k)  # B x t_k

        if config.is_coverage:
            coverage = coverage.view(-1, t_k)
            coverage = coverage + attn_dist

        return c_t, attn_dist, coverage
Ejemplo n.º 28
0
def outer(vec1, vec2=None):
    '''Batch support for vectors outer products.

    This function is broadcast-able,
    so you can provide batched vec1 or batched vec2 or both.

    Args:
        vec1: A vector of size (Batch, Size1).
        vec2: A vector of size (Batch, Size2)
            if vec2 is None, vec2 = vec1.

    Returns:
        The outer product of vec1 and vec2 (Batch, Size1, Size2).
    '''
    if vec2 is None:
        vec2 = vec1
    if len(vec1.size()) == 1 and len(vec2.size()) == 1:
        return torch.ger(vec1, vec2)
    else:  # batch outer product
        if len(vec1.size()) == 1:
            vec1 = torch.unsqueeze(vec1, 0)
        if len(vec2.size()) == 1:
            vec2 = torch.unsqueeze(vec2, 0)
        vec1 = torch.unsqueeze(vec1, -1)
        vec2 = torch.unsqueeze(vec2, -2)
        if vec1.size(0) == vec2.size(0):
            return torch.bmm(vec1, vec2)
        else:
            return vec1.matmul(vec2)
    def forward(self, xt, fc_feats, att_feats, p_att_feats, state):
        # The p_att_feats here is already projected
        att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
        att = p_att_feats.view(-1, att_size, self.att_hid_size)
        
        att_h = self.h2att(state[0][-1])                        # batch * att_hid_size
        att_h = att_h.unsqueeze(1).expand_as(att)            # batch * att_size * att_hid_size
        dot = att + att_h                                   # batch * att_size * att_hid_size
        dot = F.tanh(dot)                                # batch * att_size * att_hid_size
        dot = dot.view(-1, self.att_hid_size)               # (batch * att_size) * att_hid_size
        dot = self.alpha_net(dot)                           # (batch * att_size) * 1
        dot = dot.view(-1, att_size)                        # batch * att_size
        
        weight = F.softmax(dot)                             # batch * att_size
        att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size
        att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size

        all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
        sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
        sigmoid_chunk = F.sigmoid(sigmoid_chunk)
        in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
        forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
        out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)

        in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
            self.a2c(att_res)
        in_transform = torch.max(\
            in_transform.narrow(1, 0, self.rnn_size),
            in_transform.narrow(1, self.rnn_size, self.rnn_size))
        next_c = forget_gate * state[1][-1] + in_gate * in_transform
        next_h = out_gate * F.tanh(next_c)

        output = self.dropout(next_h)
        state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
        return output, state
Ejemplo n.º 30
0
 def score(self, hidden, encoder_outputs):
     # [B*T*2H]->[B*T*H]
     energy = self.attn(torch.cat([hidden, encoder_outputs], 2))
     energy = energy.transpose(1, 2)  # [B*H*T]
     v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)  # [B*1*H]
     energy = torch.bmm(v, energy)  # [B*1*T]
     return energy.squeeze(1)  # [B*T]
Ejemplo n.º 31
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)
        if reinforce:
            raise NotImplementedError('Our model doesn\'t have RL')

        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()

        h_num_enc, _ = run_lstm(self.cond_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc *
                      num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)

        #Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)

        col_att_val = torch.bmm(e_cond_col,
                                self.cond_col_att(h_col_enc).transpose(1, 2))
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                col_att_val[idx, :, num:] = -100
        col_att = self.softmax(col_att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)

        cond_col_score = self.cond_col_out(
            self.cond_col_out_K(K_cond_col) +
            self.cond_col_out_col(e_cond_col)).squeeze()
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100

        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            # print gt_cond
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x]
                 for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        op_att_val = torch.matmul(
            self.cond_op_att(h_op_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                op_att_val[idx, :, num:] = -100
        op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
        K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)

        cond_op_score = self.cond_op_out(
            self.cond_op_out_K(K_cond_op) +
            self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                gt_tok_seq.view(B * 4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            g_ext = g_str_s.unsqueeze(3)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)

            cond_str_score = self.cond_str_out(
                self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                self.cond_str_out_col(col_ext)).squeeze(4)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:, 0, 0] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                cur_cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B * 4,
                                                         max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B * 4, self.max_tok_num).scatter_(
                    1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_score = (cond_num_score, cond_col_score, cond_op_score,
                      cond_str_score)

        return cond_score
Ejemplo n.º 32
0
 def forward(self, x, A1, diag):
     Amatrix = (A1 + A1.transpose(-2,-1))
     A = Amatrix + diag.diag_embed(dim1=-2, dim2=-1)
     y = torch.bmm(A, x)
     return y
Ejemplo n.º 33
0
    def step(self, Ybar_t: torch.tensor,
             dec_state: Tuple[torch.tensor, torch.tensor],
             enc_hiddens: torch.tensor,
             enc_hiddens_proj: torch.tensor,
             enc_masks: torch.tensor) -> Tuple[Tuple, torch.tensor, torch.tensor]:
        """ Compute one forward step of the LSTM decoder, including the attention computation.

        @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h_e). The input for the decoder,
                                where b = batch size, e = embedding size, h = hidden size.
        @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h_d),
                where b = batch size, h_d = hidden_size_dec.
                First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
        @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h_e * 2), where b = batch size,
                                    src_len = maximum source length, h = hidden size.
        @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h_e * 2) to h.
                Tensor is with shape (b, src_len, h),
                where b = batch size, src_len = maximum source length, h = hidden size.
        @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
                                    where b = batch size, src_len is maximum source length.

        @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h),
                where b = batch size, h = hidden size.
                First tensor is decoder's new hidden state, second tensor is decoder's new cell.
        @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h),
                where b = batch size, h = hidden size.
        @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
                                Note: You will not use this outside of this function.
                                      We are simply returning this value so that we can sanity check
                                      your implementation.
        """

        combined_output = None

        e_t = None
        # YOUR CODE HERE (~3 Lines)
        # TODO:
        #     1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state.
        #     2. Split dec_state into its two parts (dec_hidden, dec_cell)
        #     3. Compute the attention scores e_t [src_len*2h*1], and alpha, a Tensor shape (b, src_len).
        #        Note: b = batch_size, src_len = maximum source length, h = hidden size.
        #
        #       Hints:
        #         - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched)
        #         - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched).
        #         - Use batched matrix multiplication (torch.bmm) to compute e_t.
        #         - To get the tensors into the right shapes for bmm, you'll need to do some squeezing and unsqueezing.
        #         - When using the squeeze() function make sure to specify the dimension you want to squeeze
        #             over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
        #
        # Use the following docs to implement this functionality:
        #     Batch Multiplication:
        #        https://pytorch.org/docs/stable/torch.html#torch.bmm
        #     Tensor Unsqueeze:
        #         https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
        #     Tensor Squeeze:
        #         https://pytorch.org/docs/stable/torch.html#torch.squeeze

        # INPUTS:
        #
        # Ybar_t: [b x (e + h_enc)]                 <-- in pdf, Y_t is [e+h_dec x 1]
        # dec_state (OG): ([b x h_dec], [b x h_dec])

        # DECODER:
        # self.decoder = nn.LSTMCell(embed_size + hidden_size_enc, hidden_size_dec)

        dec_state = dec_hidden, dec_cell = self.decoder(Ybar_t, dec_state) # ([b x h], [b x h])

        #
        # COMPUTE E_T
        #

        e_t = self.attention_function(dec_hidden, enc_hiddens_proj)


        # END YOUR CODE

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.bool(), -float('inf'))

        # YOUR CODE HERE (~6 Lines)
        # TODO:
        #     1. Apply softmax to e_t to yield alpha_t
        #     2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the
        #         attention output vector, a_t.
        #     Hints:
        #           - alpha_t is shape (b, src_len)
        #           - enc_hiddens is shape (b, src_len, 2h)
        #           - a_t should be shape (b, 2h)
        #           - You will need to do some squeezing and unsqueezing.
        #     Note: b = batch size, src_len = maximum source length, h = hidden size.
        #     3. Concatenate dec_hidden with a_t to compute tensor U_t
        #     4. Use the output projection layer to compute tensor V_t
        #     5. Compute tensor O_t using the Tanh function and the dropout layer.
        #
        # Use the following docs to implement this functionality:
        #     Softmax:
        #         https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax
        #     Batch Multiplication:
        #        https://pytorch.org/docs/stable/torch.html#torch.bmm
        #     Tensor View:
        #         https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        #     Tensor Concatenation:
        #         https://pytorch.org/docs/stable/torch.html#torch.cat
        #     Tanh:
        #         https://pytorch.org/docs/stable/torch.html#torch.tanh

        #
        # COMPUTE A
        #
        alpha_t = torch.nn.functional.softmax(e_t, 1)              # [b x src_len]
        alpha_t = alpha_t.unsqueeze(1)                  # [b x 1 x src_len]
        a_t = torch.bmm(alpha_t, enc_hiddens)                 # [b,1,sl]*[b,sl,2h] -> [b,1,2h]
        a_t = a_t.squeeze(1)                            # [b,1,2h] -> [b,2h]

        U_t = torch.cat((dec_hidden, a_t), 1)           # [b x h] + [b x 2h] = [b x 3h]
        V_t = self.combined_output_projection(U_t)           # [h x 3h] * [b x 3h (x 1)] -> [b x h (x 1)]

        O_t = self.dropout( torch.tanh(V_t) )

        # END YOUR CODE

        combined_output = O_t
        return dec_state, combined_output, e_t
Ejemplo n.º 34
0
    def test(self, partition='test'):

        num_ways_test = self.config['num_ways']
        num_shots_test = self.config['num_shots']
        test_batch_size = self.config['num_batch']

        test_iteration = 150
        support_shot_nums = num_shots_test
        query_shot_nums = 1
        val_data_loader = VRDDataLoader("test")
        log_flag = True
        best_acc = 0

        predicate_detection = self.config['predicate_detection']

        num_supports = num_ways_test * num_shots_test
        num_queries = num_ways_test * 1
        num_samples = num_supports + num_queries

        support_edge_mask = torch.zeros(test_batch_size, num_samples, num_samples).cuda()
        support_edge_mask[:, :num_supports, :num_supports] = 1
        query_edge_mask = 1 - support_edge_mask
        evaluation_mask = torch.ones(test_batch_size, num_samples, num_samples).cuda()

        query_edge_losses = []
        query_edge_accrs = []
        query_node_accrs = []

        total_nodes_acc = []
        test_acc_vec = []


        feature_dim = self.config['feature_dim']


        for iter in range(test_iteration // test_batch_size):

            support_all_input, query_all_input, os_label, [support_label, query_label], [idx_for_class,
                                                                                         idx_for_data] = val_data_loader.get_task_batch(
                num_tasks=test_batch_size,
                num_ways=num_ways_test,
                num_shots=num_shots_test,
                seed=iter)

            os_support_full_label = torch.cat(
                [os_label[0].view(test_batch_size, -1).cuda(), os_label[1].view(test_batch_size, -1).cuda()], 1)
            os_query_full_label = torch.cat(
                [os_label[2].view(test_batch_size, -1).cuda(), os_label[3].view(test_batch_size, -1).cuda()], 1)

            os_full_label = torch.cat([os_support_full_label, os_query_full_label], 1)
            os_full_edge = self.label2edge(os_full_label)

            support_label = support_label.cuda()
            query_label = query_label.cuda()
            full_label = torch.cat([support_label, query_label], 1)
            full_edge = self.label2edge(full_label)

            init_edge = full_edge.clone()
            init_edge[:, :, num_supports:, :] = 0.5
            init_edge[:, :, :, num_supports:] = 0.5
            for i in range(num_queries):
                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0

            # set as train mode
            self.enc_module.eval()
            self.gnn_module.eval()


            support_subject_emb = support_all_input[4].cuda()
            support_object_emb = support_all_input[5].cuda()
            query_subject_emb = query_all_input[4].cuda()
            query_object_emb = query_all_input[5].cuda()


            support_predicate_input = support_all_input[1].view(test_batch_size * num_ways_test, 1, num_shots_test,
                                                                -1).cuda()
            query_predicate_input = query_all_input[1].view(test_batch_size * num_ways_test, 1, 1, -1).cuda()

            all_support_full_data, support_mapping_subj_feature, support_mapping_obj_feature = self.enc_module(
                support_all_input[0].cuda(),
                support_predicate_input,
                support_all_input[2].cuda(),
                support_all_input[3].cuda(),
                support_subject_emb, support_object_emb, support_shot_nums)

            all_query_full_data, query_mapping_subj_feature, query_mapping_obj_feature = self.enc_module(
                query_all_input[0].cuda(),
                query_predicate_input,
                query_all_input[2].cuda(),
                query_all_input[3].cuda(),
                query_subject_emb, query_object_emb, query_shot_nums)

            support_subject_input = support_mapping_subj_feature.view(test_batch_size, num_supports, feature_dim).cuda()
            support_object_input = support_mapping_obj_feature.view(test_batch_size, num_supports, feature_dim).cuda()
            support_full_data = torch.cat([support_subject_input, support_object_input], 1)

            query_subject_input = query_mapping_subj_feature.view(test_batch_size, num_queries, feature_dim).cuda()
            query_object_input = query_mapping_obj_feature.view(test_batch_size, num_queries, feature_dim).cuda()
            query_full_data = torch.cat([query_subject_input, query_object_input], 1)

            feature_full_data = torch.cat([support_full_data, query_full_data], 1)

            all_support_full_data = all_support_full_data.view(test_batch_size, num_supports, feature_dim)

            all_query_full_data = all_query_full_data.view(test_batch_size, num_queries, feature_dim)

            full_data = torch.cat([all_support_full_data, all_query_full_data], 1)

            full_logit_all, object_full_logit_layers, object_out = self.gnn_module(node_feat=full_data,
                                                                                   edge_feat=init_edge,
                                                                                   object_node_feat=feature_full_data,
                                                                                   object_edge_feat=os_full_edge)



            full_logit = full_logit_all[-1]

            full_edge_loss = self.edge_loss(1 - full_logit[:, 0], 1 - full_edge[:, 0])

            query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum(
                query_edge_mask * evaluation_mask)

            pos_query_edge_loss = torch.sum(
                full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(
                query_edge_mask * full_edge[:, 0] * evaluation_mask)
            neg_query_edge_loss = torch.sum(
                full_edge_loss * query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) / torch.sum(
                query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask)
            query_edge_loss = pos_query_edge_loss + neg_query_edge_loss

            full_edge_accr = self.hit(full_logit, 1 - full_edge[:, 0].long())
            query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum(
                query_edge_mask * evaluation_mask)

            query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports],
                                        self.one_hot_encode(num_ways_test,
                                                            support_label.long()))  # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
            query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()


            query_edge_losses += [query_edge_loss.item()]
            query_edge_accrs += [query_edge_accr.item()]
            query_node_accrs += [query_node_accr.item()]


            print('evaluation: mean=%.2f%%, /n' % (np.array(query_node_accrs).mean() * 100,))


        return np.array(total_nodes_acc).mean()
Ejemplo n.º 35
0
    def train_self_supervised(self):
        seed = 0
        import numpy as np
        np.random.seed(seed)
        import random as rn
        rn.seed(seed)
        import os
        os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        from networks import Dense_Net, Dense_Net_with_softmax
        Nets = []
        AEs = []

        for view_id in range(self.n_view):
            Net = Dense_Net(input_dim=self.input_shape[view_id],
                            out_dim=self.output_shape)
            Nets.append(Net)
            AE = Dense_Net(input_dim=self.output_shape,
                           out_dim=self.input_shape[view_id])
            AEs.append(AE)

        if torch.cuda.is_available():
            for view_id in range(self.n_view):
                Nets[view_id].cuda()
                AEs[view_id].cuda()

        W = torch.tensor(self.W)
        W = Variable(W.cuda(), requires_grad=False)
        get_grad_params = lambda model: [
            x for x in model.parameters() if x.requires_grad
        ]

        params = []
        optims = []
        for view_id in range(self.n_view):
            params.append(
                get_grad_params(Nets[view_id]) + get_grad_params(AEs[view_id]))
            optims.append(
                optim.Adam(params[view_id], self.lr[view_id],
                           [self.beta1, self.beta2]))

        discriminator_losses, losses, valid_results = [], [], []

        for view_id in range(self.n_view):
            discriminator_losses.append([])
            losses.append([])
            valid_results.append([])

        criterion = lambda x, y: (((x - y)**2).sum(1).sqrt()).mean()
        tr_d_loss, tr_ae_loss, val_d_loss, val_ae_loss = [], [], [], []
        for view_id in range(self.n_view):
            tr_d_loss.append([])
            tr_ae_loss.append([])
            val_d_loss.append([])
            val_ae_loss.append([])

        valid_loss_min = 1e9
        for epoch in range(self.epochs):
            rand_idx = np.arange(self.train_data[view_id].shape[0])
            np.random.shuffle(rand_idx)
            batch_count = int(self.train_data[view_id].shape[0] /
                              float(self.batch_size))

            k = 0
            mean_loss = []
            mean_tr_d_loss, mean_tr_ae_loss = [], []

            for view_id in range(self.n_view):
                mean_loss.append([])
                mean_tr_d_loss.append([])
                mean_tr_ae_loss.append([])

            for batch_idx in range(batch_count):

                self_supervised_features = []

                ae_loss = 0

                for view_id in range(self.n_view):
                    print(('ViewID: %d, Epoch %d/%d') %
                          (view_id, epoch + 1, self.epochs))

                    idx = rand_idx[batch_idx *
                                   self.batch_size:(batch_idx + 1) *
                                   self.batch_size]
                    train_y = self.to_one_hot(self.train_labels[view_id][idx])
                    train_x = self.to_var(
                        torch.tensor(self.train_data[view_id][idx]))

                    optimizer = optims[view_id]
                    Net = Nets[view_id]
                    AE = AEs[view_id]

                    optimizer.zero_grad()

                    network_outputs = Net(train_x)
                    ae_input = network_outputs[-1]
                    pred = ae_input.view([ae_input.shape[0], -1]).mm(W)
                    self_supervised_features.append(pred)

                    ae_data = train_x
                    ae_pred = AE(ae_input)[-1]
                    ae_loss += criterion(ae_pred, ae_data)

                labeled_inx = train_y.sum(1) > 0

                d_loss = 0

                if (not self.use_nce):
                    identity_mat = torch.eye(
                        self_supervised_features[0].shape[0]).reshape(
                            1, self_supervised_features[0].shape[0],
                            self_supervised_features[0].shape[0]).repeat(
                                self_supervised_features[0].shape[1], 1, 1)

                    if torch.cuda.is_available():
                        identity_mat = identity_mat.cuda()

                    f_1 = self_supervised_features[0].view(
                        self_supervised_features[0].shape[1],
                        self_supervised_features[0].shape[0], -1)
                    f_2 = self_supervised_features[1].view(
                        self_supervised_features[1].shape[1], -1,
                        self_supervised_features[0].shape[0])

                    loss_pred = torch.bmm(f_1, f_2)
                    cpc_loss = self.cpc_loss_func(loss_pred, identity_mat)

                else:
                    # NCE CPC loss implementation (Eq. 6) [parts of code referred from https://github.com/HobbitLong/CMC/blob/master/NCE/NCECriterion.py]
                    eps = 1e-7

                    anchor_samples_indices = self.sample_anchor_points(
                        self_supervised_features)
                    pos_samples_indices = self.sample_positives(
                        anchor_samples_indices, self_supervised_features)
                    neg_samples_indices = self.sample_negatives(
                        anchor_samples_indices, self_supervised_features)

                    anchor_samples = torch.zeros(
                        (self.num_anchors,
                         self_supervised_features[0].shape[1]))
                    pos_samples = torch.zeros(
                        (self.num_anchors,
                         self_supervised_features[0].shape[1]))
                    neg_samples = torch.zeros(
                        (self.num_anchors, self.num_negative_samples,
                         self_supervised_features[0].shape[1]))

                    for anchor_id in range(len(anchor_samples_indices)):
                        anchor_point = anchor_samples_indices[anchor_id]
                        anchor_samples[anchor_id] = self_supervised_features[
                            anchor_point[0]][anchor_point[1]]

                    for pos_id in range(len(pos_samples_indices)):
                        pos_point = pos_samples_indices[pos_id]
                        pos_samples[pos_id] = self_supervised_features[
                            pos_point[0]][pos_point[1]]

                    for neg_id in range(len(neg_samples_indices)):
                        for neg_sample_id in range(
                                len(neg_samples_indices[neg_id])):
                            neg_point = neg_samples_indices[neg_id][
                                neg_sample_id]
                            neg_samples[neg_id][
                                neg_sample_id] = self_supervised_features[
                                    neg_point[0]][neg_point[1]]

                    if (torch.cuda.is_available()):
                        anchor_samples = anchor_samples.cuda()
                        pos_samples = pos_samples.cuda()
                        neg_samples = neg_samples.cuda()

                    # noise distribution
                    Pn = 1 / float(self_supervised_features[0].shape[0])

                    # number of noise samples
                    m = self_supervised_features[0].shape[1]

                    D1 = torch.div(
                        pos_samples,
                        torch.zeros_like(pos_samples).fill_(eps) +
                        pos_samples.add(m * Pn + eps))
                    D1 = torch.clamp(D1, min=eps)
                    log_D1 = D1.log()

                    D0 = torch.div(
                        neg_samples.clone().fill_(m * Pn),
                        torch.zeros_like(neg_samples).fill_(eps) +
                        neg_samples.add(m * Pn + eps))
                    D0 = torch.clamp(D0, min=eps)
                    log_D0 = D0.log()

                    cpc_loss = torch.mean(
                        -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) /
                        self_supervised_features[0].shape[0]**2)

                for f_ind in range(len(self_supervised_features)):
                    train_y = self.to_one_hot(self.train_labels[f_ind][idx])
                    feature = self_supervised_features[f_ind]

                    curr_loss = criterion(feature[labeled_inx],
                                          train_y[labeled_inx])
                    d_loss += curr_loss

                # cross modal loss (Eq. 2)
                cross_loss = 0

                for f_ind1 in range(len(self_supervised_features)):
                    for f_ind2 in range(len(self_supervised_features)):
                        if (f_ind1 == f_ind2):
                            continue
                        else:
                            feature_1 = self_supervised_features[f_ind1]
                            feature_2 = self_supervised_features[f_ind2]
                            train_y1 = self.to_one_hot(
                                self.train_labels[f_ind1][idx])
                            train_y2 = self.to_one_hot(
                                self.train_labels[f_ind2][idx])

                            curr_loss = criterion(feature_1[labeled_inx],
                                                  feature_2[labeled_inx])
                            cross_loss += curr_loss

                ae_loss *= self.alpha
                cross_loss *= self.beta
                cpc_loss *= self.gamma
                d_loss *= self.delta

                loss = ae_loss + cpc_loss + d_loss + cross_loss
                loss.backward()

                for view_id in range(self.n_view):
                    optims[view_id].step()
                    mean_loss[view_id].append(self.to_data(loss))
                    mean_tr_d_loss[view_id].append(self.to_data(d_loss))
                    mean_tr_ae_loss[view_id].append(self.to_data(ae_loss))

                if ((epoch + 1) % self.sample_interval
                        == 0) and (batch_idx == batch_count - 1):
                    valid_pres = []
                    test_pres = []

                    for view_id in range(self.n_view):
                        losses[view_id].append(np.mean(mean_loss[view_id]))
                        utils.show_progressbar([batch_idx, batch_count],
                                               mean_loss=np.mean(
                                                   mean_loss[view_id]))

                        pre_labels = utils.predict(
                            lambda x: Nets[view_id](x)[-1].view(
                                [x.shape[0], -1]).mm(W).view([x.shape[0], -1]),
                            self.valid_data[view_id], self.batch_size).reshape(
                                [self.valid_data[view_id].shape[0], -1])
                        valid_labels = self.to_one_hot(
                            self.valid_labels[view_id])
                        valid_d_loss = self.to_data(
                            criterion(self.to_var(torch.tensor(pre_labels)),
                                      valid_labels))
                        if valid_loss_min > valid_d_loss and not self.just_valid:
                            valid_loss_min = valid_d_loss

                            valid_pre_0 = utils.predict(
                                lambda x: Nets[0]
                                (x)[-1].view([x.shape[0], -1]),
                                self.valid_data[0], self.batch_size).reshape(
                                    [self.valid_data[0].shape[0], -1])
                            valid_pre_1 = utils.predict(
                                lambda x: Nets[1]
                                (x)[-1].view([x.shape[0], -1]),
                                self.valid_data[1], self.batch_size).reshape(
                                    [self.valid_data[1].shape[0], -1])

                            test_pre_0 = utils.predict(
                                lambda x: Nets[0]
                                (x)[-1].view([x.shape[0], -1]),
                                self.test_data[0], self.batch_size).reshape(
                                    [self.test_data[0].shape[0], -1])
                            test_pre_1 = utils.predict(
                                lambda x: Nets[1]
                                (x)[-1].view([x.shape[0], -1]),
                                self.test_data[1], self.batch_size).reshape(
                                    [self.test_data[1].shape[0], -1])

                            valid_pres.append(valid_pre_0)
                            test_pres.append(test_pre_0)
                        elif self.just_valid:
                            tr_d_loss[view_id].append(
                                np.mean(mean_tr_d_loss[view_id]))
                            val_d_loss[view_id].append(valid_d_loss[view_id])
                            tr_ae_loss[view_id].append(
                                np.mean(mean_tr_ae_loss[view_id]))
                elif batch_idx == batch_count - 1:
                    utils.show_progressbar([batch_idx, batch_count],
                                           mean_loss=np.mean(
                                               mean_loss[view_id]))
                    losses[view_id].append(np.mean(mean_loss[view_id]))
                else:
                    utils.show_progressbar([batch_idx, batch_count], loss=loss)
                k += 1

        torch.save(Nets[0].state_dict(),
                   './features/' + self.datasets + '_image_encoder_weights')
        torch.save(Nets[1].state_dict(),
                   './features/' + self.datasets + '_text_encoder_weights')
        torch.save(AEs[0].state_dict(),
                   './features/' + self.datasets + '_image_decoder_weights')
        torch.save(AEs[1].state_dict(),
                   './features/' + self.datasets + '_text_decoder_weights')

        valid_pres_all = [valid_pre_0, valid_pre_1]
        test_pres_all = [test_pre_0, test_pre_1]

        if not self.just_valid:
            for view_id in range(self.n_view):
                sio.savemat(
                    'features/' + self.datasets + '_' + str(view_id) + '.mat',
                    {
                        'valid_fea': valid_pres_all[view_id],
                        'valid_lab': self.valid_labels[view_id],
                        'test_fea': test_pres_all[view_id],
                        'test_lab': self.test_labels[view_id]
                    })
            return [valid_pre_0, test_pre_0]
        else:
            self.tr_d_loss[view_id] = tr_d_loss[view_id]
            self.tr_ae_loss[view_id] = tr_ae_loss[view_id]
            self.val_d_loss[view_id] = val_d_loss[view_id]
Ejemplo n.º 36
0
 def _rank_k_bmm(self, x, u, v):
     xu = torch.bmm(x[:, None], u.view(x.shape[0], x.shape[-1], self.rank))
     xuv = torch.bmm(xu, v.view(x.shape[0], self.rank, -1))
     return xuv[:, 0]
 def gram_matrix(self, inp):
   b, c, h, w = inp.size()
   features = inp.view(b, c, h*w)
   G = torch.bmm(features, features.transpose(1, 2))
   return G.div(h*w)
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                mode=""):
        #
        if self.pad_level == "sent" or self.pad_level == "sentence":
            text_inputs = text_inputs.view(
                text_inputs.shape[0], self.max_num_sents * self.max_len_sent)
        mask = mask_input.view(text_inputs.shape)

        #
        encoder_out = self.base_encoder(text_inputs, mask_input, len_seq)

        # applying conv1d after rnn
        avg_pooled = torch.zeros(text_inputs.shape[0], text_inputs.shape[1],
                                 self.conv_output_size)
        avg_pooled = utils.cast_type(avg_pooled, FLOAT, self.use_gpu)
        for cur_batch, cur_tensor in enumerate(encoder_out):
            ## Actual length version
            if self.target_model == "conll17_al":
                cur_seq_len = int(len_seq[cur_batch])
                cur_tensor = cur_tensor.unsqueeze(0)
                crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len)
                crop_tensor = crop_tensor.transpose(1, 0)
                cur_tensor = crop_tensor
            ## published version: do not consider actual length
            else:
                cur_tensor = cur_tensor.unsqueeze(1)

            # applying conv
            cur_tensor = self.conv(cur_tensor)
            cur_tensor = self.leak_relu(cur_tensor)
            cur_tensor = self.dropout_layer(cur_tensor)
            # cur_tensor = self.avg_pool_1d(cur_tensor)
            cur_tensor = self.avg_adapt_pool1(cur_tensor)
            cur_tensor = cur_tensor.view(cur_tensor.shape[0],
                                         self.conv_output_size)
            avg_pooled[cur_batch, :cur_tensor.shape[0], :] = cur_tensor

        len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu)

        ## implement attention by parameters
        context_weight = self.context_weight.unsqueeze(1)
        context_weight = context_weight.expand(text_inputs.shape[0],
                                               self.conv_output_size, 1)
        attn_weight = torch.bmm(avg_pooled, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = self.softmax(attn_weight)
        # attention applied
        attn_vec = torch.bmm(avg_pooled.transpose(1, 2),
                             attn_weight.unsqueeze(2))

        ilc_vec = attn_vec.squeeze(2)

        ## implement attention by linear
        #attn_vec = self.attn(encoder_out.view(self.batch_size, -1)).unsqueeze(2)
        #attn_vec = self.softmax(attn_vec)
        #ilc_vec_attn = torch.bmm(encoder_out.transpose(1, 2), attn_vec).squeeze(2)

        ## FC

        # fully connected stage
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)
        if self.corpus_target.lower() == "asap":
            fc_out = self.sigmoid(fc_out)

        return fc_out
Ejemplo n.º 39
0
    def forward(self, inputs, predict_M, getAttention=False):
        """
        do the LSTM process
        :param inputs: the input sequcen
        :param pred_len: the length of the predict value
        :return: outputs:the ouput sequence
        """
        assert len(inputs.size(
        )) == 3, '[LSTM]: input dimension must be of length 3 i.e. [M*S*D]'

        # encoder
        inputs = self.bn(inputs)
        inputs, h, r_batch_size = self._prepare_input_h(inputs)
        c, h = self.encoder(inputs, h)

        #attention and decoder
        outs = []
        predict_M, _, _ = self._prepare_input_h(predict_M)

        # this must be in the outside of the for loop
        if not self.args.batch_first:
            c = c.permute(1, 0, 2)
        h_ = h  # we use the h from encoder h_ = [(args.encoder_num_layer*binum,r_batch_size,encoder_hidden_dim),
        #                                       (args.encoder_num_layer*binum,r_batch_size,encoder_hidden_dim)]

        if getAttention:
            attention = []
        for i in range(self.args.pred_len):
            #attention
            if not self.args.batch_first:
                h_temp = h_[0].permute(1, 0, 2)
            h_temp = h_temp.contiguous().view(r_batch_size, -1)
            weights = F.softmax(self.sequenceForAttention(h_temp),
                                dim=-1)  # [r_batch_size,args.seq_len]

            if getAttention:
                temp_weights = weights.squeeze()
                temp_weights = temp_weights.cpu().detach().numpy()
                attention.append(temp_weights.tolist())

            weights = torch.unsqueeze(weights, 1)
            out = torch.bmm(weights, c)  #[r_batch_size,1,encoder_hidden_dim]
            if not self.args.batch_first:
                out = out.permute(1, 0, -1)
            #decoder
            out, h_ = self.decoder(out, h_)
            outs.append(out)
        outs = torch.stack(
            outs, dim=0) if not self.args.batch_first else torch.stack(outs,
                                                                       dim=1)
        outs = torch.squeeze(
            outs, dim=2) if self.args.batch_first else torch.squeeze(outs,
                                                                     dim=1)
        outs = self.sequenceForOut(outs)

        if not self.args.batch_first:
            outs = outs.permute(1, 0, 2)
        if getAttention:
            return outs, attention
        else:
            return outs
Ejemplo n.º 40
0
 def forward(self, x0, x):
     x0xl = torch.bmm(x0.unsqueeze(-1), x.unsqueeze(-2))
     return torch.tensordot(x0xl, self.weights, [[-1], [0]]) + self.bias + x
Ejemplo n.º 41
0
 def forward(self, x, A1):
     Amatrix = torch.tril(A1)
     y = torch.bmm(Amatrix, x)
     return y
Ejemplo n.º 42
0
 def forward(self, hidden, encoder_outputs, normalize=True):
     encoder_outputs = encoder_outputs.transpose(0, 1)  # [B,T,H]
     attn_energies = self.score(hidden, encoder_outputs)
     normalized_energy = F.softmax(attn_energies, dim=2)  # [B,1,T]
     context = torch.bmm(normalized_energy, encoder_outputs)  # [B,1,H]
     return context.transpose(0, 1)  # [1,B,H]
 def score(self, hidden, encoder_outputs):
     energy = F.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) # [B*T*2H]->[B*T*H]
     energy = energy.transpose(2,1) # [B*H*T]
     v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H]
     energy = torch.bmm(v,energy) # [B*1*T]
     return energy.squeeze(1) #[B*T]
Ejemplo n.º 44
0
 def forward(self, input):
     b, c, h, w = input.size()
     F = input.view(b, c, h * w)
     G = torch.bmm(F, F.transpose(1, 2))
     G.div_(h * w)
     return G
Ejemplo n.º 45
0
 def _layer(self, t, x, weight, bias):
     # weights is (batch, in_dim + 1, out_dim)
     ttx = self._pack_inputs(t, x)  # (batch, in_dim + 1)
     ttx = ttx.view(ttx.size(0), 1, ttx.size(1))  # (batch, 1, in_dim + 1)
     xw = torch.bmm(ttx, weight)[:, 0, :]  # (batch, out_dim)
     return xw + bias
    def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None):

        # assert contents_char is not None and question_ans_char is not None
        batch_size = question_ans.size()[0]
        max_content_len = contents.size()[2]
        max_question_len = question_ans.size()[1]
        contents_num = contents.size()[1]
        # word-level embedding: (seq_len, batch, embedding_size)
        content_vec = []
        content_mask = []
        question_vec, question_mask = self.embedding.forward(question_ans)
        for i in range(contents_num):
            cur_content = contents[:, i, :]
            cur_content_vec, cur_content_mask = self.embedding.forward(cur_content)
            content_vec.append(cur_content_vec)
            content_mask.append(cur_content_mask)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        # context_emb_char, context_char_mask = self.char_embedding.forward(context_char)
        # question_emb_char, question_char_mask = self.char_embedding.forward(question_char)
        question_encode, _ = self.context_layer.forward(question_vec,question_mask)  # size=(cur_batch_max_questionans_len, batch, 256)
        content_encode = []  # word-level encode: (seq_len, batch, hidden_size)
        for i in range(contents_num):
            cur_content_vec = content_vec[i]
            cur_content_mask = content_mask[i]
            cur_content_encode, _ = self.context_layer.forward(cur_content_vec,cur_content_mask)  # size=(cur_batch_max_content_len, batch, 256)
            content_encode.append(cur_content_encode)

        # 将所有的content编码后统一到相同的长度 200,所有的question编码后统一到相同的长度100
        same_sized_content_encode = []
        for i in range(contents_num):
            cur_content_encode = content_encode[i]
            cur_content_encode = self.full_matrix_to_specify_size(cur_content_encode, [max_content_len, batch_size,cur_content_encode.size()[2]])  # size=(200,16,256)
            same_sized_content_encode.append(cur_content_encode)
        same_sized_question_encode = self.full_matrix_to_specify_size(question_encode, [max_question_len, batch_size,question_encode.size()[2]])  # size=(100,16,256)


        # 计算gating layer的值
        reasoning_content_gating_val = []
        reasoning_question_gating_val = None
        decision_content_gating_val = []
        decision_question_gating_val = None
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i]  # size=(200,16,256)
            cur_gating_input = cur_content_encode.permute(1,2,0)  # size=(16,256,200)
            cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_gating_input)  # size=(16,1,200)
            cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错

            cur_decision_content_gating_val = self.decision_gating_layer(cur_gating_input)  # size=(16,1,200)
            cur_decision_content_gating_val =cur_decision_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错
            reasoning_content_gating_val.append(cur_reasoning_content_gating_val)
            decision_content_gating_val.append(cur_decision_content_gating_val)

        question_gating_input = same_sized_question_encode.permute(1,2,0)  # size=(16,256,100)
        reasoning_question_gating_val = self.reasoning_gating_layer(question_gating_input)  # size=(16,1,100)
        reasoning_question_gating_val=reasoning_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错
        decision_question_gating_val = self.decision_gating_layer(question_gating_input)  # size=(16,1,100)
        decision_question_gating_val=decision_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错


        # 计算gate loss todo: 貌似无法返回多个变量,暂时无用
        # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)])
        # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val])
        # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val])
        # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val])
        # mean_gate_val = torch.mean(all_gate_val)


        # Matching Matrix computing, question 和每一个content都要计算matching matrix
        Matching_matrix = []
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i]
            cur_Matching_matrix = self.compute_matching_matrix(same_sized_question_encode,
                                                               cur_content_encode)  # (batch, question_len , content_len) eg(16,100,200)
            Matching_matrix.append(cur_Matching_matrix)

        # compute an & bn
        an_matrix = []
        bn_matrix = []
        for i in range(contents_num):
            cur_Matching_matrix = Matching_matrix[i]
            cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2)  # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len)
            cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1)  # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len)
            an_matrix.append(cur_an_matrix)
            bn_matrix.append(cur_bn_matrix)


        # compute RnQ & RnD
        RnQ = []  # list[tensor[16,100,256]]
        RnD = []
        for i in range(contents_num):
            cur_an_matrix = an_matrix[i]
            cur_content_encode = same_sized_content_encode[i]
            cur_bn_matrix = bn_matrix[i]
            cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode)  # size=(batch, curbatch_max_question_len , 256)     eg[16,100,256]
            cur_RnD = self.compute_RnD(cur_bn_matrix,same_sized_question_encode)  # size=(batch, curbatch_max_content_len , 256)    eg[16,200,256]
            RnQ.append(cur_RnQ)
            RnD.append(cur_RnD)


        ########### compute Mmn' ##############
        D_RnD = []  # 先获得D和RnD的concatenation
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i].transpose(0, 1)  # size=(16,200,256)
            cur_RnD = RnD[i]  # size=(16,200,256)
            # embed()
            cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2)  # size=(16,200,512)
            D_RnD.append(cur_D_RnD)

        RmD = []  # list[tensor(16,200,512)]
        for i in range(contents_num):
            D_RnD_m = D_RnD[i]  # size=(16,200,512)
            Mmn_i=[]
            RmD_i = []
            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n)  # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200)
                Mmn_i.append(Mmn_i_j)

            Mmn_i=torch.stack(Mmn_i).permute(1,2,3,0)# size=(16,200,200,10)
            softmax_Mmn_i=self.reduce_softmax(Mmn_i) # size=(16,200,200,10)

            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                beta_mn_i_j = softmax_Mmn_i[:,:,:,j]
                cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n)  # size=(16,200,512)
                RmD_i.append(cur_RmD)


            RmD_i = torch.stack(RmD_i)  # size=(10,16,200,512)
            RmD_i = RmD_i.transpose(0, 1)  # size=(16,10,200,512)
            RmD_i = torch.sum(RmD_i, dim=1)  # size=(16,200,512)
            RmD.append(RmD_i)

        # RmD=torch.stack(RmD).transpose(0,1) #size=(16,10,200,512)


        matching_feature_row = []  # list[tensor(16,200,2)]
        matching_feature_col = []  # list[tensor(16,100,2)]
        for i in range(contents_num):
            cur_Matching_matrix = Matching_matrix[i]  # size=(16,100,200)
            cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row]).permute(1,2,0)  # size=(16,200,2)
            matching_feature_row.append(cur_matching_feature_row)

            cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col]).permute(1,2,0)  # size=(16,100,2)
            matching_feature_col.append(cur_matching_feature_col)
        # print(253)
        # embed()
        reasoning_feature = []
        RnQ_reasoning_out=[]
        RmD_reasoning_out=[]
        for i in range(contents_num):
            cur_RnQ = RnQ[i]  # size=(16,100,256)
            cur_RmD = RmD[i]  # size=(16,200,512)
            cur_matching_feature_col = matching_feature_col[i]  # size=(16,100,2)
            cur_matching_feature_row = matching_feature_row[i]  # size=(16,200,2)

            cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2)  # size=(16,100,258)
            cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2)  # size=(16,200,514)

            cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx)
            cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx)

            gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258)
            gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514)

            # 经过reasoning层
            cur_RnQ_reasoning_out, _ = self.question_reasoning_layer.forward(gated_cur_RnQ.transpose(0,1),cur_RnQ_mask)  # size=(max_sequence_len,16,256)
            cur_RmD_reasoning_out, _ = self.content_reasoning_layer.forward(gated_cur_RmD.transpose(0,1),cur_RmD_mask)  # size=(max_sequence_len,16,256)

            # 所有的矩阵变成相同的大小
            cur_RnQ_reasoning_out = self.full_matrix_to_specify_size(cur_RnQ_reasoning_out,
                                                                     [max_question_len, batch_size,
                                                                      cur_RnQ_reasoning_out.size()[2]])  # size=(100,16,256)
            cur_RmD_reasoning_out = self.full_matrix_to_specify_size(cur_RmD_reasoning_out,
                                                                     [max_content_len, batch_size,
                                                                      cur_RmD_reasoning_out.size()[2]])  # size=(200,16,256)

            #过decision layer的gate层
            cur_RnQ_reasoning_out=cur_RnQ_reasoning_out.transpose(0,1) #size(16,100,256)
            cur_RmD_reasoning_out=cur_RmD_reasoning_out.transpose(0,1) #size(16,200,256)

            RnQ_reasoning_out.append(cur_RnQ_reasoning_out)
            RmD_reasoning_out.append(cur_RmD_reasoning_out)

            # gated_RnQ_out=self.compute_gated_value(cur_RnQ_reasoning_out,decision_question_gating_val)#size(16,100,256)
            # gated_RmD_out=self.compute_gated_value(cur_RmD_reasoning_out,decision_content_gating_val[i])#size(16,200,256)
            #
            # # 将2种feature cat到一起得到300*256的表示
            # cur_reasoning_feature = torch.cat([gated_RnQ_out, gated_RmD_out], dim=1)  # size(16,300,256) ||  when content=100 size(16,200,256)
            # reasoning_feature.append(cur_reasoning_feature)

        # 10个文档的cat到一起
        RnQ_reasoning_out=torch.stack(RnQ_reasoning_out).transpose(0,1) #size=(16,10,100,256)
        RmD_reasoning_out=torch.stack(RmD_reasoning_out).transpose(0,1) #size=(16,10,200,256)

        RnQ_reasoning_out_maxpool,_=torch.max(RnQ_reasoning_out,dim=1) #size=(16,100,256)
        RmD_reasoning_out_maxpool,_=torch.max(RmD_reasoning_out,dim=1) #size=(16,200,256)

        # gated_RnQ_reasoning_out_maxpool=self.compute_gated_value(RnQ_reasoning_out_maxpool,decision_question_gating_val) #size(16,100,256)
        # gated_RmD_reasoning_out_maxpool=self.compute_gated_value(RmD_reasoning_out_maxpool,decision_content_gating_val)

        fc_input_RnQ_maxpool,_=torch.max(RnQ_reasoning_out_maxpool,dim=1) # size(16,256)
        fc_input_RnQ_meanpool=torch.mean(RnQ_reasoning_out_maxpool,dim=1) # size(16,256)

        fc_input_RmD_maxpool,_=torch.max(RmD_reasoning_out_maxpool,dim=1)#size(16,256)
        fc_input_RmD_meanpool=torch.mean(RmD_reasoning_out_maxpool,dim=1) #size(16,256)

        fc_input=torch.cat([fc_input_RnQ_maxpool,fc_input_RnQ_meanpool,fc_input_RmD_maxpool,fc_input_RmD_meanpool],dim=1) #size(16,1024)

        # reasoning_feature = torch.cat(reasoning_feature, dim=1)  # size=(16,3000,256)    |  when content=100 size(16,2000,256)
        # # print(299)
        # # embed()
        # maxpooling_reasoning_feature_column, _ = torch.max(reasoning_feature, dim=1)  # size(16,256)
        # meanpooling_reasoning_feature_column = torch.mean(reasoning_feature, dim=1)  # size(16,256)
        #
        # maxpooling_reasoning_feature_row, _ = torch.max(reasoning_feature, dim=2)  # size=(16,3000)    |  when content=100 size(16,2000)
        # meanpooling_reasoning_feature_row = torch.mean(reasoning_feature, dim=2)  # size=(16,3000)      |  when content=100 size(16,2000)
        # print(228, "============================")

        # pooling_reasoning_feature = torch.cat([maxpooling_reasoning_feature_row, meanpooling_reasoning_feature_row, maxpooling_reasoning_feature_column,meanpooling_reasoning_feature_column], dim=1)
        decision_input=fc_input.view(int(batch_size/5), 5120)  # size=(16,1024*5) 五分类问题
        #
        # print(312)
        # embed()
        output = self.decision_layer.forward(decision_input)  # size=(batchsize/5,5)



        # temp_gate_val=torch.stack([mean_gate_val,torch.tensor(0.0).to(self.device)]).resize_(1,2)
        # output_with_gate_val=torch.cat([output,temp_gate_val],dim=0)
        # logics=logics.resize_(logics.size()[0],1)
        return output # logics 是反向的话乘以-1,正向的话是乘以1
Ejemplo n.º 47
0
    def train(self):

        num_ways_train = self.config['num_ways']
        num_shots_train = self.config['num_shots']
        support_shot_nums = num_shots_train
        query_shot_nums = 1
        meta_batch_size = self.config['num_batch']
        num_layers = self.config['num_layers']

        train_iteration = self.config['train_iteration']
        lr = 1e-3
        test_interval = 150
        predicate_detection = self.config['predicate_detection']

        feature_dim = self.config['feature_dim']

        val_acc = self.val_acc


        num_supports = num_ways_train * num_shots_train
        num_queries = num_ways_train * 1
        num_samples = num_supports + num_queries


        support_edge_mask = torch.zeros(meta_batch_size, num_samples, num_samples).cuda()
        support_edge_mask[:, :num_supports, :num_supports] = 1
        query_edge_mask = 1 - support_edge_mask
        evaluation_mask = torch.ones(meta_batch_size, num_samples, num_samples).cuda()

        object_support_edge_mask = torch.zeros(meta_batch_size, 2 * num_samples, 2 * num_samples).cuda()
        object_query_edge_mask = 1 - object_support_edge_mask
        object_evaluation_mask = torch.ones(meta_batch_size, 2 * num_samples, 2 * num_samples).cuda()


        for iter in range(self.global_step + 1, train_iteration + 1):

            self.optimizer.zero_grad()

            self.global_step = iter

            support_all_input, query_all_input, os_label, [support_label, query_label], [idx_for_class,
                                                                                         idx_for_data] = self.data_loader.get_task_batch(
                num_tasks=meta_batch_size,
                num_ways=num_ways_train,
                num_shots=num_shots_train,
                seed=iter)

            os_support_full_label = torch.cat(
                [os_label[0].view(meta_batch_size, -1).cuda(), os_label[1].view(meta_batch_size, -1).cuda()], 1)
            os_query_full_label = torch.cat(
                [os_label[2].view(meta_batch_size, -1).cuda(), os_label[3].view(meta_batch_size, -1).cuda()], 1)

            os_full_label = torch.cat([os_support_full_label, os_query_full_label], 1)
            os_full_edge = self.label2edge(os_full_label)

            support_label = support_label.cuda()
            query_label = query_label.cuda()

            full_label = torch.cat([support_label, query_label], 1)
            full_edge = self.label2edge(full_label)

            init_edge = full_edge.clone()
            init_edge[:, :, num_supports:, :] = 0.5
            init_edge[:, :, :, num_supports:] = 0.5
            for i in range(num_queries):
                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0

            self.enc_module.train()
            self.gnn_module.train()


            support_predicate_input = support_all_input[1].view(meta_batch_size * num_ways_train, 1, num_shots_train,
                                                                -1).cuda()
            query_predicate_input = query_all_input[1].view(meta_batch_size * num_ways_train, 1, 1, -1).cuda()




            support_subject_emb = support_all_input[4].cuda()
            support_object_emb = support_all_input[5].cuda()
            query_subject_emb = query_all_input[4].cuda()
            query_object_emb = query_all_input[5].cuda()


            all_support_full_data, support_mapping_subj_feature, support_mapping_obj_feature = self.enc_module(
                support_all_input[0].cuda(),
                support_predicate_input,
                support_all_input[2].cuda(),
                support_all_input[3].cuda(),
                support_subject_emb, support_object_emb, support_shot_nums)

            all_query_full_data, query_mapping_subj_feature, query_mapping_obj_feature = self.enc_module(
                query_all_input[0].cuda(),
                query_predicate_input,
                query_all_input[2].cuda(),
                query_all_input[3].cuda(),
                query_subject_emb, query_object_emb, query_shot_nums)

            all_support_full_data = all_support_full_data.view(meta_batch_size, num_supports, feature_dim)
            all_query_full_data = all_query_full_data.view(meta_batch_size, num_queries, feature_dim)

            full_data = torch.cat([all_support_full_data, all_query_full_data], 1)  #

            support_subject_input = support_mapping_subj_feature.view(meta_batch_size, num_supports, feature_dim).cuda()
            support_object_input = support_mapping_obj_feature.view(meta_batch_size, num_supports, feature_dim).cuda()
            support_full_data = torch.cat([support_subject_input, support_object_input], 1)

            query_subject_input = query_mapping_subj_feature.view(meta_batch_size, num_queries, feature_dim).cuda()
            query_object_input = query_mapping_obj_feature.view(meta_batch_size, num_queries, feature_dim).cuda()
            query_full_data = torch.cat([query_subject_input, query_object_input], 1)

            feature_full_data = torch.cat([support_full_data, query_full_data], 1)

            full_logit_layers, object_full_logit_layers, object_out = self.gnn_module(node_feat=full_data,
                                                                                      edge_feat=init_edge,
                                                                                      object_node_feat=feature_full_data,
                                                                                      object_edge_feat=os_full_edge)

            support_subject_output_label = object_out[0]
            support_object_output_label = object_out[1]
            query_subject_output_label = object_out[2]
            query_object_output_label = object_out[3]

            subject_support_loss = self.criterion(object_out[0], os_label[0].cuda().long())
            object_support_loss = self.criterion(object_out[1], os_label[1].cuda().long())
            subject_query_loss = self.criterion(object_out[2], os_label[2].cuda().long())

            object_query_loss = self.criterion(object_out[3], os_label[3].cuda().long())

            object_class_loss = (
                                            subject_support_loss + object_support_loss + subject_query_loss + object_query_loss) / 4

            full_edge_loss_layers = [self.edge_loss((1 - full_logit_layer[:, 0]), (1 - full_edge[:, 0])) for
                                     full_logit_layer in full_logit_layers]

            object_full_edge_loss_layers = [self.edge_loss((1 - full_logit_layer[:, 0]), (1 - os_full_edge[:, 0])) for
                                            full_logit_layer in object_full_logit_layers]


            object_pos_query_edge_loss_layers = [torch.sum(
                full_edge_loss_layer * object_query_edge_mask * os_full_edge[:,
                                                                0] * object_evaluation_mask) / torch.sum(
                object_query_edge_mask * os_full_edge[:, 0] * object_evaluation_mask) for full_edge_loss_layer in
                                                 object_full_edge_loss_layers]
            object_neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * object_query_edge_mask * (
                        1 - os_full_edge[:, 0]) * object_evaluation_mask) / torch.sum(
                object_query_edge_mask * (1 - os_full_edge[:, 0])) for full_edge_loss_layer in
                                                 object_full_edge_loss_layers]

            object_query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for
                                             (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in
                                             zip(object_pos_query_edge_loss_layers, object_neg_query_edge_loss_layers)]

            pos_query_edge_loss_layers = [
                torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(
                    query_edge_mask * full_edge[:, 0]) for full_edge_loss_layer in full_edge_loss_layers]
            neg_query_edge_loss_layers = [
                torch.sum(full_edge_loss_layer * query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) / torch.sum(
                    query_edge_mask * (1 - full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in
                full_edge_loss_layers]
            query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for
                                      (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in
                                      zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)]


            full_edge_accr_layers = [self.hit(full_logit_layer, 1 - full_edge[:, 0].long()) for full_logit_layer in
                                     full_logit_layers]
            query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum(
                query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers]


            query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports],
                                                self.one_hot_encode(num_ways_train, support_label.long())) for
                                      full_logit_layer in
                                      full_logit_layers]
            query_node_accr_layers = [
                torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for
                query_node_pred_layer in query_node_pred_layers]


            query_nodes_acc = query_node_accr_layers[-1]


            total_loss_layers = query_edge_loss_layers

            total_loss = []
            for l in range(num_layers - 1):
                total_loss += [total_loss_layers[l].view(-1) * 0.5]
            total_loss += [total_loss_layers[-1].view(-1) * 1.0]
            total_loss = torch.mean(torch.cat(total_loss, 0))

            object_total_loss_layers = object_query_edge_loss_layers
            object_total_loss = []
            for l in range(num_layers - 1):
                object_total_loss += [object_total_loss_layers[l].view(-1) * 0.5]
            object_total_loss += [object_total_loss_layers[-1].view(-1) * 1.0]
            object_total_loss = torch.mean(torch.cat(object_total_loss, 0))


            total_loss = total_loss + object_total_loss + 1 * object_class_loss

            total_loss.backward()

            self.optimizer.step()

            print('train/edge_loss'+ str( query_edge_loss_layers[-1]))

            print("\t train/edge_accr" + str(query_edge_accr_layers[-1]))
            print("\t train/node_accr" + str(query_nodes_acc)+"\n")


            if self.global_step % test_interval == 0:

                val_acc = self.test(partition='val')

                is_best = 0

                if val_acc >= self.val_acc:
                    test_acc = self.test(partition='test')
                    self.test_acc = test_acc
                    self.val_acc = val_acc
                    is_best = 1

                print("val/best_accr"+ str(self.val_acc))
                print("\t val/best_accr" + str(self.test_acc)+"\n" )
 def compute_matching_matrix(self, question_encode, content_encode):
     question_encode_trans = question_encode.transpose(0, 1)  # (batch, seq_len, embedding_size)
     content_encode_trans = content_encode.transpose(0, 1)  # (batch, seq_len, embedding_size)
     content_encode_trans = content_encode_trans.transpose(1, 2)  # (batch, embedding_size, seq_len)
     Matching_matrix = torch.bmm(question_encode_trans, content_encode_trans)  # (batch, question_len , content_len)
     return Matching_matrix
 def compute_cross_document_attention(self, content_m, content_n):
     content_n_trans = content_n.transpose(1, 2)  # (batch, 512, 200)
     Matching_matrix = torch.bmm(content_m, content_n_trans)  # (batch, question_len , content_len)
     return Matching_matrix
Ejemplo n.º 50
0
def multi_head_attention_forward(
    query,  # type: Tensor
    key,  # type: Tensor
    value,  # type: Tensor
    embed_dim_to_check,  # type: int
    num_heads,  # type: int
    in_proj_weight,  # type: Tensor
    in_proj_bias,  # type: Tensor
    bias_k,  # type: Optional[Tensor]
    bias_v,  # type: Optional[Tensor]
    add_zero_attn,  # type: bool
    dropout_p,  # type: float
    out_proj_weight,  # type: Tensor
    out_proj_bias,  # type: Tensor
    training=True,  # type: bool
    key_padding_mask=None,  # type: Optional[Tensor]
    need_weights=True,  # type: bool
    attn_mask=None,  # type: Optional[Tensor]
    use_separate_proj_weight=False,  # type: bool
    q_proj_weight=None,  # type: Optional[Tensor]
    k_proj_weight=None,  # type: Optional[Tensor]
    v_proj_weight=None,  # type: Optional[Tensor]
    static_k=None,  # type: Optional[Tensor]
    static_v=None,  # type: Optional[Tensor]
    relative_atten_weights=None,  # type: Tensor
    app_relation=True,
):
    # type: (...) -> Tuple[Tensor, Optional[Tensor]]
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an
            output. See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at
            dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key
            will be ignored by the attention. This is an binary mask. When the
            value is True, the corresponding value on the attention layer will
            be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: mask that prevents attention to certain positions.
            This is an additive mask
            (i.e. the values will be added to the attention layer).
        use_separate_proj_weight: the function accept the proj. weights for
            query, key, and value in different forms. If false, in_proj_weight
            will be used, which is a combination of q_proj_weight,
            k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input
            projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
        relative_atten_weights: used to deal with relative relationship, add
            in the atten weights before softmax
    Shape:
        Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length,
            N is the batch size, E is the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length,
            N is the batch size, E is the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length,
            N is the batch size, E is the embedding dimension.
        - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch
            size, S is the source sequence length.
        - attn_mask: :math:`(L, S)` where L is the target sequence length,
            S is the source sequence length.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the
            source sequence length, N is the batch size, E is the embedding
            dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the
            source sequence length, N is the batch size, E is the embedding
            dimension. E/num_heads is the head dimension.
        - relative_atten_weights: math:`(N, num_heads, L, S)`, where N is the
            batch size, L is the target sequence length, S is the source
            sequence length.
        Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length,
            N is the batch size, E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
    """

    qkv_same = torch.equal(query, key) and torch.equal(key, value)
    kv_same = torch.equal(key, value)

    tgt_len, bsz, embed_dim = query.size()
    assert embed_dim == embed_dim_to_check
    assert list(query.size()) == [tgt_len, bsz, embed_dim]
    assert key.size() == value.size()

    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, \
        'embed_dim must be divisible by num_heads'
    scaling = float(head_dim)**-0.5

    if use_separate_proj_weight is not True:
        if qkv_same:
            # self-attention
            q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(
                3, dim=-1)

        elif kv_same:
            # encoder-decoder attention
            # This is inline in_proj function
            # with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = F.linear(query, _w, _b)

            if key is None:
                assert value is None
                k = None
                v = None
            else:
                # This is inline in_proj function
                # with in_proj_weight and in_proj_bias
                _b = in_proj_bias
                _start = embed_dim
                _end = None
                _w = in_proj_weight[_start:, :]
                if _b is not None:
                    _b = _b[_start:]
                k, v = F.linear(key, _w, _b).chunk(2, dim=-1)

        else:
            # This is inline in_proj function
            # with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = F.linear(query, _w, _b)

            # This is inline in_proj function
            # with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim
            _end = embed_dim * 2
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            k = F.linear(key, _w, _b)

            # This is inline in_proj function
            # with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim * 2
            _end = None
            _w = in_proj_weight[_start:, :]
            if _b is not None:
                _b = _b[_start:]
            v = F.linear(value, _w, _b)
    else:
        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
        len1, len2 = q_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == query.size(-1)

        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
        len1, len2 = k_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == key.size(-1)

        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
        len1, len2 = v_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == value.size(-1)

        if in_proj_bias is not None:
            q = F.linear(query, q_proj_weight_non_opt,
                         in_proj_bias[0:embed_dim])
            k = F.linear(key, k_proj_weight_non_opt,
                         in_proj_bias[embed_dim:(embed_dim * 2)])
            v = F.linear(value, v_proj_weight_non_opt,
                         in_proj_bias[(embed_dim * 2):])
        else:
            q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
            k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
            v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
    q = q * scaling

    if bias_k is not None and bias_v is not None:
        if static_k is None and static_v is None:
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [
                        attn_mask,
                        torch.zeros(
                            (attn_mask.size(0), 1),
                            dtype=attn_mask.dtype,
                            device=attn_mask.device,
                        ),
                    ],
                    dim=1,
                )
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(
                            (key_padding_mask.size(0), 1),
                            dtype=key_padding_mask.dtype,
                            device=key_padding_mask.device,
                        ),
                    ],
                    dim=1,
                )
        else:
            assert static_k is None, 'bias cannot be added to static key.'
            assert static_v is None, 'bias cannot be added to static value.'
    else:
        assert bias_k is None
        assert bias_v is None

    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if k is not None:
        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    if v is not None:
        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)

    if static_k is not None:
        assert static_k.size(0) == bsz * num_heads
        assert static_k.size(2) == head_dim
        k = static_k

    if static_v is not None:
        assert static_v.size(0) == bsz * num_heads
        assert static_v.size(2) == head_dim
        v = static_v

    src_len = k.size(1)

    if key_padding_mask is not None:
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len

    if add_zero_attn:
        src_len += 1
        k = torch.cat(
            [
                k,
                torch.zeros(
                    (k.size(0), 1) + k.size()[2:],
                    dtype=k.dtype,
                    device=k.device),
            ],
            dim=1,
        )
        v = torch.cat(
            [
                v,
                torch.zeros(
                    (v.size(0), 1) + v.size()[2:],
                    dtype=v.dtype,
                    device=v.device),
            ],
            dim=1,
        )
        if attn_mask is not None:
            attn_mask = torch.cat(
                [
                    attn_mask,
                    torch.zeros(
                        (attn_mask.size(0), 1),
                        dtype=attn_mask.dtype,
                        device=attn_mask.device,
                    ),
                ],
                dim=1,
            )
        if key_padding_mask is not None:
            key_padding_mask = torch.cat(
                [
                    key_padding_mask,
                    torch.zeros(
                        (key_padding_mask.size(0), 1),
                        dtype=key_padding_mask.dtype,
                        device=key_padding_mask.device,
                    ),
                ],
                dim=1,
            )
    if app_relation:
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(
            attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
    else:
        attn_output_weights = None
    if relative_atten_weights is not None:
        if app_relation:
            attn_output_weights = attn_output_weights.view(
                bsz, num_heads, tgt_len, src_len)
        else:
            attn_output_weights = 0
        attn_output_weights += relative_atten_weights
        attn_output_weights = attn_output_weights.reshape(
            bsz * num_heads, tgt_len, src_len)

    assert attn_output_weights is not None, \
        ('Please either specify relative position relation '
         'or appearance relation.')

    if attn_mask is not None:
        attn_mask = attn_mask.unsqueeze(0)
        attn_output_weights += attn_mask

    if key_padding_mask is not None:
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
                                                       src_len)
        attn_output_weights = attn_output_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float('-inf'),
        )
        attn_output_weights = attn_output_weights.view(bsz * num_heads,
                                                       tgt_len, src_len)

    attn_output_weights = F.softmax(attn_output_weights, dim=-1)
    attn_output_weights = F.dropout(
        attn_output_weights, p=dropout_p, training=training)

    attn_output = torch.bmm(attn_output_weights, v)
    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output.transpose(0, 1).contiguous().view(
        tgt_len, bsz, embed_dim)
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)

    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
                                                       src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None
Ejemplo n.º 51
0
 def matrix_power3(self, Input):
     B = torch.bmm(Input, Input)
     return torch.bmm(B, Input)
Ejemplo n.º 52
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        number_indices: torch.LongTensor,
        answer_as_passage_spans: torch.LongTensor = None,
        answer_as_question_spans: torch.LongTensor = None,
        answer_as_add_sub_expressions: torch.LongTensor = None,
        answer_as_counts: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        passsage_attention_over_attention = torch.bmm(
            passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(
            encoded_passage, passsage_attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            ))

        # The recurrent modeling layers. Since these layers share the same parameters,
        # we don't construct them conditioned on answering abilities.
        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]
        for _ in range(4):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)
        # Pop the first one, which is input
        modeled_passage_list.pop(0)

        # The first modeling layer is used to calculate the vector representation of passage
        passage_weights = self._passage_weights_predictor(
            modeled_passage_list[0]).squeeze(-1)
        passage_weights = masked_softmax(passage_weights, passage_mask)
        passage_vector = util.weighted_sum(modeled_passage_list[0],
                                           passage_weights)
        # The vector representation of question is calculated based on the unmatched encoding,
        # because we may want to infer the answer ability only based on the question words.
        question_weights = self._question_weights_predictor(
            encoded_question).squeeze(-1)
        question_weights = masked_softmax(question_weights, question_mask)
        question_vector = util.weighted_sum(encoded_question, question_weights)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self._answer_ability_predictor(
                torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1,
                best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[1]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self._passage_span_start_predictor(
                passage_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[2]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self._passage_span_end_predictor(
                passage_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_log_probs = util.masked_log_softmax(
                passage_span_start_logits, passage_mask)
            passage_span_end_log_probs = util.masked_log_softmax(
                passage_span_end_logits, passage_mask)

            # Info about the best passage span prediction
            passage_span_start_logits = replace_masked_values_with_big_negative_number(
                passage_span_start_logits, passage_mask)
            passage_span_end_logits = replace_masked_values_with_big_negative_number(
                passage_span_end_logits, passage_mask)
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits,
                                              passage_span_end_logits)
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = torch.gather(
                passage_span_start_log_probs, 1,
                best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_passage_end_log_probs = torch.gather(
                passage_span_end_log_probs, 1,
                best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[:, self.
                                                                       _passage_span_extraction_index]

        if "question_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, question_length)
            encoded_question_for_span_prediction = torch.cat(
                [
                    encoded_question,
                    passage_vector.unsqueeze(1).repeat(
                        1, encoded_question.size(1), 1),
                ],
                -1,
            )
            question_span_start_logits = self._question_span_start_predictor(
                encoded_question_for_span_prediction).squeeze(-1)
            # Shape: (batch_size, question_length)
            question_span_end_logits = self._question_span_end_predictor(
                encoded_question_for_span_prediction).squeeze(-1)
            question_span_start_log_probs = util.masked_log_softmax(
                question_span_start_logits, question_mask)
            question_span_end_log_probs = util.masked_log_softmax(
                question_span_end_logits, question_mask)

            # Info about the best question span prediction
            question_span_start_logits = replace_masked_values_with_big_negative_number(
                question_span_start_logits, question_mask)
            question_span_end_logits = replace_masked_values_with_big_negative_number(
                question_span_end_logits, question_mask)
            # Shape: (batch_size, 2)
            best_question_span = get_best_span(question_span_start_logits,
                                               question_span_end_logits)
            # Shape: (batch_size, 2)
            best_question_start_log_probs = torch.gather(
                question_span_start_log_probs, 1,
                best_question_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_question_end_log_probs = torch.gather(
                question_span_end_log_probs, 1,
                best_question_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_question_span_log_prob = (best_question_start_log_probs +
                                           best_question_end_log_probs)
            if len(self.answering_abilities) > 1:
                best_question_span_log_prob += answer_ability_log_probs[:,
                                                                        self.
                                                                        _question_span_extraction_index]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = number_indices != -1
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_passage_for_numbers = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers,
                1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)),
            )
            # Shape: (batch_size, # of numbers in the passage)
            encoded_numbers = torch.cat(
                [
                    encoded_numbers,
                    passage_vector.unsqueeze(1).repeat(
                        1, encoded_numbers.size(1), 1),
                ],
                -1,
            )

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if (answer_as_passage_spans is not None
                or answer_as_question_spans is not None
                or answer_as_add_sub_expressions is not None
                or answer_as_counts is not None):

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = gold_passage_span_starts != -1
                    clamped_gold_passage_span_starts = util.replace_masked_values(
                        gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = util.replace_masked_values(
                        gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1,
                        clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1,
                        clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = (
                        log_likelihood_for_passage_span_starts +
                        log_likelihood_for_passage_span_ends)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = (
                        replace_masked_values_with_big_negative_number(
                            log_likelihood_for_passage_spans,
                            gold_passage_span_mask,
                        ))
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :,
                                                                         0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = gold_question_span_starts != -1
                    clamped_gold_question_span_starts = util.replace_masked_values(
                        gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = util.replace_masked_values(
                        gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = torch.gather(
                        question_span_start_log_probs, 1,
                        clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = torch.gather(
                        question_span_end_log_probs, 1,
                        clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = (
                        log_likelihood_for_question_span_starts +
                        log_likelihood_for_question_span_ends)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = (
                        replace_masked_values_with_big_negative_number(
                            log_likelihood_for_question_spans,
                            gold_question_span_mask,
                        ))
                    # Shape: (batch_size, )

                    log_marginal_likelihood_for_question_span = util.logsumexp(
                        log_likelihood_for_question_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = answer_as_add_sub_expressions.sum(
                        -1) > 0
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = util.replace_masked_values(
                        log_likelihood_for_number_signs,
                        number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_add_subs, gold_add_sub_mask)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = answer_as_counts != -1
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_counts, gold_count_mask)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = (all_log_marginal_likelihoods +
                                                answer_ability_log_probs)
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]["original_passage"]
                    offsets = metadata[i]["passage_token_offsets"]
                    predicted_span = tuple(
                        best_passage_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = passage_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]["original_question"]
                    offsets = metadata[i]["question_token_offsets"]
                    predicted_span = tuple(
                        best_question_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = question_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif (predicted_ability_str == "addition_subtraction"
                      ):  # plus_minus combination answer
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]["original_numbers"]
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum([
                        sign * number for sign, number in zip(
                            predicted_signs, original_numbers)
                    ])
                    predicted_answer = str(result)
                    offsets = metadata[i]["passage_token_offsets"]
                    number_indices = metadata[i]["number_indices"]
                    number_positions = [
                        offsets[index] for index in number_indices
                    ]
                    answer_json["numbers"] = []
                    for offset, value, sign in zip(number_positions,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json["numbers"].append({
                            "span": offset,
                            "value": value,
                            "sign": sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get("answer_annotations", [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            output_dict[
                "passage_question_attention"] = passage_question_attention
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict
 def compute_RnQ(self, an_matrix, content_encode):
     content_encode_trans = content_encode.transpose(0, 1)  ## (batch, content_len, embedding_size)
     RnQ = torch.bmm(an_matrix, content_encode_trans)
     return RnQ
Ejemplo n.º 54
0
                processed['attention_mask']).to(device)
            token_type_ids = torch.tensor(
                processed['token_type_ids']).to(device)

            output = model(
                input_ids=input_ids.view(M, max_seq_length),
                attention_mask=attention_mask.view(M, max_seq_length),
                token_type_ids=token_type_ids.view(M, max_seq_length))

            start_logits = output.start_logits
            end_logits = output.end_logits

            start_prob = softmax(start_logits)
            end_prob = softmax(end_logits)

            span_prob = torch.bmm(start_prob.view(M, -1, 1),
                                  end_prob.view(M, 1, -1))
            span_prob = torch.triu(span_prob)

            # mask to limit the length of the span
            mask = torch.ones_like(span_prob)
            mask = torch.triu(mask, diagonal=max_answer_length)
            span_prob[mask == 1] = 0

            # mask out the question
            span_prob[token_type_ids == 0] = 0

            best_span = torch.argmax(span_prob.view(M, -1), dim=1)

            offset = processed['offset_mapping']
            new_prediction = [
                processed['backs'][idx]
 def compute_RnD(self, bn_matrix, question_encode):
     bn_matrix_trans = bn_matrix.transpose(1, 2)  # size=(batch,content_len,question_len)
     question_encode_trans = question_encode.transpose(0, 1)  # (batch, question_len, embedding_size)
     RnD = torch.bmm(bn_matrix_trans, question_encode_trans)
     return RnD
Ejemplo n.º 56
0
    def forward(self, u_enc_out, z_tm1, last_hidden, u_input_np, pv_z_enc_out,
                prev_z_input_np, u_emb, pv_z_emb):

        sparse_u_input = Variable(get_sparse_input_aug(u_input_np),
                                  requires_grad=False)

        if pv_z_enc_out is not None:
            context = self.attn_u(last_hidden,
                                  torch.cat([pv_z_enc_out, u_enc_out], dim=0))
        else:
            context = self.attn_u(last_hidden, u_enc_out)
        embed_z = self.emb(z_tm1)
        #embed_z = F.dropout(embed_z, self.dropout_rate)
        #embed_z = self.inp_dropout(embed_z)

        gru_in = torch.cat([embed_z, context], 2)
        gru_out, last_hidden = self.gru(gru_in, last_hidden)
        #gru_out = F.dropout(gru_out, self.dropout_rate)
        #gru_out = self.inp_dropout(gru_out)
        gen_score = self.proj(torch.cat([gru_out, context], 2)).squeeze(0)
        #gen_score = F.dropout(gen_score, self.dropout_rate)
        #gen_score = self.inp_dropout(gen_score)
        u_copy_score = F.tanh(self.proj_copy1(u_enc_out.transpose(
            0, 1)))  # [B,T,H]
        # stable version of copynet
        u_copy_score = torch.matmul(u_copy_score,
                                    gru_out.squeeze(0).unsqueeze(2)).squeeze(2)
        u_copy_score = u_copy_score.cpu()
        u_copy_score_max = torch.max(u_copy_score, dim=1, keepdim=True)[0]
        u_copy_score = torch.exp(u_copy_score - u_copy_score_max)  # [B,T]
        u_copy_score = torch.log(
            torch.bmm(u_copy_score.unsqueeze(1),
                      sparse_u_input)).squeeze(1) + u_copy_score_max  # [B,V]
        u_copy_score = cuda_(u_copy_score)
        if pv_z_enc_out is None:
            #u_copy_score = F.dropout(u_copy_score, self.dropout_rate)
            #u_copy_score = self.inp_dropout(u_copy_score)
            scores = F.softmax(torch.cat([gen_score, u_copy_score], dim=1),
                               dim=1)
            gen_score, u_copy_score = scores[:, :cfg.vocab_size], \
                                      scores[:, cfg.vocab_size:]
            proba = gen_score + u_copy_score[:, :cfg.vocab_size]  # [B,V]
            proba = torch.cat([proba, u_copy_score[:, cfg.vocab_size:]], 1)
        else:
            sparse_pv_z_input = Variable(get_sparse_input_aug(prev_z_input_np),
                                         requires_grad=False)
            pv_z_copy_score = F.tanh(
                self.proj_copy2(pv_z_enc_out.transpose(0, 1)))  # [B,T,H]
            pv_z_copy_score = torch.matmul(
                pv_z_copy_score,
                gru_out.squeeze(0).unsqueeze(2)).squeeze(2)
            pv_z_copy_score = pv_z_copy_score.cpu()
            pv_z_copy_score_max = torch.max(pv_z_copy_score,
                                            dim=1,
                                            keepdim=True)[0]
            pv_z_copy_score = torch.exp(pv_z_copy_score -
                                        pv_z_copy_score_max)  # [B,T]
            pv_z_copy_score = torch.log(
                torch.bmm(pv_z_copy_score.unsqueeze(1), sparse_pv_z_input)
            ).squeeze(1) + pv_z_copy_score_max  # [B,V]
            pv_z_copy_score = cuda_(pv_z_copy_score)
            scores = F.softmax(torch.cat(
                [gen_score, u_copy_score, pv_z_copy_score], dim=1),
                               dim=1)
            gen_score, u_copy_score, pv_z_copy_score = scores[:, :cfg.vocab_size], \
                                                       scores[:,
                                                       cfg.vocab_size:2 * cfg.vocab_size + u_input_np.shape[0]], \
                                                       scores[:, 2 * cfg.vocab_size + u_input_np.shape[0]:]
            proba = gen_score + u_copy_score[:, :cfg.
                                             vocab_size] + pv_z_copy_score[:, :
                                                                           cfg.
                                                                           vocab_size]  # [B,V]
            proba = torch.cat([
                proba, pv_z_copy_score[:, cfg.vocab_size:],
                u_copy_score[:, cfg.vocab_size:]
            ], 1)
        return gru_out, last_hidden, proba
Ejemplo n.º 57
0
    def step(self, Ybar_t: torch.Tensor,
            dec_state: Tuple[torch.Tensor, torch.Tensor],
            enc_hiddens: torch.Tensor,
            enc_hiddens_proj: torch.Tensor,
            enc_masks: torch.Tensor) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
        """ Compute one forward step of the LSTM decoder, including the attention computation.

        @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder,
                                where b = batch size, e = embedding size, h = hidden size.
        @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size.
                First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
        @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size,
                                    src_len = maximum source length, h = hidden size.
        @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h),
                                    where b = batch size, src_len = maximum source length, h = hidden size.
        @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
                                    where b = batch size, src_len is maximum source length. 

        @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size.
                First tensor is decoder's new hidden state, second tensor is decoder's new cell.
        @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size.
        @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
                                Note: You will not use this outside of this function.
                                      We are simply returning this value so that we can sanity check
                                      your implementation.
        """

        combined_output = None

        ### YOUR CODE HERE (~3 Lines)
        ### TODO:
        ###     1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state.
        ###     2. Split dec_state into its two parts (dec_hidden, dec_cell)
        ###     3. Compute the attention scores e_t, a Tensor shape (b, src_len). 
        ###        Note: b = batch_size, src_len = maximum source length, h = hidden size.
        ###
        ###       Hints:
        ###         - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched)
        ###         - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched).
        ###         - Use batched matrix multiplication (torch.bmm) to compute e_t.
        ###         - To get the tensors into the right shapes for bmm, you will need to do some squeezing and unsqueezing.
        ###         - When using the squeeze() function make sure to specify the dimension you want to squeeze
        ###             over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
        ###
        ### Use the following docs to implement this functionality:
        ###     Batch Multiplication:
        ###        https://pytorch.org/docs/stable/torch.html#torch.bmm
        ###     Tensor Unsqueeze:
        ###         https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
        ###     Tensor Squeeze:
        ###         https://pytorch.org/docs/stable/torch.html#torch.squeeze
        dec_state = self.decoder(Ybar_t, dec_state)
        dec_hidden, dec_cell = dec_state
        e_t = torch.bmm(enc_hiddens_proj, dec_hidden.unsqueeze(2))
        e_t = e_t.squeeze(2)
        ### END YOUR CODE

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.byte(), -float('inf'))

        ### YOUR CODE HERE (~6 Lines)
        ### TODO:
        ###     1. Apply softmax to e_t to yield alpha_t
        ###     2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the
        ###         attention output vector, a_t.
        #$$     Hints:
        ###           - alpha_t is shape (b, src_len)
        ###           - enc_hiddens is shape (b, src_len, 2h)
        ###           - a_t should be shape (b, 2h)
        ###           - You will need to do some squeezing and unsqueezing.
        ###     Note: b = batch size, src_len = maximum source length, h = hidden size.
        ###
        ###     3. Concatenate dec_hidden with a_t to compute tensor U_t
        ###     4. Apply the combined output projection layer to U_t to compute tensor V_t
        ###     5. Compute tensor O_t by first applying the Tanh function and then the dropout layer.
        ###
        ### Use the following docs to implement this functionality:
        ###     Softmax:
        ###         https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax
        ###     Batch Multiplication:
        ###        https://pytorch.org/docs/stable/torch.html#torch.bmm
        ###     Tensor View:
        ###         https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        ###     Tensor Concatenation:
        ###         https://pytorch.org/docs/stable/torch.html#torch.cat
        ###     Tanh:
        ###         https://pytorch.org/docs/stable/torch.html#torch.tanh
        alpha_t = F.softmax(e_t, dim = 1).unsqueeze(1)
        a_t = torch.bmm(alpha_t, enc_hiddens).squeeze(1)
        u_t = torch.cat((dec_hidden, a_t), dim = 1)
        v_t = self.combined_output_projection(u_t)
        O_t = self.dropout(torch.tanh(v_t))
        ### END YOUR CODE

        combined_output = O_t
        return dec_state, combined_output, e_t
Ejemplo n.º 58
0
    def forward(self, x):

        # input transform
        if self.use_point_stn:
            # from tuples to list of single points
            x = x.view(x.size(0), 3, -1)
            trans = self.stn1(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans)
            x = x.transpose(2, 1)
            x = x.contiguous().view(x.size(0), 3 * self.point_tuple, -1)
        else:
            trans = None

        # if self.use_mask:
        #     mask = self.mask(x)

        #     x = torch.cat([x, mask], 1)
        # else:
        #     mask=None

        x = F.relu(self.bn0a(self.conv0a(x)))
        x = F.relu(self.bn0b(self.conv0b(x)))

        # mlp (64,64)

        # feature transform
        if self.use_feat_stn:
            trans2 = self.stn2(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans2)
            x = x.transpose(2, 1)
        else:
            trans2 = None

        # mlp (64,128,1024)
        x = F.relu(self.bn1(self.conv1(x)))
        if self.use_mask:
            pointfeat = x
            n_pts = x.size()[2]
        x = F.relu(self.bn2(self.conv2(x)))

        x = self.bn3(self.conv3(x))

        # mlp (1024,1024*num_scales)
        if self.num_scales > 1:
            x = self.bn4(self.conv4(F.relu(x)))

        if self.get_pointfvals:
            pointfvals = x
        else:
            pointfvals = None  # so the intermediate result can be forgotten if it is not needed

        # symmetric max operation over all points
        if self.num_scales == 1:
            if self.sym_op == 'max':
                x = self.mp1(x)
            elif self.sym_op == 'sum':
                x = torch.sum(x, 2, keepdim=True)
            else:
                raise ValueError('Unsupported symmetric operation: %s' %
                                 (self.sym_op))

        else:
            x_scales = x.new_empty(x.size(0), 1024 * self.num_scales**2, 1)
            if self.sym_op == 'max':
                for s in range(self.num_scales):
                    x_scales[:, s * self.num_scales * 1024:(s + 1) *
                             self.num_scales * 1024, :] = self.mp1(
                                 x[:, :, s * self.num_points:(s + 1) *
                                   self.num_points])
            elif self.sym_op == 'sum':
                for s in range(self.num_scales):
                    x_scales[:, s * self.num_scales * 1024:(s + 1) *
                             self.num_scales * 1024, :] = torch.sum(
                                 x[:, :, s * self.num_points:(s + 1) *
                                   self.num_points],
                                 2,
                                 keepdim=True)
            else:
                raise ValueError('Unsupported symmetric operation: %s' %
                                 (self.sym_op))
            x = x_scales
        #x= x.transpose(2,1)
        x = x.contiguous().view(-1, 1024 * self.num_scales**2)
        if self.use_mask:
            x2 = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            x2 = torch.cat([x2, pointfeat], 1)
            x2 = F.relu(self.bna(self.conva(x2)))
            #x2 = F.relu(self.bnb(self.convb(x2)))
            x2 = F.relu(self.bnc(self.convc(x2)))
            mask = self.convd(x2)
        else:
            mask = None
        return x, trans, trans2, mask
Ejemplo n.º 59
0
 def transpose(self, x, A1):
     Amatrix = torch.tril(A1).transpose(-2,-1)
     y = torch.bmm(Amatrix, x)
     return y
Ejemplo n.º 60
0
 def read_vectors(self, memory, read_weights):
     return torch.bmm(read_weights, memory)