Beispiel #1
0
    def forward(self, data: BatchHolder):
        if self.use_attention:
            output = data.hidden
            mask = data.masks
            attn = self.attention(data.seq, output, mask)

            if self.use_regulariser_attention:
                data.reg_loss = 5 * self.regularizer_attention.regularise(
                    data.seq, output, mask, attn)

            if isTrue(data, 'detach'):
                attn = attn.detach()

            if isTrue(data, 'permute'):
                permutation = data.generate_permutation()
                attn = torch.gather(attn, -1,
                                    torch.LongTensor(permutation).to(device))

            context = (attn.unsqueeze(-1) * output).sum(1)
            data.attn = attn
        else:
            context = data.last_hidden

        predict = self.decode(context)
        data.predict = predict
Beispiel #2
0
    def forward(self, data):
        #input_seq = (B, L), hidden : (B, L, H), masks : (B, L)
        input_seq, hidden, masks = data.seq, data.hidden, data.masks
        lengths = data.lengths

        attn1 = nn.Tanh()(self.attn1(hidden))
        attn2 = self.attn2(attn1).squeeze(-1)
        data.attn_logit = attn2
        attn = masked_softmax(attn2, masks)

        inf = 1e9
        if isTrue(data, 'erase_max'):
            attn2[:, attn.max(dim=1)[1]] = -1 * inf
            attn = masked_softmax(attn2, masks)

        if isTrue(data, 'erase_random'):
            rand_len = (torch.rand(size=lengths.size()).to(device) *
                        (lengths).float()).long()
            attn2[:, rand_len] = -1 * inf
            attn = masked_softmax(attn2, masks)

        if isTrue(data, 'erase_given'):
            attn2[:, data.erase_attn] = -1 * inf
            attn = masked_softmax(attn2, masks)

        return attn
Beispiel #3
0
    def forward(self, data, is_embds=False):
        seq = data.seq
        lengths = data.lengths

        if (len(seq.shape) == 2):  # look up embeds table
            embedding = self.embedding(seq)

        else:  # skip lookup
            embedding = data.seq
            embedding.requires_grad = True

        #print("JPRD I EMBEDDED", embedding.shape)
        batch_size = embedding.size()[0]

        self.init_states = [[
            LSTMState(cx=torch.zeros(batch_size, self.hidden_size),
                      hx=torch.zeros(batch_size, self.hidden_size))
            for _ in range(2)
        ] for _ in range(self.num_layers)]

        output, cell_output, output_state = self.rnn(embedding, lengths)
        h, c = output_state[self.num_layers - 1]

        data.hidden = output
        data.cell_state = cell_output
        data.last_hidden = h
        data.embedding = embedding

        if isTrue(data, 'keep_grads'):
            data.embedding.retain_grad()
            data.hidden.retain_grad()
Beispiel #4
0
    def forward(self, data):
        #input_seq = (B, L), hidden : (B, L, H), masks : (B, L)
        input_seq, hidden, masks = data.seq, data.hidden, data.masks
        lengths = data.lengths

        b, l, h = hidden.shape

        proj_hiddens = torch.unbind(self.projection(hidden).view(
            b, l, self.heads, h),
                                    dim=2)  # heads - [ (B,L,H), ...]
        data.proj_hiddens = proj_hiddens

        attns = []

        erase_head = getattr(data, 'erase_head', -1)

        for idx, hidden in enumerate(proj_hiddens):

            hidden = hidden.to(device)
            attn1 = nn.Tanh()(self.attn1[idx](hidden))
            attn2 = self.attn2[idx](attn1).squeeze(-1)
            data.attn_logit = attn2
            attn = masked_softmax(attn2, masks)

            inf = 1e9
            if isTrue(data, 'erase_max') and (erase_head == idx):
                attn2[:, attn.max(dim=1)[1]] = -1 * inf
                attn = masked_softmax(attn2, masks)

            if isTrue(data, 'erase_random') and (erase_head == idx):
                rand_len = (torch.rand(size=lengths.size()).to(device) *
                            (lengths).float()).long()
                attn2[:, rand_len] = -1 * inf
                attn = masked_softmax(attn2, masks)

            if isTrue(data, 'erase_given') and (erase_head == idx):
                attn2[:, data.erase_attn] = -1 * inf
                attn = masked_softmax(attn2, masks)

            attns.append(attn)
        return attns
Beispiel #5
0
    def forward(self, data: BatchMultiHolder):
        if self.use_attention:
            Poutput = data.P.hidden  #(B, H, L)
            Qoutput = data.Q.last_hidden  #(B, H)
            mask = data.P.masks

            attn = self.attention(data.P.seq, Poutput, Qoutput, mask)  #(B, L)

            if isTrue(data, 'detach'):
                attn = attn.detach()

            if isTrue(data, 'permute'):
                permutation = data.P.generate_permutation()
                attn = torch.gather(attn, -1,
                                    torch.LongTensor(permutation).to(device))

            context = (attn.unsqueeze(-1) * Poutput).sum(1)  #(B, H)
            data.attn = attn
        else:
            context = data.P.last_hidden

        predict = self.decode(context, Qoutput, data.entity_mask)
        data.predict = predict
Beispiel #6
0
    def forward(self, data) :
        seq = data.seq
        lengths = data.lengths
        embedding = self.embedding(seq) #(B, L, E)
        packseq = nn.utils.rnn.pack_padded_sequence(embedding, lengths, batch_first=True, enforce_sorted=False)
        output, (h, c) = self.rnn(packseq)
        output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True, padding_value=0)

        data.hidden = output
        data.last_hidden = torch.cat([h[0], h[1]], dim=-1)

        if isTrue(data, 'keep_grads') :
            data.embedding = embedding
            data.embedding.retain_grad()
            data.hidden.retain_grad()
Beispiel #7
0
    def forward(self, data) :
        seq = data.seq
        lengths = data.lengths
        embedding = self.embedding(seq) #(B, L, E)

        output = self.activation(self.projection(embedding))
        h = output.mean(1)

        data.hidden = output
        data.last_hidden = h

        if isTrue(data, 'keep_grads') :
            data.embedding = embedding
            data.embedding.retain_grad()
            data.hidden.retain_grad()
Beispiel #8
0
    def forward(self, input_seq, hidden_1, hidden_2, masks, data):
        #input_seq = (B, L), hidden : (B, L, H), masks : (B, L)

        attn1 = nn.Tanh()(self.attn1p(hidden_1) +
                          self.attn1q(hidden_2).unsqueeze(1))
        attn2 = self.attn2(attn1).squeeze(-1)
        attn = masked_softmax(attn2, masks)

        inf = 1e9

        if isTrue(data, 'erase_given'):
            attn2[:, data.erase_attn] = -1 * inf
            attn = masked_softmax(attn2, masks)

        return attn
Beispiel #9
0
    def forward(self, data) :
        seq = data.seq #(B, L)
        lengths = data.lengths #(B, )
        masks = data.masks #(B, L)
        embedding = self.embedding(seq) #(B, L, E)

        seq_t = embedding.transpose(1, 2)
        outputs = [self.convolutions[i](seq_t) for i in sorted(self.convolutions.keys())]

        output = self.activation(torch.cat(outputs, dim=1))
        output = output * (1 - masks.unsqueeze(1)).float()
        h = nn.functional.max_pool1d(output, kernel_size=output.size(-1)).squeeze(-1)

        data.hidden = output.transpose(1, 2)
        data.last_hidden = h

        if isTrue(data, 'keep_grads') :
            data.embedding = embedding
            data.embedding.retain_grad()
            data.hidden.retain_grad()