def forward(self, data: BatchHolder): if self.use_attention: output = data.hidden mask = data.masks attn = self.attention(data) if self.use_regulariser_attention: data.reg_loss = 5 * self.regularizer_attention.regularise( data.seq, output, mask, attn) if isTrue(data, 'detach'): attn = attn.detach() if isTrue(data, 'permute'): permutation = data.generate_permutation() attn = torch.gather(attn, -1, torch.LongTensor(permutation).to(device)) context = (attn.unsqueeze(-1) * output).sum(1) data.attn = attn else: context = data.last_hidden predict = self.decode(context) data.predict = predict
def forward(self, data): #input_seq = (B, L), hidden : (B, L, H), masks : (B, L) input_seq, hidden, masks = data.seq, data.hidden, data.masks lengths = data.lengths attn1 = nn.Tanh()(self.attn1(hidden)) attn2 = self.attn2(attn1).squeeze(-1) data.attn_logit = attn2 attn = masked_softmax(attn2, masks) inf = 1e9 if isTrue(data, 'erase_max'): attn2[:, attn.max(dim=1)[1]] = -1 * inf attn = masked_softmax(attn2, masks) if isTrue(data, 'erase_random'): rand_len = (torch.rand(size=lengths.size()).to(device) * (lengths).float()).long() attn2[:, rand_len] = -1 * inf attn = masked_softmax(attn2, masks) if isTrue(data, 'erase_given'): attn2[:, data.erase_attn] = -1 * inf attn = masked_softmax(attn2, masks) return attn
def get_context(self, data: BatchHolder): output = data.hidden mask = data.masks attn = self.attention(data.seq, output, mask) if self.use_regulariser_attention: data.reg_loss = 5 * self.regularizer_attention.regularise(data.seq, output, mask, attn) if isTrue(data, 'detach'): attn = attn.detach() if isTrue(data, 'permute'): permutation = data.generate_permutation() attn = torch.gather(attn, -1, torch.LongTensor(permutation).to(device)) return(attn.unsqueeze(-1) * output).sum(1)
def forward(self, data, is_embds=False): seq = data.seq lengths = data.lengths if (len(seq.shape) == 2): # look up embeds table embedding = self.embedding(seq) else: # skip lookup embedding = data.seq embedding.requires_grad = True batch_size = embedding.size()[0] self.init_states = [[ LSTMState(cx=torch.zeros(batch_size, self.hidden_size), hx=torch.zeros(batch_size, self.hidden_size)) for _ in range(2) ] for _ in range(self.num_layers)] output, cell_output, output_state = self.rnn(embedding, lengths) h, c = output_state[self.num_layers - 1] data.hidden = output data.cell_state = cell_output data.last_hidden = h data.embedding = embedding if isTrue(data, 'keep_grads'): data.embedding.retain_grad() data.hidden.retain_grad()
def forward(self, data): seq = data.seq lengths = data.lengths if (len(data.seq.shape) == 2): # look up embeds table embedding = self.embedding(seq) else: #skip lookup embedding = data.seq.type(torch.FloatTensor) embedding.requires_grad = True packseq = nn.utils.rnn.pack_padded_sequence(embedding, lengths, batch_first=True, enforce_sorted=False) output, (h, c) = self.rnn(packseq) output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True, padding_value=0) data.hidden = output data.last_hidden = torch.cat([h[0], h[1]], dim=-1) if isTrue(data, 'keep_grads'): data.embedding = embedding data.embedding.retain_grad() data.hidden.retain_grad()
def forward(self, data): #input_seq = (B, L), hidden : (B, L, H), masks : (B, L) input_seq, hidden, masks = data.seq, data.hidden, data.masks lengths = data.lengths b, l, h = hidden.shape proj_hiddens = torch.unbind(self.projection(hidden).view( b, l, self.heads, h), dim=2) # heads - [ (B,L,H), ...] data.proj_hiddens = proj_hiddens attns = [] erase_head = getattr(data, 'erase_head', -1) for idx, hidden in enumerate(proj_hiddens): hidden = hidden.to(device) attn1 = nn.Tanh()(self.attn1[idx](hidden)) attn2 = self.attn2[idx](attn1).squeeze(-1) data.attn_logit = attn2 attn = masked_softmax(attn2, masks) inf = 1e9 if isTrue(data, 'erase_max') and (erase_head == idx): attn2[:, attn.max(dim=1)[1]] = -1 * inf attn = masked_softmax(attn2, masks) if isTrue(data, 'erase_random') and (erase_head == idx): rand_len = (torch.rand(size=lengths.size()).to(device) * (lengths).float()).long() attn2[:, rand_len] = -1 * inf attn = masked_softmax(attn2, masks) if isTrue(data, 'erase_given') and (erase_head == idx): attn2[:, data.erase_attn] = -1 * inf attn = masked_softmax(attn2, masks) attns.append(attn) return attns
def forward(self, data: BatchMultiHolder): if self.use_attention: Poutput = data.P.hidden #(B, H, L) Qoutput = data.Q.last_hidden #(B, H) mask = data.P.masks attn = self.attention(data.P.seq, Poutput, Qoutput, mask, data) #(B, L) if isTrue(data, 'detach'): attn = attn.detach() if isTrue(data, 'permute'): permutation = data.P.generate_permutation() attn = torch.gather(attn, -1, torch.LongTensor(permutation).to(device)) context = (attn.unsqueeze(-1) * Poutput).sum(1) #(B, H) data.attn = attn else: context = data.P.last_hidden predict = self.decode(context, Qoutput, data.entity_mask) data.predict = predict
def forward(self, input_seq, hidden_1, hidden_2, masks, data): #input_seq = (B, L), hidden : (B, L, H), masks : (B, L) attn1 = nn.Tanh()(self.attn1p(hidden_1) + self.attn1q(hidden_2).unsqueeze(1)) attn2 = self.attn2(attn1).squeeze(-1) attn = masked_softmax(attn2, masks) inf = 1e9 if isTrue(data, 'erase_given'): attn2[:, data.erase_attn] = -1 * inf attn = masked_softmax(attn2, masks) return attn
def forward(self, data): # data is of type BatchHolder, data.seq => batch of 32 Tensors, data.length => length of each tensor # self.embedding is of type nn.Embedding, with no hooks, with initialised with pre_embed weights to convert seq to embed if (len(data.seq.shape) == 2): # if true, get embeddings of data.seq seq = data.seq lengths = data.lengths #embedding is tensor of shape [32, 39, 300] => [bsize, max_length, embd_size] # print("looking up embds!") embedding = self.embedding(seq) #(B, L, E) #above function turns int64 data to float64 data, must replicate the same step in skip_embds flow output = self.activation(self.projection( embedding)) # Z = tanh(WX + B), Z is [bsize, hidden_size] # output is [32,maxlen,128], embedding is [32,maxlen, 300] h = output.mean(1) # take the mean of all bsize hidden states # h is [32, 128] else: # if false, directly use embeddings and find output # print("skipping embds lookup!") embedding = data.seq.type(torch.FloatTensor) embedding.requires_grad = True # ensures that embeddings computation tape is tracked # convert embds to h states output = self.activation(self.projection( embedding)) # Z = tanh(WX + B), Z is [bsize, hidden_size] h = output.mean(1) # take the mean of all bsize hidden states data.hidden = output data.last_hidden = h if isTrue(data, 'keep_grads'): data.embedding = embedding data.embedding.retain_grad() data.hidden.retain_grad()
def forward(self, data): seq = data.seq #(B, L) lengths = data.lengths #(B, ) masks = data.masks #(B, L) embedding = self.embedding(seq) #(B, L, E) seq_t = embedding.transpose(1, 2) outputs = [ self.convolutions[i](seq_t) for i in sorted(self.convolutions.keys()) ] output = self.activation(torch.cat(outputs, dim=1)) output = output * (1 - masks.unsqueeze(1)).float() h = nn.functional.max_pool1d(output, kernel_size=output.size(-1)).squeeze(-1) data.hidden = output.transpose(1, 2) data.last_hidden = h if isTrue(data, 'keep_grads'): data.embedding = embedding data.embedding.retain_grad() data.hidden.retain_grad()