def forward(self, sents): # Embed the sequence x, lengths = to_input_tensor(self.language, sents, self.device) x_embed = self.embedding(x) # RNN encoding x = nn.utils.rnn.pack_padded_sequence(x_embed, lengths) x, _ = self.gru(x) x, _ = nn.utils.rnn.pad_packed_sequence(x) x = x.transpose(0, 1) # get attention over RNN outputs I = torch.eye(max(lengths)) attn_mask = torch.stack([I] * self.batch_size) for i, l in zip(list(range(self.batch_size)), lengths): attn_mask[i, :, l:] = 1 attn_mask[i, l:, :] = 1 attn = self.attention(x, attn_mask, self.device) attn_vec = attn.unsqueeze(-1) * x.unsqueeze(1) attn_vec = attn_vec.sum(-2) attn_out = torch.cat([attn_vec, x], dim=-1) # max pool over sequence attn_out = attn_out.transpose(-1, -2) max_vec, _ = torch.max(attn_out, -1) max_vec = max_vec.unsqueeze(-2) # binary classification activ. y = torch.sigmoid(self.classify(max_vec)).squeeze() return y
def forward(self, context, word): # convert word lists to indices context_input = to_input_tensor(context, self.embeddings_df, is_contexts=True) word_input = to_input_tensor(word, self.embeddings_df, is_contexts=False) # get embeddings of all words in context and the word context_embed = self.embedding(context_input) word_embed = self.embedding(word_input) # take average of context words to combine # context_embed = (b, con_len, embed_dim) context_embed = torch.mean(context_embed, dim=1) # run through linear layer concat_features = torch.cat((word_embed, context_embed), axis=1) linear_output = self.linear(concat_features) return linear_output
def forward(self, x, lang, device): x, _ = to_input_tensor(lang, x, device) positions = torch.arange(len(x), device=x.device).unsqueeze(-1) h = self.token_embeddings(x) h = h + self.pos_embeddings(positions).expand_as(h) h = self.dropout(h) return h, len(x)
def forward(self, sents): batch_size = len(sents) x, _ = to_input_tensor(self.language, sents, self.max_seq_len, self.device) positions = torch.arange(len(x), device=x.device).unsqueeze(-1) w_embed = self.w_embedding(x) h = w_embed + self.pos_embeddings(positions).expand_as(w_embed) for task, mha, feed_forward, lnorm_1, lnorm_2 in zip( self.tasks, self.mhas, self.ff, self.ln_1, self.ln_2): # for task, mha, lnorm_1 in zip(self.tasks, self.mhas, self.ln_1): # tasks = torch.tensor([task] * batch_size, device=self.device) # te = self.t_embedding(tasks).unsqueeze(-1) # ffe = self.ff_embedding(tasks).unsqueeze(-1) # seq, bs, embed x, _ = mha(h, h, h) # x = self.weight1 * x * self.attention(x, te) # x = self.weight1 * x * self.attention(w_embed, te) # x = self.weight1 * x # x = self.weight1 * self.attention(x, te) # x = self.attention(w_embed, te) + self.attention(x, te) # x = x + self.weight1 * self.attention(w_embed, te) * w_embed # if self.training: # x = x * self.attention(x, te) # h = x + w_embed * self.attention(w_embed, te) h = x + h h = lnorm_1(h) # seq, bs, embed x = feed_forward(h) x = self.dropout(x) # x = self.weight2 * x * self.attention(x, ffe) # x = self.weight2 * x * self.attention(w_embed, ffe) # x = self.weight2 * x * self.attention(h, ffe) # x = self.weight2 * self.attention(x, ffe) # x = self.attention(w_embed, ffe) + self.attention(x, ffe) * x # x = x + self.weight2 * self.attention(w_embed, ffe) * w_embed # h = x + h * self.attention(h, ffe) # h = x + w_embed * self.attention(w_embed, ffe) # if self.training: # x = x * self.attention(x, ffe) h = x + h h = lnorm_2(h) # bs, seq, embed_dim # h = h.transpose(0, 1) # BERT classification head # bs, embed_dim x = h[0, :, :] # m, _ = torch.max(h, -2) y = torch.sigmoid(self.classify(x)).squeeze() return y
def forward(self, context, word): # convert word lists to indices context_input = to_input_tensor(context, self.embeddings_df, is_contexts=True) word_input = to_input_tensor(word, self.embeddings_df, is_contexts=False) # get embeddings of all words in context and the word context_embed = self.embedding(context_input) word_embed = self.embedding(word_input) batch_size = word_embed.size(0) # take average of context words to combine # context_embed = (b, con_len, embed_dim) # input to lstm has to be (con_len, b, embed_dim) context_embed_permuted = context_embed.permute(1,0,2) _, (encoded_context, _) = self.encoder(context_embed_permuted) # output is the last cell's hidden output: (layers*dirs, b, hidden_size) = (2,b,h) # we want (b,2*h) encoded_context_permuted = encoded_context.permute(1,0,2).contiguous() # (b, 2, h) encoded_context_squashed = encoded_context_permuted.view(batch_size, -1) # (b, 2*h) # run through linear layer concat_features = torch.cat((word_embed, encoded_context_squashed), axis=1) linear_output = self.linear(concat_features) return linear_output
def forward(self, sents): s_tensor, lengths = to_input_tensor(self.lang, sents, self.device) emb = self.embed(s_tensor) # pack + rnn sequence + unpack x = nn.utils.rnn.pack_padded_sequence(emb, lengths) output, hidden = self.gru(x) output, _ = nn.utils.rnn.pad_packed_sequence(output) # batch_size, seq_len, hidden_size output_batch = output.transpose(0, 1) # batch_size, hidden_size out_avg = output_batch.sum(dim=1) # batch_size, 1 linear_out = self.l1(out_avg) out = torch.sigmoid(linear_out).squeeze(-1) return out