Beispiel #1
0
class Attention(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 num_classes,
                 num_embeddings=128,
                 CUDA=True):
        super(Attention, self).__init__()
        self.attention_cell = AttentionCell(input_size,
                                            hidden_size,
                                            num_embeddings,
                                            CUDA=CUDA)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.generator = nn.Linear(hidden_size, num_classes)
        self.char_embeddings = Parameter(
            torch.randn(num_classes + 1, num_embeddings))
        self.num_embeddings = num_embeddings
        self.num_classes = num_classes
        self.cuda = CUDA

    # targets is nT * nB
    def forward(self, feats, text_length, text, test=False):

        nT = feats.size(0)
        nB = feats.size(1)
        nC = feats.size(2)
        hidden_size = self.hidden_size
        input_size = self.input_size
        assert (input_size == nC)
        assert (nB == text_length.numel())

        num_steps = text_length.data.max()
        num_labels = text_length.data.sum()

        if not test:

            targets = torch.zeros(nB, num_steps + 1).long()
            if self.cuda:
                targets = targets.cuda()
            start_id = 0

            for i in range(nB):
                targets[i][1:1 + text_length.
                           data[i]] = text.data[start_id:start_id +
                                                text_length.data[i]] + 1
                start_id = start_id + text_length.data[i]
            targets = Variable(targets.transpose(0, 1).contiguous())

            output_hiddens = Variable(
                torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
            hidden = Variable(torch.zeros(nB, hidden_size).type_as(feats.data))

            for i in range(num_steps):
                cur_embeddings = self.char_embeddings.index_select(
                    0, targets[i])
                hidden, alpha = self.attention_cell(hidden, feats,
                                                    cur_embeddings, test)
                output_hiddens[i] = hidden

            new_hiddens = Variable(
                torch.zeros(num_labels, hidden_size).type_as(feats.data))
            b = 0
            start = 0

            for length in text_length.data:
                new_hiddens[start:start + length] = output_hiddens[0:length,
                                                                   b, :]
                start = start + length
                b = b + 1

            probs = self.generator(new_hiddens)
            return probs

        else:

            hidden = Variable(torch.zeros(nB, hidden_size).type_as(feats.data))
            targets_temp = Variable(torch.zeros(nB).long().contiguous())
            probs = Variable(torch.zeros(nB * num_steps, self.num_classes))
            if self.cuda:
                targets_temp = targets_temp.cuda()
                probs = probs.cuda()

            for i in range(num_steps):
                cur_embeddings = self.char_embeddings.index_select(
                    0, targets_temp)
                hidden, alpha = self.attention_cell(hidden, feats,
                                                    cur_embeddings, test)
                hidden2class = self.generator(hidden)
                probs[i * nB:(i + 1) * nB] = hidden2class
                _, targets_temp = hidden2class.max(1)
                targets_temp += 1

            probs = probs.view(num_steps, nB,
                               self.num_classes).permute(1, 0, 2).contiguous()
            probs = probs.view(-1, self.num_classes).contiguous()
            probs_res = Variable(
                torch.zeros(num_labels, self.num_classes).type_as(feats.data))
            b = 0
            start = 0

            for length in text_length.data:
                probs_res[start:start +
                          length] = probs[b * num_steps:b * num_steps + length]
                start = start + length
                b = b + 1

            return probs_res
Beispiel #2
0
class GCNLayer(nn.Module):
    """ Graph convolutional neural network encoder.

    """
    def __init__(self,
                 num_inputs,
                 num_units,
                 num_labels,
                 in_arcs=True,
                 out_arcs=True,
                 batch_first=False,
                 use_gates=True,
                 use_glus=False):
        super(GCNLayer, self).__init__()

        self.in_arcs = in_arcs
        self.out_arcs = out_arcs

        self.num_inputs = num_inputs
        self.num_units = num_units
        self.num_labels = num_labels
        self.batch_first = batch_first

        self.glu = nn.GLU(3)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.use_gates = use_gates
        self.use_glus = use_glus
        #https://www.cs.toronto.edu/~yujiali/files/talks/iclr16_ggnn_talk.pdf
        #https://arxiv.org/pdf/1612.08083.pdf

        if in_arcs:
            self.V_in = Parameter(torch.Tensor(self.num_inputs,
                                               self.num_units))
            nn.init.xavier_normal_(self.V_in)

            self.b_in = Parameter(torch.Tensor(num_labels, self.num_units))
            nn.init.constant_(self.b_in, 0)

            if self.use_gates:
                self.V_in_gate = Parameter(torch.Tensor(self.num_inputs, 1))
                nn.init.xavier_normal_(self.V_in_gate)
                self.b_in_gate = Parameter(torch.Tensor(num_labels, 1))
                nn.init.constant_(self.b_in_gate, 1)

        if out_arcs:
            self.V_out = Parameter(
                torch.Tensor(self.num_inputs, self.num_units))
            nn.init.xavier_normal_(self.V_out)

            self.b_out = Parameter(torch.Tensor(num_labels, self.num_units))
            nn.init.constant_(self.b_out, 0)

            if self.use_gates:
                self.V_out_gate = Parameter(torch.Tensor(self.num_inputs, 1))
                nn.init.xavier_normal_(self.V_out_gate)
                self.b_out_gate = Parameter(torch.Tensor(num_labels, 1))
                nn.init.constant_(self.b_out_gate, 1)

        self.W_self_loop = Parameter(
            torch.Tensor(self.num_inputs, self.num_units))

        nn.init.xavier_normal_(self.W_self_loop)

        if self.use_gates:
            self.W_self_loop_gate = Parameter(torch.Tensor(self.num_inputs, 1))
            nn.init.xavier_normal_(self.W_self_loop_gate)

    def forward(
            self,
            src,
            lengths=None,
            arc_tensor_in=None,
            arc_tensor_out=None,
            label_tensor_in=None,
            label_tensor_out=None,
            mask_in=None,
            mask_out=None,  # batch* t, degree
            mask_loop=None,
            sent_mask=None):

        if not self.batch_first:
            encoder_outputs = src.permute(1, 0, 2).contiguous()
        else:
            encoder_outputs = src.contiguous()
        batch_size = encoder_outputs.size()[0]
        seq_len = encoder_outputs.size()[1]
        max_degree = 1
        input_ = encoder_outputs.view(
            (batch_size * seq_len, self.num_inputs))  # [b* t, h]

        if self.in_arcs:
            input_in = torch.mm(input_,
                                self.V_in)  # [b* t, h] * [h,h] = [b*t, h]
            first_in = input_in.index_select(
                0, arc_tensor_in[0] * seq_len +
                arc_tensor_in[1])  # [b* t* degr, h]
            second_in = self.b_in.index_select(
                0, label_tensor_in[0])  # [b* t* degr, h]
            in_ = first_in + second_in
            degr = int(first_in.size()[0] / batch_size / seq_len)

            in_ = in_.view((batch_size, seq_len, degr, self.num_units))

            if self.use_glus:
                # gate the information of each neighbour, self nodes are in here too.
                in_ = torch.cat((in_, in_), 3)
                in_ = self.glu(in_)

            if self.use_gates:
                # compute gate weights
                input_in_gate = torch.mm(
                    input_, self.V_in_gate)  # [b* t, h] * [h,h] = [b*t, h]
                first_in_gate = input_in_gate.index_select(
                    0, arc_tensor_in[0] * seq_len +
                    arc_tensor_in[1])  # [b* t* mxdeg, h]
                second_in_gate = self.b_in_gate.index_select(
                    0, label_tensor_in[0])
                in_gate = (first_in_gate + second_in_gate).view(
                    (batch_size, seq_len, degr))

            max_degree += degr

        if self.out_arcs:
            input_out = torch.mm(input_,
                                 self.V_out)  # [b* t, h] * [h,h] = [b* t, h]
            first_out = input_out.index_select(
                0, arc_tensor_out[0] * seq_len +
                arc_tensor_out[1])  # [b* t* mxdeg, h]
            second_out = self.b_out.index_select(0, label_tensor_out[0])

            degr = int(first_out.size()[0] / batch_size / seq_len)
            max_degree += degr

            out_ = (first_out + second_out).view(
                (batch_size, seq_len, degr, self.num_units))

            if self.use_glus:
                # gate the information of each neighbour, self nodes are in here too.
                out_ = torch.cat((out_, out_), 3)
                out_ = self.glu(out_)

            if self.use_gates:
                # compute gate weights
                input_out_gate = torch.mm(
                    input_, self.V_out_gate)  # [b* t, h] * [h,h] = [b* t, h]
                first_out_gate = input_out_gate.index_select(
                    0, arc_tensor_out[0] * seq_len +
                    arc_tensor_out[1])  # [b* t* mxdeg, h]
                second_out_gate = self.b_out_gate.index_select(
                    0, label_tensor_out[0])
                out_gate = (first_out_gate + second_out_gate).view(
                    (batch_size, seq_len, degr))


        same_input = torch.mm(encoder_outputs.view(-1, encoder_outputs.size(2)), self.W_self_loop). \
            view(encoder_outputs.size(0), encoder_outputs.size(1), -1)
        same_input = same_input.view(encoder_outputs.size(0),
                                     encoder_outputs.size(1), 1,
                                     self.W_self_loop.size(1))
        if self.use_gates:
            same_input_gate = torch.mm(encoder_outputs.view(-1, encoder_outputs.size(2)), self.W_self_loop_gate) \
                .view(encoder_outputs.size(0), encoder_outputs.size(1), -1)

        if self.in_arcs and self.out_arcs:
            potentials = torch.cat((in_, out_, same_input),
                                   dim=2)  # [b, t,  mxdeg, h]
            if self.use_gates:
                potentials_gate = torch.cat(
                    (in_gate, out_gate, same_input_gate),
                    dim=2)  # [b, t,  mxdeg, h]
            mask_soft = torch.cat((mask_in, mask_out, mask_loop),
                                  dim=1)  # [b* t, mxdeg]
        elif self.out_arcs:
            potentials = torch.cat((out_, same_input),
                                   dim=2)  # [b, t,  2*mxdeg+1, h]
            if self.use_gates:
                potentials_gate = torch.cat((out_gate, same_input_gate),
                                            dim=2)  # [b, t,  mxdeg, h]
            mask_soft = torch.cat((mask_out, mask_loop),
                                  dim=1)  # [b* t, mxdeg]
        elif self.in_arcs:
            potentials = torch.cat((in_, same_input),
                                   dim=2)  # [b, t,  2*mxdeg+1, h]
            if self.use_gates:
                potentials_gate = torch.cat((in_gate, same_input_gate),
                                            dim=2)  # [b, t,  mxdeg, h]
            mask_soft = torch.cat((mask_in, mask_loop), dim=1)  # [b* t, mxdeg]
        else:
            potentials = same_input  # [b, t,  2*mxdeg+1, h]
            if self.use_gates:
                potentials_gate = same_input_gate  # [b, t,  mxdeg, h]
            mask_soft = mask_loop  # [b* t, mxdeg]

        potentials_resh = potentials.view((
            batch_size * seq_len,
            max_degree,
            self.num_units,
        ))  # [h, b * t, mxdeg]

        if self.use_gates:
            potentials_r = potentials_gate.view(
                (batch_size * seq_len, max_degree))  # [b * t, mxdeg]

            probs_det_ = (self.sigmoid(potentials_r) * mask_soft).unsqueeze(
                2)  # [b * t, mxdeg]
            potentials_masked = potentials_resh * probs_det_  # [b * t, mxdeg,h]
        else:
            # NO Gates
            potentials_masked = potentials_resh * mask_soft.unsqueeze(2)

        potentials_masked_ = potentials_masked.sum(dim=1)  # [b * t, h]
        potentials_masked_ = self.relu(potentials_masked_)  # [b * t, h]

        result_ = potentials_masked_.view(
            (batch_size, seq_len, self.num_units))  # [ b, t, h]

        result_ = result_ * sent_mask.permute(1, 0).contiguous().unsqueeze(
            2)  # [b, t, h]

        memory_bank = result_.permute(1, 0, 2).contiguous()  # [t, b, h]

        return memory_bank
Beispiel #3
0
class DTD(nn.Module):
    # LSTM DTD
    def __init__(self, nclass, nchannel, dropout=0.3):
        super(DTD, self).__init__()
        self.nclass = nclass
        self.nchannel = nchannel
        self.pre_lstm = nn.LSTM(nchannel,
                                int(nchannel / 2),
                                bidirectional=True)
        self.rnn = nn.GRUCell(nchannel * 2, nchannel)
        self.generator = nn.Sequential(nn.Dropout(p=dropout),
                                       nn.Linear(nchannel, nclass))
        self.char_embeddings = Parameter(torch.randn(nclass, nchannel))

    def forward(self, feature, A, text, text_length, test=False):
        nB, nC, nH, nW = feature.size()
        nT = A.size()[1]
        # Normalize
        A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
        # weighted sum
        C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
        C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0)
        C, _ = self.pre_lstm(C)
        C = F.dropout(C, p=0.3, training=self.training)
        if not test:
            lenText = int(text_length.sum())
            nsteps = int(text_length.max())

            gru_res = torch.zeros(C.size()).type_as(C.data)
            out_res = torch.zeros(lenText, self.nclass).type_as(feature.data)
            out_attns = torch.zeros(lenText, nH, nW).type_as(A.data)

            hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
            prev_emb = self.char_embeddings.index_select(
                0,
                torch.zeros(nB).long().type_as(text.data))
            for i in range(0, nsteps):
                hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
                                  hidden)
                gru_res[i, :, :] = hidden
                prev_emb = self.char_embeddings.index_select(0, text[:, i])
            gru_res = self.generator(gru_res)

            start = 0
            for i in range(0, nB):
                cur_length = int(text_length[i])
                out_res[start:start + cur_length] = gru_res[0:cur_length, i, :]
                out_attns[start:start + cur_length] = A[i, 0:cur_length, :, :]
                start += cur_length

            return out_res, out_attns

        else:
            lenText = nT
            nsteps = nT
            out_res = torch.zeros(lenText, nB,
                                  self.nclass).type_as(feature.data)

            hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
            prev_emb = self.char_embeddings.index_select(
                0,
                torch.zeros(nB).long().type_as(text.data))
            out_length = torch.zeros(nB)
            now_step = 0
            while 0 in out_length and now_step < nsteps:
                hidden = self.rnn(
                    torch.cat((C[now_step, :, :], prev_emb), dim=1), hidden)
                tmp_result = self.generator(hidden)
                out_res[now_step] = tmp_result
                tmp_result = tmp_result.topk(1)[1].squeeze()
                for j in range(nB):
                    if out_length[j] == 0 and tmp_result[j] == 0:
                        out_length[j] = now_step + 1
                prev_emb = self.char_embeddings.index_select(0, tmp_result)
                now_step += 1
            for j in range(0, nB):
                if int(out_length[j]) == 0:
                    out_length[j] = nsteps

            start = 0
            output = torch.zeros(int(out_length.sum()),
                                 self.nclass).type_as(feature.data)
            for i in range(0, nB):
                cur_length = int(out_length[i])
                output[start:start + cur_length] = out_res[0:cur_length, i, :]
                start += cur_length

            return output, out_length