示例#1
0
    def forward(self, x, length):
        batch_size = x.size(0)
        maxlength = x.size(1)
        hidden = [self.init_hidden(batch_size) for l in range(self.bi_num)]
        weight = self.init_attention_weight(batch_size, maxlength)
        out_final = self.init_final_out(batch_size)

        out = [x, reverse_padded_sequence(x, length, batch_first=True)]
        for l in range(self.bi_num):
            out[l] = pack_padded_sequence(out[l], length, batch_first=True)
            out[l], hidden[l] = self.layer1[l](out[l], hidden[l])
            out[l], _ = pad_packed_sequence(out[l], batch_first=True)
            if (l == 1):
                out[l] = reverse_padded_sequence(out[l],
                                                 length,
                                                 batch_first=True)

        if (self.bi_num == 1): out = out[0]
        else: out = torch.cat(out, 2)
        potential = self.simple_attention(out)
        for inx, l in enumerate(length):
            weight[inx,
                   0:l] = torch.squeeze(F.softmax(potential[inx, 0:l], dim=0),
                                        1)
        out_final = torch.squeeze(torch.bmm(torch.unsqueeze(weight, 1), out),
                                  1)
        out_final = self.layer2(out_final)
        return out_final, length
    def forward(self, x, length, r):
        batch_size = x.size(0)
        maxlength = x.size(1)
        hidden = [self.init_hidden(batch_size) for l in range(self.bi_num)]
        weight = self.init_attention_weight(batch_size, maxlength, r)
        oneMat = Variable(torch.ones(batch_size, r, r)).cuda()

        out = [x, reverse_padded_sequence(x, length, batch_first=True)]
        for l in range(self.bi_num):
            out[l] = pack_padded_sequence(out[l], length, batch_first=True)
            out[l], hidden[l] = self.layer1[l](out[l], hidden[l])
            out[l], _ = pad_packed_sequence(out[l], batch_first=True)
            if (l == 1):
                out[l] = reverse_padded_sequence(out[l],
                                                 length,
                                                 batch_first=True)

        if (self.bi_num == 1): out = out[0]
        else: out = torch.cat(out, 2)
        potential = self.simple_attention(out)
        for inx, l in enumerate(length):
            weight[inx, 0:l, :] = F.softmax(potential[inx, 0:l, :], dim=0)
        weight = torch.transpose(weight, 1, 2)
        out_final = torch.bmm(weight, out)
        out_final = out_final.view(batch_size, -1)
        out_final = self.layer2(out_final)

        penalty = torch.sum(
            torch.sum(
                torch.sum(
                    torch.pow(
                        torch.bmm(weight, torch.transpose(weight, 1, 2)) -
                        oneMat, 2.0), 0), 0), 0)
        return out_final, length, penalty
    def forward(self,x1,x2,length):
        batch_size=x1.size(0)
        maxlength=int(np.max(length))
        hidden=[ self.init_hidden(batch_size) for l in range(self.bi_num)]
#        weight=self.init_attention_weight(batch_size,maxlength)
#        out_final=self.init_final_out(batch_size)

        x1=x1.view(-1,self.input_channels,self.input_dim1)
        out=self.cov1(x1)
        out=self.cov2(out)
        out=out.view(batch_size,maxlength,self.dimAfterCov)
        out=torch.cat((out,x2),dim=2)

        out=[out,reverse_padded_sequence(out,length,batch_first=True)]
        for l in range(self.bi_num):
            out[l]=pack_padded_sequence(out[l],length,batch_first=True)
            out[l],hidden[l]=self.layer1[l](out[l],hidden[l])
            out[l],_=pad_packed_sequence(out[l],batch_first=True)
            if(l==1):out[l]=reverse_padded_sequence(out[l],length,batch_first=True)
#        potential=self.simple_attention(out)
#        for inx,l in enumerate(length):weight[inx,0:l]=F.softmax(potential[inx,0:l],dim=0)
#        for inx,l in enumerate(length):out_final[inx,:]=torch.matmul(weight[inx,:],torch.squeeze(out[inx,:,:]))
        
        if(self.bi_num==1):out=out[0]
        else:out=torch.cat(out,2)
        out=self.layer2(out)
        out=torch.squeeze(out)
        return out,length
示例#4
0
    def forward(self,x,length):
        batch_size=x.size(0)
        hidden=[ self.init_hidden(batch_size) for l in range(self.bi_num)]

        out=[x,reverse_padded_sequence(x,length,batch_first=True)]
        for l in range(self.bi_num):
            out[l]=pack_padded_sequence(out[l],length,batch_first=True)
            out[l],hidden[l]=self.layer1[l](out[l],hidden[l])
            out[l],_=pad_packed_sequence(out[l],batch_first=True)
            if(l==1):out[l]=reverse_padded_sequence(out[l],length,batch_first=True)
        
        if(self.bi_num==1):out=out[0]
        else:out=torch.cat(out,2)
        out=self.layer2(out)
        out=torch.squeeze(out)
        return out,length