Exemplo n.º 1
0
    def forward(self, sents):
        # Embed the sequence
        x, lengths = to_input_tensor(self.language, sents, self.device)
        x_embed = self.embedding(x)
        # RNN encoding
        x = nn.utils.rnn.pack_padded_sequence(x_embed, lengths)
        x, _ = self.gru(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x)

        x = x.transpose(0, 1)
        # get attention over RNN outputs
        I = torch.eye(max(lengths))
        attn_mask = torch.stack([I] * self.batch_size)
        for i, l in zip(list(range(self.batch_size)), lengths):
            attn_mask[i, :, l:] = 1
            attn_mask[i, l:, :] = 1

        attn = self.attention(x, attn_mask, self.device)
        attn_vec = attn.unsqueeze(-1) * x.unsqueeze(1)
        attn_vec = attn_vec.sum(-2)
        attn_out = torch.cat([attn_vec, x], dim=-1)

        # max pool over sequence
        attn_out = attn_out.transpose(-1, -2)
        max_vec, _ = torch.max(attn_out, -1)
        max_vec = max_vec.unsqueeze(-2)

        # binary classification activ.
        y = torch.sigmoid(self.classify(max_vec)).squeeze()
        return y
Exemplo n.º 2
0
    def forward(self, context, word):
        # convert word lists to indices
        context_input = to_input_tensor(context, self.embeddings_df, is_contexts=True)
        word_input = to_input_tensor(word, self.embeddings_df, is_contexts=False)

        # get embeddings of all words in context and the word
        context_embed = self.embedding(context_input)
        word_embed = self.embedding(word_input)

        # take average of context words to combine
        # context_embed = (b, con_len, embed_dim)
        context_embed = torch.mean(context_embed, dim=1)

        # run through linear layer
        concat_features = torch.cat((word_embed, context_embed), axis=1)
        linear_output = self.linear(concat_features)
        return linear_output
Exemplo n.º 3
0
    def forward(self, x, lang, device):
        x, _ = to_input_tensor(lang, x, device)
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.token_embeddings(x)
        h = h + self.pos_embeddings(positions).expand_as(h)
        h = self.dropout(h)

        return h, len(x)
Exemplo n.º 4
0
    def forward(self, sents):
        batch_size = len(sents)
        x, _ = to_input_tensor(self.language, sents, self.max_seq_len,
                               self.device)

        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        w_embed = self.w_embedding(x)
        h = w_embed + self.pos_embeddings(positions).expand_as(w_embed)

        for task, mha, feed_forward, lnorm_1, lnorm_2 in zip(
                self.tasks, self.mhas, self.ff, self.ln_1, self.ln_2):
            # for task, mha, lnorm_1 in zip(self.tasks, self.mhas, self.ln_1):
            # tasks = torch.tensor([task] * batch_size, device=self.device)

            # te = self.t_embedding(tasks).unsqueeze(-1)
            # ffe = self.ff_embedding(tasks).unsqueeze(-1)

            # seq, bs, embed
            x, _ = mha(h, h, h)
            # x = self.weight1 * x * self.attention(x, te)
            # x = self.weight1 * x * self.attention(w_embed, te)
            # x = self.weight1 * x
            # x = self.weight1 * self.attention(x, te)
            # x = self.attention(w_embed, te) + self.attention(x, te)
            # x = x + self.weight1 * self.attention(w_embed, te) * w_embed
            # if self.training:
            #     x = x * self.attention(x, te)
            # h = x + w_embed * self.attention(w_embed, te)
            h = x + h
            h = lnorm_1(h)

            # seq, bs, embed
            x = feed_forward(h)
            x = self.dropout(x)

            # x = self.weight2 * x * self.attention(x, ffe)
            # x = self.weight2 * x * self.attention(w_embed, ffe)
            # x = self.weight2 * x * self.attention(h, ffe)
            # x = self.weight2 * self.attention(x, ffe)
            # x = self.attention(w_embed, ffe) + self.attention(x, ffe) * x
            # x = x + self.weight2 * self.attention(w_embed, ffe) * w_embed
            # h = x + h * self.attention(h, ffe)
            # h = x + w_embed * self.attention(w_embed, ffe)
            # if self.training:
            #     x = x * self.attention(x, ffe)
            h = x + h
            h = lnorm_2(h)

        # bs, seq, embed_dim
        # h = h.transpose(0, 1)

        # BERT classification head
        # bs, embed_dim
        x = h[0, :, :]

        # m, _ = torch.max(h, -2)
        y = torch.sigmoid(self.classify(x)).squeeze()
        return y
Exemplo n.º 5
0
    def forward(self, context, word):
        # convert word lists to indices
        context_input = to_input_tensor(context, self.embeddings_df, is_contexts=True)
        word_input = to_input_tensor(word, self.embeddings_df, is_contexts=False)

        # get embeddings of all words in context and the word
        context_embed = self.embedding(context_input)
        word_embed = self.embedding(word_input)
        batch_size = word_embed.size(0)

        # take average of context words to combine
        # context_embed = (b, con_len, embed_dim)
        # input to lstm has to be (con_len, b, embed_dim)
        context_embed_permuted = context_embed.permute(1,0,2)
        _, (encoded_context, _) = self.encoder(context_embed_permuted)
        # output is the last cell's hidden output: (layers*dirs, b, hidden_size) = (2,b,h)
        # we want (b,2*h)
        encoded_context_permuted = encoded_context.permute(1,0,2).contiguous() # (b, 2, h)
        encoded_context_squashed = encoded_context_permuted.view(batch_size, -1) # (b, 2*h)

        # run through linear layer
        concat_features = torch.cat((word_embed, encoded_context_squashed), axis=1)
        linear_output = self.linear(concat_features)
        return linear_output
    def forward(self, sents):
        s_tensor, lengths = to_input_tensor(self.lang, sents, self.device)
        emb = self.embed(s_tensor)

        # pack + rnn sequence + unpack
        x = nn.utils.rnn.pack_padded_sequence(emb, lengths)
        output, hidden = self.gru(x)
        output, _ = nn.utils.rnn.pad_packed_sequence(output)

        # batch_size, seq_len, hidden_size
        output_batch = output.transpose(0, 1)

        # batch_size, hidden_size
        out_avg = output_batch.sum(dim=1)

        # batch_size, 1
        linear_out = self.l1(out_avg)
        out = torch.sigmoid(linear_out).squeeze(-1)

        return out