def forward(self, ids, lengths, word_embeddings, hidden): sorted_lengths = sorted(lengths, reverse=True) is_sorted = sorted_lengths == lengths is_varlen = sorted_lengths[0] != sorted_lengths[-1] if not is_sorted: true2sorted = sorted(range(len(lengths)), key=lambda x: -lengths[x]) sorted2true = sorted(range(len(lengths)), key=lambda x: true2sorted[x]) ids = torch.stack([ids[:, i] for i in true2sorted], dim=1) lengths = [lengths[i] for i in true2sorted] embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) if is_varlen: embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths) output, hidden = self.rnn(embeddings, hidden) if self.bidirectional: hidden = torch.stack([ torch.cat((hidden[2 * i], hidden[2 * i + 1]), dim=1) for i in range(self.layers) ]) if is_varlen: output = nn.utils.rnn.pad_packed_sequence(output)[0] if not is_sorted: hidden = torch.stack([hidden[:, i, :] for i in sorted2true], dim=1) output = torch.stack([output[:, i, :] for i in sorted2true], dim=1) return hidden, output
def forward(self, ids, lengths, word_embeddings, hidden, context, context_mask, prev_output, generator): embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) output = prev_output scores = [] for emb in embeddings.split(1): if self.input_feeding: input = torch.cat([emb.squeeze(0), output], 1) else: input = emb.squeeze(0) output, hidden = self.stacked_rnn(input, hidden) if not self.modify: l_context, l_query = self.attention(output, context, context_mask) output = self.tanh(l_context + l_query) else: self.attention_from.train(self.update_attnF) f_context, _ = self.attention_from(output, context, context_mask) l_context, l_query = self.attention(output, context, context_mask) output = self.tanh(l_context + l_query + f_context) output = self.dropout(output) scores.append(generator(output)) return torch.stack(scores), hidden, output
def forward(self, ids, lengths, word_embeddings, hidden): # hidden: (bidirectional * layers) * batch_size * hidden size sorted_lengths = sorted(lengths, reverse=True) is_sorted = sorted_lengths == lengths is_varlen = sorted_lengths[0] != sorted_lengths[-1] # if variable length, needs padding if not is_sorted: true2sorted = sorted(range(len(lengths)), key=lambda x: -lengths[x]) sorted2true = sorted(range(len(lengths)), key=lambda x: true2sorted[x]) ids = torch.stack([ids[:, i] for i in true2sorted], dim=1) lengths = [lengths[i] for i in true2sorted] embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) # matrix of word embeddings. sentence size * batch size * embedding dims (150) if is_varlen: embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths) output, hidden = self.rnn(embeddings, hidden) # hidden should be (bidirectional *layers) * batch size * hidden size/2 (because bidirectional) (301) # output should be sentence size * batch size * hidden size (602) if self.bidirectional: hidden = torch.stack([ torch.cat((hidden[2 * i], hidden[2 * i + 1]), dim=1) for i in range(self.layers) ]) if is_varlen: output = nn.utils.rnn.pad_packed_sequence(output)[0] if not is_sorted: hidden = torch.stack([hidden[:, i, :] for i in sorted2true], dim=1) output = torch.stack([output[:, i, :] for i in sorted2true], dim=1) return hidden, output
def forward(self, ids, lengths, word_embeddings, hidden, context, context_mask, prev_output, generator): embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) output = prev_output scores = [] for emb in embeddings.split(1): if self.input_feeding: input = torch.cat([emb.squeeze(0), output], 1) else: input = emb.squeeze(0) output, hidden = self.stacked_rnn(input, hidden) output = self.attention(output, context, context_mask) output = self.dropout(output) scores.append(generator(output)) return torch.stack(scores), hidden, output
def forward(self, ids, lengths, word_embeddings, hidden, pass_embedds=False): sorted_lengths = sorted(lengths, reverse=True) is_sorted = sorted_lengths == lengths is_varlen = sorted_lengths[0] != sorted_lengths[-1] if not is_sorted: true2sorted = sorted(range(len(lengths)), key=lambda x: -lengths[x]) sorted2true = sorted(range(len(lengths)), key=lambda x: true2sorted[x]) ids = torch.stack([ids[:, i] for i in true2sorted], dim=1) lengths = [lengths[i] for i in true2sorted] embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) passembeddings = embeddings # print("EMBEDDINGS TENSOR SIZE: ",passembeddings,passembeddings.requires_grad) # print("EMBEDDINGS ENCODER: ",word_embeddings.weight.requires_grad) if is_varlen: embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths) output, hidden = self.rnn(embeddings, hidden) if self.bidirectional: hidden = torch.stack([ torch.cat((hidden[2 * i], hidden[2 * i + 1]), dim=1) for i in range(self.layers) ]) if is_varlen: output = nn.utils.rnn.pad_packed_sequence(output)[0] if not is_sorted: hidden = torch.stack([hidden[:, i, :] for i in sorted2true], dim=1) output = torch.stack([output[:, i, :] for i in sorted2true], dim=1) if not pass_embedds: return hidden, output else: return hidden, output, passembeddings
def call(self, ids, lengths, word_embeddings, hidden): sorted_lengths = sorted(lengths, reverse=True) is_sorted = sorted_lengths == lengths is_varlen = sorted_lengths[0] != sorted_lengths[-1] if tf.reduce_sum(hidden) != 0: print('****need to pass hidden as initial_state in GRU****') sys.exit(-1) if not is_sorted: true2sorted = sorted(range(len(lengths)), key=lambda x: -lengths[x]) sorted2true = sorted(range(len(lengths)), key=lambda x: true2sorted[x]) ids = tf.stack([ids[:, i] for i in true2sorted], axis=1) lengths = [lengths[i] for i in true2sorted] embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) # if is_varlen: # embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths) embeddings = tf.transpose(embeddings, perm=[1, 0, 2]) embeddings_mask = tf.transpose(word_embeddings.compute_mask( data.word_ids(ids)), perm=[1, 0]) # if (embeddings_mask.numpy().sum(axis=1)==0).any(): # print(embeddings_mask.numpy().sum(axis=1)) output, hidden = self.rnn(embeddings, mask=embeddings_mask) hidden = tf.convert_to_tensor(hidden) output = tf.transpose(output, perm=[1, 0, 2]) if self.bidirectional: hidden = tf.squeeze(hidden, [1]) if not self.bidirectional: print('****Encoder not bidirectional was not Tested****') # if is_varlen: # TODO didn't touch that possibility # output = nn.utils.rnn.pad_packed_sequence(output)[0] if not is_sorted: # TODO didn't touch that possibility hidden = tf.stack([hidden[:, i, :] for i in sorted2true], axis=1) output = tf.stack([output[:, i, :] for i in sorted2true], axis=1) # print(output.shape, hidden.shape) # at the end, want torch.Size([3, 5, 600]), torch.Size([2, 5, 600]) return hidden, output
def call(self, ids, lengths, word_embeddings, hidden, context, context_mask, prev_output, generator): embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) output = prev_output scores = [] for emb in embeddings: if self.input_feeding: input = tf.concat([emb, output], 1) else: input = emb # output, hidden = self.stacked_rnn(input, hidden) input = tf.expand_dims(input, 1) output, hidden = self.stacked_rnn(input, initial_state=hidden) hidden = tf.convert_to_tensor(hidden) output, hidden = tf.squeeze(output, [1]), tf.squeeze(hidden, [1]) output = self.attention(output, context, context_mask) output = self.dropout(output) scores.append(generator(output)) return tf.stack(scores), hidden, output
def forward(self, ids, lengths, word_embeddings, hidden, context, context_mask, prev_output, generator,\ att_embeddings=None,pass_att=False,pass_context=False,detach_encoder=False,ncontrol = None): if ncontrol is None: embeddings = word_embeddings( data.word_ids(ids)) + self.special_embeddings( data.special_ids(ids)) else: embeddings = word_embeddings(data.word_ids(ids)) + self.special_embeddings(data.special_ids_nosos(ids)) + \ self.sosembeddings(data.sos_ids(ids).div(3).mul(ncontrol)) output = prev_output scores = [] find_cosine = True if att_embeddings is not None else False cosineloss = Variable(gpu(torch.FloatTensor(1).fill_(0))) att_scores = [] att_contexts = [] for emb in embeddings.split(1): if self.input_feeding: input = torch.cat([emb.squeeze(0), output], 1) else: input = emb.squeeze(0) output, hidden = self.stacked_rnn(input, hidden) output, att_weights, weighted_context = self.attention( output, context, context_mask, pass_weights=True, pass_context=True, detach_encoder=detach_encoder) output = self.dropout(output) score = generator(output) if pass_context: # print('weighted_context size:',weighted_context.size()) att_contexts.append(weighted_context) if find_cosine: # print("att_weights:",att_weights.requires_grad) att_embeddings = att_embeddings.detach() # print("att_embeddings:",att_embeddings.requires_grad) weighted_embedd = att_weights.unsqueeze(1).bmm( att_embeddings.transpose(0, 1)).squeeze(1) # print("score: ",score.exp()) # print("special_embeddings: ",self.special_embeddings.weight.size()) # print("word_embeddings: ",word_embeddings.weight.size()) weighted_predembedd = score.exp().unsqueeze(1).matmul( torch.cat([ self.special_embeddings.weight[1:], word_embeddings.weight[1:] ])).squeeze(1) # print("weighted_predembedd: ",weighted_predembedd.size()) att_cosine = torch.sum( F.cosine_similarity(weighted_embedd, weighted_predembedd)) cosineloss += att_cosine att_scores.append(att_weights) scores.append(score) if not pass_context: if not pass_att: if not find_cosine: return torch.stack(scores), hidden, output else: return torch.stack(scores), hidden, output, cosineloss else: att_scores = torch.stack(att_scores) if not find_cosine: return torch.stack(scores), hidden, output, att_scores else: return torch.stack( scores), hidden, output, cosineloss, att_scores else: att_contexts = torch.stack(att_contexts) # print('att_contexts size',att_contexts.size()) if not pass_att: if not find_cosine: return torch.stack(scores), hidden, output, att_contexts else: return torch.stack( scores), hidden, output, cosineloss, att_contexts else: att_scores = torch.stack(att_scores) if not find_cosine: return torch.stack( scores), hidden, output, att_scores, att_contexts else: return torch.stack( scores ), hidden, output, cosineloss, att_scores, att_contexts