def batchify(self, batch, device): examples = list() sentence_len_s = [len(tup[1]) for tup in batch] sentence_len_t = [len(tup[2]) for tup in batch] max_sentence_len_s = max(sentence_len_s) max_sentence_len_t = max(sentence_len_t) event1_lens = [len(tup[2]) for tup in batch] event2_lens = [len(tup[3]) for tup in batch] sentences_s, sentences_t, event1, event2, data_y = list(), list( ), list(), list(), list() for data in batch: sentences_s.append(data[1]) sentences_t.append(data[2]) event1.append(data[3]) event2.append(data[4]) y = self.y_label[data[5]] if data[5] in self.y_label else 0 data_y.append(y) examples.append(data) sentences_s = list( map(lambda x: pad_sequence_to_length(x, max_sentence_len_s), sentences_s)) sentences_t = list( map(lambda x: pad_sequence_to_length(x, max_sentence_len_t), sentences_t)) event1 = list(map(lambda x: pad_sequence_to_length(x, 5), event1)) event2 = list(map(lambda x: pad_sequence_to_length(x, 5), event2)) mask_sentences_s = get_mask_from_sequence_lengths( torch.LongTensor(sentence_len_s), max_sentence_len_s) mask_sentences_t = get_mask_from_sequence_lengths( torch.LongTensor(sentence_len_t), max_sentence_len_t) mask_even1 = get_mask_from_sequence_lengths( torch.LongTensor(event1_lens), 5) mask_even2 = get_mask_from_sequence_lengths( torch.LongTensor(event2_lens), 5) return [ torch.LongTensor(sentences_s).to(device), mask_sentences_s.to(device), torch.LongTensor(sentences_t).to(device), mask_sentences_t.to(device), torch.LongTensor(event1).to(device), mask_even1.to(device), torch.LongTensor(event2).to(device), mask_even2.to(device), torch.LongTensor(data_y).to(device), examples ]
def predict_crf(self, hs, ls=None, lengths=None, calculate_loss=True, decode=False): device = hs.device if lengths is None: lengths = torch.tensor([h.shape[0] for h in hs], device=device) mask = get_mask_from_sequence_lengths(lengths, max_length=max(lengths)) if not decode or self.crf_top_k == 1: ps = self.crf.viterbi_tags(hs, mask) ps, score = zip(*ps) else: ps = [] psks = self.crf.viterbi_tags(hs, mask, top_k=self.crf_top_k) for psk in psks: psk, score = zip(*psk) ps.append(psk) if calculate_loss: log_likelihood = self.crf(hs, ls, mask) loss = -1 * log_likelihood / len(lengths) else: loss = torch.tensor(np.array(0), dtype=torch.float, device=device) return loss, ps
def _encoding(self, word_inputs: torch.Tensor, chars_inputs: torch.Tensor, lengths: torch.Tensor,): # NOTE: there is no dropout on the last layer. start = time.time() embedded_tokens = self.token_embedder(word_inputs, chars_inputs) self.token_embedding_time += time.time() - start start = time.time() mask = get_mask_from_sequence_lengths(lengths, lengths.max()) if self.add_sentence_boundary: embedded_tokens_with_boundary, mask_with_boundary = \ self._add_sentence_boundary(embedded_tokens, mask) encoded_tokens = self.encoder(embedded_tokens_with_boundary, mask_with_boundary) self.encoding_time += time.time() - start return encoded_tokens[:, :, 1:-1, :], embedded_tokens, mask elif self.add_sentence_boundary_ids: encoded_tokens = self.encoder(embedded_tokens, mask) self.encoding_time += time.time() - start return self._remove_sentence_boundaries(encoded_tokens, embedded_tokens, mask) else: encoded_tokens = self.encoder(embedded_tokens, mask) self.encoding_time += time.time() - start return encoded_tokens, embedded_tokens, mask
def batchify(self, batch): cur_batch_size = len(batch) encode_sequence_ipt = [] decode_sequence_ipt = [] for instance_ind in range(cur_batch_size): instance = batch[instance_ind] encode_sequence_ipt.append(instance[:] + [self.word2idx['<END>']]) decode_sequence_ipt.append([self.word2idx['<BOS>']] + instance[:]) lens = [len(tup) for tup in encode_sequence_ipt] max_len = max(lens) encode_sequence_ipt = list( map(lambda x: pad_sequence_to_length(x, max_len), encode_sequence_ipt)) decode_sequence_ipt = list( map(lambda x: pad_sequence_to_length(x, max_len), decode_sequence_ipt)) mask = get_mask_from_sequence_lengths(torch.LongTensor(lens), max_len) encode_sequence_ipt = torch.LongTensor(encode_sequence_ipt).to( self.device) decode_sequence_ipt = torch.LongTensor(decode_sequence_ipt).to( self.device) mask = mask.to(self.device) return [encode_sequence_ipt, decode_sequence_ipt, mask]
def forward(self, inputs: Dict[str, torch.Tensor], targets: torch.Tensor): # input_: (batch_size, seq_len) embedded_input = {} for name, fn in self.input_layers.items(): input_ = inputs[name] embedded_input[name] = fn(input_) encoded_inputs = [] for encoder_ in self.input_encoders: ordered_names = encoder_.get_ordered_names() args_ = {name: embedded_input[name] for name in ordered_names} encoded_inputs.append(self.input_dropout(encoder_(args_))) encoded_inputs = torch.cat(encoded_inputs, dim=-1) lengths = inputs['length'] mask = get_mask_from_sequence_lengths(lengths, lengths.max()) encoded_inputs = self.encoder(encoded_inputs, mask) # encoded_input_: (batch_size, seq_len, dim) encoded_inputs = self.dropout(encoded_inputs) output, loss = self.classify_layer(encoded_inputs, targets) return output, loss
def test_get_mask_from_sequence_lengths(self): sequence_lengths = Variable(torch.LongTensor([4, 3, 1, 4, 2])) mask = util.get_mask_from_sequence_lengths(sequence_lengths, 5).data.numpy() assert_almost_equal(mask, [[1, 1, 1, 1, 0], [1, 1, 1, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 1, 0, 0, 0]])
def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor): embs = [] if self.word_embedder is not None: word_inputs = torch.autograd.Variable(word_inputs, requires_grad=False) embed_words = self.word_embedder(word_inputs) embs.append(embed_words) if self.char_embedder is not None: char_inputs, char_lengths = char_inputs batch_size, seq_len = char_lengths.size()[:2] char_inputs = char_inputs.view(batch_size * seq_len, -1) char_lengths = char_lengths.view(batch_size * seq_len) # (batch_size * seq_len, max_char, dim) embeded_chars = self.char_embedder(char_inputs) mask = get_mask_from_sequence_lengths( char_lengths, char_lengths.max()).unsqueeze(-1) float_mask = mask.float() embeded_chars = (embeded_chars * float_mask).sum(dim=-2) embs.append(embeded_chars.view(batch_size, seq_len, -1)) token_embedding = torch.cat(embs, dim=2) return self.projection(token_embedding)
def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor): embs = [] if self.word_embedder is not None: word_inputs = torch.autograd.Variable(word_inputs, requires_grad=False) if self.use_cuda: word_inputs = word_inputs.cuda() word_emb = self.word_embedder(word_inputs) embs.append(word_emb) if self.char_embedder is not None: char_inputs, char_lengths = char_inputs batch_size, seq_len = char_lengths.size() char_inputs = char_inputs.view(batch_size * seq_len, -1) char_lengths = char_lengths.view(-1) char_mask = get_mask_from_sequence_lengths(char_lengths, char_lengths.max()) embeded_char_inputs = self.char_embedder(char_inputs) encoded_char_outputs, _ = self.char_encoder(embeded_char_inputs) char_attentions = masked_softmax( self.char_attention(encoded_char_outputs).squeeze(-1), char_mask, dim=-1) encoded_char_outputs = torch.bmm( encoded_char_outputs.permute(0, 2, 1), char_attentions.unsqueeze(-1)) encoded_char_outputs = encoded_char_outputs.view( batch_size, seq_len, -1) embs.append(encoded_char_outputs) token_embedding = torch.cat(embs, dim=2) return self.projection(token_embedding)
def forward(self, question_and_answers: Dict[str, torch.LongTensor], video_features: Optional[torch.Tensor] = None, frame_count: Optional[torch.LongTensor] = None, label: Optional[torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]: # This supposes a fixed number of answers, by grabbing any of the dict values available. num_answers = list(question_and_answers.values())[0].shape[1] if video_features is None: video_features_mask = None else: video_features = self._expand_to_num_answers(video_features, num_answers) video_features_mask = self._expand_to_num_answers( util.get_mask_from_sequence_lengths(frame_count, video_features.shape[2]), num_answers) embedded_question_and_answers = self.text_field_embedder(question_and_answers, num_wrapping_dims=1) question_and_answers_mask = util.get_text_field_mask(question_and_answers, num_wrapping_dims=1) scores = self.answer_scorer(video_features=video_features, video_features_mask=video_features_mask, embedded_question_and_answers=embedded_question_and_answers, question_and_answers_mask=question_and_answers_mask) output_dict = {'scores': scores} if label is not None: output_dict['loss'] = self.loss(scores, label) for metric in self.metrics.values(): metric(scores, label) return output_dict
def test_get_mask_from_sequence_lengths(self): sequence_lengths = torch.LongTensor([4, 3, 1, 4, 2]) mask = util.get_mask_from_sequence_lengths(sequence_lengths, 5).data.numpy() assert_almost_equal(mask, [[1, 1, 1, 1, 0], [1, 1, 1, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 1, 0, 0, 0]])
def batchify(self, batch, device): sentence_len_s = [len(tup[0][1]) for tup in batch] max_sentence_len_s = self.max_len event1_lens = [len(tup[0][2]) for tup in batch] event2_lens = [len(tup[0][3]) for tup in batch] sentences_s, sentences_s_mask, event1, event2, data_y = list(), list( ), list(), list(), list() for data, data_mask in batch: sentences_s.append(data[1]) sentences_s_mask.append(data_mask[1]) event1.append(data[3]) event2.append(data[4]) y = self.y_label[data[5]] data_y.append(y) sentences_s = list( map(lambda x: pad_sequence_to_length(x, max_sentence_len_s), sentences_s)) sentences_s_mask = list( map(lambda x: pad_sequence_to_length(x, max_sentence_len_s), sentences_s_mask)) event1 = list(map(lambda x: pad_sequence_to_length(x, 5), event1)) event2 = list(map(lambda x: pad_sequence_to_length(x, 5), event2)) mask_sentences_s = get_mask_from_sequence_lengths( torch.LongTensor(sentence_len_s), max_sentence_len_s) mask_even1 = get_mask_from_sequence_lengths( torch.LongTensor(event1_lens), 5) mask_even2 = get_mask_from_sequence_lengths( torch.LongTensor(event2_lens), 5) return [ torch.LongTensor(sentences_s).to(device), mask_sentences_s.to(device), torch.LongTensor(sentences_s_mask).to(device), torch.LongTensor(event1).to(device), mask_even1.to(device), torch.LongTensor(event2).to(device), mask_even2.to(device), torch.LongTensor(data_y).to(device) ]
def _encode(self, source_features: torch.FloatTensor, source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) if self._cnn is not None: source_features, source_lengths = self._cnn( source_features, source_lengths) source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)) if self._conv_lstm is not None: source_features = self._conv_lstm(source_features, source_mask) if not isinstance(self._encoder, AWDRNN): encoder_outputs = self._encoder(source_features, source_mask) else: encoder_outputs, _, source_lengths = self._encoder( source_features, source_lengths, self._output_layer_num) source_mask = util.get_mask_from_sequence_lengths( source_lengths, encoder_outputs.size(1)) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}
def position_mask(self) -> torch.Tensor: """ Which elements are actual words in the sentence? :return: shape (batch_size, input_seq_len) """ if not hasattr(self, "lengths"): self.lengths = torch.tensor([len(s) + 1 for s in self.sentences], device=get_device_id(self.constants)) self.max_len = max(len(s) for s in self.sentences) + 1 return get_mask_from_sequence_lengths(self.lengths, self.max_len)
def _encode(self, source_features: torch.FloatTensor, source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) encoder_outputs, _, source_lengths = self._encoder( source_features, source_lengths) source_mask = util.get_mask_from_sequence_lengths( source_lengths, torch.max(source_lengths)) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, }
def _get_phn_level_representations( self, features: torch.FloatTensor, mask: torch.BoolTensor, phn_log_probs: torch.Tensor) -> Dict[str, torch.Tensor]: phn_enc_outs, segment_lengths = averaging_tensor_of_same_label( features, phn_log_probs, mask=mask) state = { "encoder_outputs": phn_enc_outs, "source_mask": util.get_mask_from_sequence_lengths(segment_lengths, int(max(segment_lengths))) } return state
def _get_action_embeddings( state: NlvrDecoderState, actions_to_embed: List[List[int]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ This method is identical to ``WikiTablesDecoderStep._get_action_embeddings`` Returns an embedded representation for all actions in ``actions_to_embed``, using the state in ``NlvrDecoderState``. Parameters ---------- state : ``NlvrDecoderState`` The current state. We'll use this to get the global action embeddings. actions_to_embed : ``List[List[int]]`` A list of _global_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. Returns ------- action_embeddings : ``torch.FloatTensor`` An embedded representation of all of the given actions. Shape is ``(group_size, num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of considered actions for any group element. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. """ num_actions = [len(action_list) for action_list in actions_to_embed] max_num_actions = max(num_actions) padded_actions = [ common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in actions_to_embed ] # Shape: (group_size, num_actions) action_tensor = Variable( state.score[0].data.new(padded_actions).long()) # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim). # We want to select from state.action_embeddings using `action_tensor` to get a tensor of # shape (group_size, num_actions, action_embedding_dim). Unfortunately, the index_select # functions in nn.util don't do this operation. So we'll do some reshapes and do the # index_select ourselves. group_size = len(state.batch_indices) action_embedding_dim = state.action_embeddings.size(-1) flattened_actions = action_tensor.view(-1) flattened_action_embeddings = state.action_embeddings.index_select( 0, flattened_actions) action_embeddings = flattened_action_embeddings.view( group_size, max_num_actions, action_embedding_dim) sequence_lengths = Variable(action_embeddings.data.new(num_actions)) action_mask = nn_util.get_mask_from_sequence_lengths( sequence_lengths, max_num_actions) return action_embeddings, action_mask
def forward(self, input_: Tuple[torch.Tensor, torch.Tensor]): chars, lengths = input_ batch_size, seq_len, max_chars = chars.size() chars = chars.view(batch_size * seq_len, -1) lengths = lengths.view(batch_size * seq_len) mask = get_mask_from_sequence_lengths(lengths, max_chars) chars = torch.autograd.Variable(chars, requires_grad=False) embeded_chars = self.embeddings(chars) output, _ = self.encoder_(embeded_chars) attentions = masked_softmax(self.attention(output).squeeze(-1), mask, dim=-1) output = torch.bmm(output.permute(0, 2, 1), attentions.unsqueeze(-1)) return self.projection(output.view(batch_size, seq_len, -1))
def forward(self, input_: Tuple[torch.Tensor, torch.Tensor]): chars, lengths = input_ batch_size, seq_len, max_chars = chars.size() chars = chars.view(batch_size * seq_len, -1) lengths = lengths.view(batch_size * seq_len) mask = get_mask_from_sequence_lengths(lengths, max_chars) chars = torch.autograd.Variable(chars, requires_grad=False) embeded_chars = self.embeddings(chars) output, _ = self.encoder_(embeded_chars) output = self.attention(output, mask).sum(dim=-2) return output.view(batch_size, seq_len, -1)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None, gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN): # Precomputing of the max_context_length is important # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible. sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) joint_seq_logits = self.qa_outputs(sequence_output) context_logits, context_length = span_util.span_select( joint_seq_logits, context_span, max_context_length) context_mask = allen_util.get_mask_from_sequence_lengths( context_length, max_context_length) # The following line is from AllenNLP bidaf. start_logits = allen_util.replace_masked_values( context_logits[:, :, 0], context_mask, -1e18) # B, T, 2 end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18) if mode == BertSpan.ForwardMode.TRAIN: assert gt_span is not None gt_start = gt_span[:, 0] # gt_span: [B, 2] gt_end = gt_span[:, 1] start_loss = nll_loss( allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1)) end_loss = nll_loss( allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1)) loss = start_loss + end_loss return loss else: return start_logits, end_logits, context_length
def _get_action_embeddings(state: NlvrDecoderState, actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: """ This method is identical to ``WikiTablesDecoderStep._get_action_embeddings`` Returns an embedded representation for all actions in ``actions_to_embed``, using the state in ``NlvrDecoderState``. Parameters ---------- state : ``NlvrDecoderState`` The current state. We'll use this to get the global action embeddings. actions_to_embed : ``List[List[int]]`` A list of _global_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. Returns ------- action_embeddings : ``torch.FloatTensor`` An embedded representation of all of the given actions. Shape is ``(group_size, num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of considered actions for any group element. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. """ num_actions = [len(action_list) for action_list in actions_to_embed] max_num_actions = max(num_actions) padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in actions_to_embed] # Shape: (group_size, num_actions) action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long) # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim). # We want to select from state.action_embeddings using `action_tensor` to get a tensor of # shape (group_size, num_actions, action_embedding_dim). Unfortunately, the index_select # functions in nn.util don't do this operation. So we'll do some reshapes and do the # index_select ourselves. group_size = len(state.batch_indices) action_embedding_dim = state.action_embeddings.size(-1) flattened_actions = action_tensor.view(-1) flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions) action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim) sequence_lengths = action_embeddings.new_tensor(num_actions) action_mask = nn_util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions) return action_embeddings, action_mask
def forward( self, word_embs: torch.Tensor, # Float[Batch, Word, Embedding] mask: torch.Tensor, # Byte[Batch, Word] left: bool = False, ) -> torch.Tensor: # Float[Batch, Embedding] device = word_embs.device lengths = mask.long().sum(dim=1).cpu().numpy() # Long[Batch] sorted_lengths, idx_sort = np.sort(lengths)[::-1], np.argsort( -lengths) # sort descendingly w.r.t. length of sequence idx_unsort = np.argsort(idx_sort) # get inverse permutation x_sorted = word_embs.index_select( 0, torch.from_numpy(idx_sort).to( device=device)) # Float[Batch, Word, Embedding] x_packed = torch.nn.utils.rnn.pack_padded_sequence( x_sorted, lengths=sorted_lengths.copy(), batch_first=True) y_output, _ = self.lstm(x_packed) y_output, _ = torch.nn.utils.rnn.pad_packed_sequence(y_output, batch_first=True) y_unsorted = y_output.index_select( 0, torch.from_numpy(idx_unsort).to( device=device)) # Float[Batch, Word, Encoding] y_unsorted_inf = torch.where( get_mask_from_sequence_lengths( torch.tensor(lengths).to(device=device), max_length=y_unsorted.size(1)).unsqueeze(dim=2).expand( -1, -1, y_unsorted.size(2)), y_unsorted, torch.ones_like(y_unsorted) * float('-inf')) pooled, _ = torch.max(y_unsorted_inf, dim=1) output = self.final_dropout(pooled) if self.with_linear_transform and left: output = self.linear(output) return output
def test_average_tensor_of_same_labels(self): batch_size = 10 max_len = 16 feat_dim = 32 label_dim = 4 for _ in range(10): phn_logits = torch.randn(batch_size, max_len, label_dim) phn_log_probs = F.log_softmax(phn_logits) lengths = torch.randint(label_dim, (batch_size, )) mask = get_mask_from_sequence_lengths(lengths, int(max(lengths))) enc_outs = torch.randn((batch_size, max_len, feat_dim)) _, max_ids = phn_log_probs.max(dim=-1) phn_enc_out_list = [] for b in range(batch_size): count = 1 phn_enc_out = [] feat = enc_outs[b, 0].clone() prev_id = None for t, max_id in enumerate(max_ids[b]): if prev_id is None: pass elif max_id == prev_id: feat += enc_outs[b, t].clone() count += 1 else: phn_enc_out.append(feat.div(count)) feat = enc_outs[b, t].clone() count = 1 prev_id = max_id phn_enc_out.append(feat / float(count)) phn_enc_out_list.append(phn_enc_out) phn_max_len = len(max(phn_enc_out_list, key=lambda x: len(x))) phn_enc_outs = enc_outs.new_zeros(batch_size, phn_max_len, feat_dim) for idx, phn_enc_out in enumerate(phn_enc_out_list): phn_enc_outs[idx, :len(phn_enc_out)] = torch.stack(phn_enc_out) len_phn_enc_outs, _ = averaging_tensor_of_same_label( enc_outs, phn_log_probs, lengths) torch.testing.assert_allclose(phn_enc_outs, len_phn_enc_outs) mask_phn_enc_outs, _ = averaging_tensor_of_same_label( enc_outs, phn_log_probs, mask) torch.testing.assert_allclose(phn_enc_outs, mask_phn_enc_outs)
def pad_contextualizer_output(seqs: List[torch.Tensor]): """ Takes the output of a contextualizer, a list (of length batch_size) of Tensors with shape (seq_len, repr_dim), and produces a padded Tensor with these possibly-variable length items of shape (batch_size, seq_len, repr_dim) Returns ------- padded_representations: torch.FloatTensor FloatTensor of shape (batch_size, seq_len, repr_dim) with 0 padding. mask: torch.FloatTensor A (batch_size, max_length) mask with 1's in positions without padding and 0's in positions with padding. """ lengths = [len(s) for s in seqs] max_len = max(lengths) mask = get_mask_from_sequence_lengths(seqs[0].new_tensor(lengths), max_len) return torch.stack([ torch.cat([s, s.new_zeros(max_len - len_, s.size(-1))], dim=0) for s, len_ in zip(seqs, lengths) ]), mask
def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], world: List[SpiderWorld], schema: Dict[str, torch.LongTensor], action_sequence: torch.LongTensor = None ) -> Dict[str, torch.Tensor]: max_len_entities = max( [len(w.db_context.knowledge_graph.entities) for w in world]) batch_size = len(world) device = utterance['tokens'].device oracle_entities = [] oracle_relevance_score = None if action_sequence is not None: # we want oracle supervision for which entities should be in the query, for the loss calculation for batch_index, batch_actions in enumerate( action_sequence.squeeze(-1)): oracle_entities.append( set([ valid_actions[batch_index][action][0].split( ' -> ')[1].strip('["]') for action in batch_actions if not valid_actions[batch_index][action][1] and action >= 0 ])) oracle_relevance_score = [ pad_sequence_to_length(w.get_oracle_relevance_score(oe), max_len_entities) for w, oe in zip(world, oracle_entities) ] oracle_relevance_score = torch.tensor(oracle_relevance_score, dtype=torch.float, device=device) initial_state = self._get_initial_state(utterance, world, schema, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) action_mask = action_sequence != self._action_padding_index else: action_mask = None self.graph_mask = util.get_mask_from_sequence_lengths( torch.tensor([len(w.entities_names) for w in world], device=device), max_len_entities).float() loss = torch.tensor([0]).float().to(device) if action_sequence is not None: graph_loss = torch.nn.functional.binary_cross_entropy_with_logits( self.predicted_relevance_logits.squeeze(-1), oracle_relevance_score, reduction='none') graph_loss = (graph_loss * self.graph_mask).sum() / self.graph_mask.sum() graph_loss *= self._graph_loss_lambda loss += graph_loss if self.training: try: decode_output = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1))) query_loss = decode_output['loss'] except ZeroDivisionError: return { 'loss': Parameter(torch.tensor([0]).float()).to( action_sequence.device) } loss += ((1 - self._graph_loss_lambda) * query_loss) return {'loss': loss} else: if action_sequence is not None and action_sequence.size(1) > 1: try: query_loss = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))['loss'] loss += query_loss except ZeroDivisionError: pass outputs: Dict[str, Any] = {'loss': loss} num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) self._compute_validation_outputs(valid_actions, best_final_states, world, action_sequence, outputs) return outputs
def span_pruner(embeddings, scores, mask, seq_length, spans_per_word=1, num_keep=None): """ Based on AllenNLP allennlp.modules.Pruner from release 0.84 Parameters ---------- logits: (batch_size, num_spans, num_tags) mask: (batch_size, num_spans) num_keep: int OR torch.LongTensor If a tensor of shape (batch_size), specifies the number of items to keep for each individual sentence in minibatch. If an int, keep the same number of items for all sentences. """ #batch_size, num_items, num_tags = tuple(logits.shape) batch_size, num_items = tuple(scores.shape) # Number to keep not provided, so use spans per word if num_keep is None: num_keep = seq_length * spans_per_word num_keep = torch.max(num_keep, torch.ones_like(num_keep)) # If an int was given for number of items to keep, construct tensor by repeating the value. if isinstance(num_keep, int): num_keep = num_keep * torch.ones( [batch_size], dtype=torch.long, device=mask.device) # Maximum number to keep max_keep = num_keep.max() # Get scores from logits # (batch_size, num_spans) # scores = logit_scorer(logits) # Set overlapping span scores large neg number #if prune_overlapping: # scores = overlap_filter(scores, span_overlaps) # Add dimension scores = scores.unsqueeze(-1) #embeddings = embeddings.unsqueeze(-1) # Check scores dimensionality if scores.size(-1) != 1 or scores.dim() != 3: raise ValueError( f"The scorer passed to Pruner must produce a tensor of shape" f"(batch_size, num_items, 1), but found shape {scores.size()}") # Make sure that we don't select any masked items by setting their scores to be very # negative. These are logits, typically, so -1e20 should be plenty negative. #print("scores", scores.shape) #print('mask', mask.shape) mask = mask.unsqueeze(-1).bool() #type(torch.BoolTensor) #print('mask', mask.shape, mask.type) scores = util.replace_masked_values(scores, mask, NEG_FILL) # Shape: (batch_size, max_num_items_to_keep, 1) _, top_indices = scores.topk(max_keep, 1) # Mask based on number of items to keep for each sentence. # Shape: (batch_size, max_num_items_to_keep) top_indices_mask = util.get_mask_from_sequence_lengths(num_keep, max_keep) top_indices_mask = top_indices_mask.bool() # Shape: (batch_size, max_num_items_to_keep) top_indices = top_indices.squeeze(-1) # Fill all masked indices with largest "top" index for that sentence, so that all masked # indices will be sorted to the end. # Shape: (batch_size, 1) fill_value, _ = top_indices.max(dim=1) fill_value = fill_value.unsqueeze(-1) # Shape: (batch_size, max_num_items_to_keep) top_indices = torch.where(top_indices_mask, top_indices, fill_value) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). top_indices, _ = torch.sort(top_indices, 1) # Shape: (batch_size * max_num_items_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select items for each element in the batch. flat_indices = util.flatten_and_batch_shift_indices(top_indices, num_items) # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside # the top k for each sentence. # Shape: (batch_size, max_num_items_to_keep) sequence_mask = util.batched_index_select(mask, top_indices, flat_indices) sequence_mask = sequence_mask.squeeze(-1).bool() top_mask = top_indices_mask & sequence_mask top_mask = top_mask.long() # Shape: (batch_size, max_num_items_to_keep, 1) top_scores = util.batched_index_select(scores, top_indices, flat_indices) top_embeddings = util.batched_index_select(embeddings, top_indices, flat_indices) # Shape: (batch_size, max_num_items_to_keep) top_scores = top_scores.squeeze(-1) #top_embeddings = top_embeddings.squeeze(-1) return (top_indices, top_embeddings, top_scores, top_mask)
def forward( self, sentence: torch.Tensor, # R[Batch, Word, Emb] sentence_lengths: torch.Tensor, # Z_Word[Batch] span: torch.Tensor, # R[Batch, Word, Emb] span_lengths: torch.Tensor, # Z_Word[Batch] span_left: torch.Tensor, # Z_Word[Batch] span_right: torch.Tensor # Z_Word[Batch] ) -> torch.Tensor: # R[Batch, Feature] batch_size = sentence.size(0) sentence_max_len = sentence.size(1) emb_size = sentence.size(2) span_max_len = span.size(1) device = sentence.device neg_inf = torch.tensor(-10000, dtype=torch.float32, device=device) zero = torch.tensor(0, dtype=torch.float32, device=device) span = self.projection(self.dropout(span)) sentence = self.projection(self.dropout(sentence)) span_mask = get_mask_from_sequence_lengths( span_lengths, span_lengths.max().item()).byte() # Z[Batch, Word] def attention_pool(): span_attn_scores = torch.einsum('e,bwe->bw', self.query, span) masked_span_attn_scores = torch.where(span_mask, span_attn_scores, neg_inf) normalized_span_attn_scores = F.softmax(masked_span_attn_scores, dim=1) span_pooled = torch.einsum('bwe,bw->be', span, normalized_span_attn_scores) return span_pooled span_pooled = { "max": lambda: torch.max(torch.where( span_mask.unsqueeze(dim=2).expand_as(span), span, neg_inf), dim=1)[0], "mean": lambda: torch.sum(torch.where( span_mask.unsqueeze(dim=2).expand_as(span), span, zero), dim=1) / span_lengths.unsqueeze(dim=1).expand( batch_size, emb_size), "attention": lambda: attention_pool() }[self.mention_pooling]() # R[Batch, Emb] features = span_pooled if self.with_context: sentence_mask = get_mask_from_sequence_lengths( sentence_lengths, sentence_max_len).bool() # B[B, L] length_range = torch.arange(0, sentence_max_len, device=device) \ .unsqueeze(dim=0).expand(batch_size, sentence_max_len) span_mask = (length_range >= (span_left.unsqueeze(dim=1).expand_as(length_range))) \ & (length_range < (span_right.unsqueeze(dim=1).expand_as(length_range))) # B[Batch, Length] span_queries = self.mention_query_transform(span_pooled) attn_scores = torch.einsum('be,bwe->bw', span_queries, sentence) # R[Batch, Word] masked_attn_scores = torch.where( sentence_mask, attn_scores, neg_inf) # R[Batch, Word] & ~span_mask normalized_attn_scores = F.softmax(masked_attn_scores, dim=1) context_pooled = torch.einsum( 'bwe,bw->be', sentence, normalized_attn_scores) # R[Batch, Emb] features = torch.cat([span_pooled, context_pooled], dim=1) # R[Batch, Emb*2] return features # R[Batch, Emb]
def forward( self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, cls_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() output_dict = {} # --------- cls任务:判断是否需要改写 ------------------ if self._cls_task: # 获取cls表征, [B, embed_size] cls_embed = context_last_layer[:, 0] # 经过线性层分类, [B, 2] cls_logits = self._cls_linear(cls_embed) output_dict["cls_logits"] = cls_logits else: cls_logits = None # --------- mask任务:判断query中需要填充的位置 ----------- if self._mask_task: # 经过线性层,[B, _len, 2] mask_logits = self._mask_linear(query_last_layer) output_dict["mask_logits"] = mask_logits else: mask_logits = None if cls_label is not None: output_dict["loss"] = self._calc_loss(cls_label, mask_label, cls_logits, mask_logits, query_mask) return output_dict
def _get_action_embeddings(state: WikiTablesDecoderState, actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns an embedded representation for all actions in ``actions_to_embed``, using the state in ``WikiTablesDecoderState``. Parameters ---------- state : ``WikiTablesDecoderState`` The current state. We'll use this to get the global action embeddings. actions_to_embed : ``List[List[int]]`` A list of _global_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. This is expected to be output from :func:`_get_actions_to_consider`. Returns ------- action_embeddings : ``torch.FloatTensor`` An embedded representation of all of the given actions. Shape is ``(group_size, num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of considered actions for any group element. output_action_embeddings : ``torch.FloatTensor`` A second embedded representation of all of the given actions. The first is used when selecting actions, the second is used as the decoder output (which is the input at the next timestep). This is similar to having separate word embeddings and softmax layer weights in a language model or MT model. action_biases : ``torch.FloatTensor`` A bias weight for predicting each action. Shape is ``(group_size, num_actions, 1)``. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. """ num_actions = [len(action_list) for action_list in actions_to_embed] max_num_actions = max(num_actions) padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in actions_to_embed] # Shape: (group_size, num_actions) action_tensor = Variable(state.score[0].data.new(padded_actions).long()) # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim). # We want to select from state.action_embeddings using `action_tensor` to get a tensor of # shape (group_size, num_actions, action_embedding_dim). Unfortunately, the index_select # functions in nn.util don't do this operation. So we'll do some reshapes and do the # index_select ourselves. group_size = len(state.batch_indices) action_embedding_dim = state.action_embeddings.size(-1) flattened_actions = action_tensor.view(-1) flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions) action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim) flattened_output_embeddings = state.output_action_embeddings.index_select(0, flattened_actions) output_embeddings = flattened_output_embeddings.view(group_size, max_num_actions, action_embedding_dim) flattened_biases = state.action_biases.index_select(0, flattened_actions) biases = flattened_biases.view(group_size, max_num_actions, 1) sequence_lengths = Variable(action_embeddings.data.new(num_actions)) action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions) return action_embeddings, output_embeddings, biases, action_mask
def _get_entity_action_logits(self, state: WikiTablesDecoderState, actions_to_link: List[List[int]], attention_weights: torch.Tensor, linked_checklist_balance: torch.Tensor = None) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.FloatTensor]: """ Returns scores for each action in ``actions_to_link`` that are derived from the linking scores between the question and the table entities, and the current attention on the question. The intuition is that if we're paying attention to a particular word in the question, we should tend to select entity productions that we think that word refers to. We additionally return a mask representing which elements in the returned ``action_logits`` tensor are just padding, and an embedded representation of each action that can be used as input to the next step of the encoder. That embedded representation is derived from the type of the entity produced by the action. The ``actions_to_link`` are in terms of the `batch` action list passed to ``model.forward()``. We need to convert these integers into indices into the linking score tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the linking score for each entity, then aggregate the scores using the current question attention. Parameters ---------- state : ``WikiTablesDecoderState`` The current state. We'll use this to get the linking scores. actions_to_link : ``List[List[int]]`` A list of _batch_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. This is expected to be output from :func:`_get_actions_to_consider`. attention_weights : ``torch.Tensor`` The current attention weights over the question tokens. Should have shape ``(group_size, num_question_tokens)``. linked_checklist_balance : ``torch.Tensor``, optional (default=None) If the parser is being trained to maximize coverage over an agenda, this is the balance vector corresponding to entity actions, containing 1s and 0s, with 1s showing the actions that are yet to be produced. Required only if the parser is being trained to maximize coverage. Returns ------- action_logits : ``torch.FloatTensor`` A score for each of the given actions. Shape is ``(group_size, num_actions)``, where ``num_actions`` is the maximum number of considered actions for any group element. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. type_embeddings : ``torch.LongTensor`` A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded representation of the `type` of the entity corresponding to each action. """ # First we map the actions to entity indices, using state.actions_to_entities, and find the # type of each entity using state.entity_types. action_entities: List[List[int]] = [] entity_types: List[List[int]] = [] for batch_index, action_list in zip(state.batch_indices, actions_to_link): action_entities.append([]) entity_types.append([]) for action_index in action_list: entity_index = state.actions_to_entities[(batch_index, action_index)] action_entities[-1].append(entity_index) entity_types[-1].append(state.entity_types[entity_index]) # Then we create a padded tensor suitable for use with # `state.flattened_linking_scores.index_select()`. num_actions = [len(action_list) for action_list in action_entities] max_num_actions = max(num_actions) padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in action_entities] padded_types = [common_util.pad_sequence_to_length(type_list, max_num_actions) for type_list in entity_types] # Shape: (group_size, num_actions) action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long) type_tensor = state.score[0].new_tensor(padded_types, dtype=torch.long) # To get the type embedding tensor, we just use an embedding matrix on the list of entity # types. type_embeddings = self._entity_type_embedding(type_tensor) # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens). # We want to select from this using `action_tensor` to get a tensor of shape (group_size, # num_actions, num_question_tokens). Unfortunately, the index_select functions in nn.util # don't do this operation. So we'll do some reshapes and do the index_select ourselves. group_size = len(state.batch_indices) num_question_tokens = state.flattened_linking_scores.size(-1) flattened_actions = action_tensor.view(-1) # (group_size * num_actions, num_question_tokens) flattened_action_linking = state.flattened_linking_scores.index_select(0, flattened_actions) # (group_size, num_actions, num_question_tokens) action_linking = flattened_action_linking.view(group_size, max_num_actions, num_question_tokens) # Now we get action logits by weighting these entity x token scores by the attention over # the question tokens. We can do this efficiently with torch.bmm. action_logits = action_linking.bmm(attention_weights.unsqueeze(-1)).squeeze(-1) if linked_checklist_balance is not None: # ``linked_checklist_balance`` is a binary tensor of size (group_size, num_actions) with # 1s indicating the linked actions that the agenda wants the decoder to produce, but # haven't been produced yet. We're simply doubling the logits of those actions here. action_logits_addition = action_logits * linked_checklist_balance action_logits = action_logits + self._linked_checklist_multiplier * action_logits_addition # Finally, we make a mask for our action logit tensor. sequence_lengths = action_linking.new_tensor(num_actions) action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions) return action_logits, action_mask, type_embeddings
def _get_action_embeddings(state: WikiTablesDecoderState, actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns an embedded representation for all actions in ``actions_to_embed``, using the state in ``WikiTablesDecoderState``. Parameters ---------- state : ``WikiTablesDecoderState`` The current state. We'll use this to get the global action embeddings. actions_to_embed : ``List[List[int]]`` A list of _global_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. This is expected to be output from :func:`_get_actions_to_consider`. Returns ------- action_embeddings : ``torch.FloatTensor`` An embedded representation of all of the given actions. Shape is ``(group_size, num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of considered actions for any group element. output_action_embeddings : ``torch.FloatTensor`` A second embedded representation of all of the given actions. The first is used when selecting actions, the second is used as the decoder output (which is the input at the next timestep). This is similar to having separate word embeddings and softmax layer weights in a language model or MT model. action_biases : ``torch.FloatTensor`` A bias weight for predicting each action. Shape is ``(group_size, num_actions, 1)``. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. """ num_actions = [len(action_list) for action_list in actions_to_embed] max_num_actions = max(num_actions) padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in actions_to_embed] # Shape: (group_size, num_actions) action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long) # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim). # We want to select from state.action_embeddings using `action_tensor` to get a tensor of # shape (group_size, num_actions, action_embedding_dim). Unfortunately, the index_select # functions in nn.util don't do this operation. So we'll do some reshapes and do the # index_select ourselves. group_size = len(state.batch_indices) action_embedding_dim = state.action_embeddings.size(-1) flattened_actions = action_tensor.view(-1) flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions) action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim) flattened_output_embeddings = state.output_action_embeddings.index_select(0, flattened_actions) output_embeddings = flattened_output_embeddings.view(group_size, max_num_actions, action_embedding_dim) flattened_biases = state.action_biases.index_select(0, flattened_actions) biases = flattened_biases.view(group_size, max_num_actions, 1) sequence_lengths = action_embeddings.new_tensor(num_actions) action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions) return action_embeddings, output_embeddings, biases, action_mask
def forward(self, inputs, mask, sent_counts, sent_lens, prompt_inputs, prompt_mask, prompt_sent_counts, prompt_sent_lens, manual_feature, label=None): """ :param prompt_sent_lens: :param prompt_sent_counts: :param prompt_inputs: :param prompt_mask: :param inputs: [batch size, max sent count, max sent len] :param mask: [batch size, max sent count, max sent len] :param sent_counts: [batch size] :param sent_lens: [batch size, max sent count] :param label: [batch size] :return: """ batch_size = inputs.shape[0] max_sent_count = inputs.shape[1] max_sent_length = inputs.shape[2] inputs = inputs.view(-1, inputs.shape[-1]) mask = mask.view(-1, mask.shape[-1]) # [batch size * max sent len, hid size] last_hidden_states = self.bert(input_ids=inputs, attention_mask=mask)[0] last_hidden_states = last_hidden_states.view(batch_size, max_sent_count, max_sent_length, -1) prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1]) prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) prompt_hidden_states = self.bert(input_ids=prompt_inputs, attention_mask=prompt_mask)[0] docs = [] lens = [] for i in range(0, batch_size): doc = [] sent_count = sent_counts[i] sent_len = sent_lens[i] for j in range(sent_count): length = sent_len[j] cur_sent = last_hidden_states[i, j, :length, :] # print('cur sent shape', cur_sent.shape) doc.append(cur_sent) # mean for a doc doc_vec = torch.cat(doc, dim=0).unsqueeze(0) doc_vec = self.positional_encoding.forward(doc_vec) lens.append(doc_vec.shape[1]) # print(i, 'doc shape', doc_vec.shape) docs.append(doc_vec) batch_max_len = max(lens) for i, doc in enumerate(docs): if doc.shape[1] < batch_max_len: pd = (0, 0, 0, batch_max_len - doc.shape[1]) m = nn.ConstantPad2d(pd, 0) doc = m(doc) docs[i] = doc # [batch size, bert embedding dim] docs = torch.cat(docs, 0) docs_mask = get_mask_from_sequence_lengths( torch.tensor(lens), max_length=batch_max_len).to(docs.device) prompt = [] for j in range(prompt_sent_counts): length = prompt_sent_lens[0][j] sent = prompt_hidden_states[j, :length, :] prompt.append(sent) prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0) prompt_vec = self.positional_encoding.forward(prompt_vec) prompt_len = prompt_vec.shape[1] prompt_attention_mask = get_mask_from_sequence_lengths( torch.tensor([prompt_len]), max_length=prompt_len).to(prompt_vec.device) # [1, seq len] prompt_vec_weights = self.prompt_global_attention( prompt_vec, prompt_attention_mask) # [1, bert hidden size] prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1), prompt_vec).squeeze(1) doc_weights = self.doc_global_attention(docs, docs_mask) doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1) doc_feature = self.dropout_layer(torch.tanh(doc_vec)) prompt_feature = self.dropout_layer( torch.tanh(prompt_vec.expand_as(doc_feature))) feature = torch.cat([doc_feature, prompt_feature], dim=-1) log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1) # log_probs = self.classifier(docs) if label is not None: loss = self.criterion(input=log_probs.contiguous().view( -1, log_probs.shape[-1]), target=label.contiguous().view(-1)) else: loss = None prediction = torch.max(log_probs, dim=1)[1] return {'loss': loss, 'prediction': prediction}
def forward( self, # pylint: disable=arguments-differ embeddings: torch.FloatTensor, mask: torch.LongTensor, num_items_to_keep: Union[int, torch.LongTensor] ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ Extracts the top-k scoring items with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that downstream components can rely on the original ordering (e.g., for knowing what spans are valid antecedents in a coreference resolution model). May use the same k for all sentences in minibatch, or different k for each. Parameters ---------- embeddings : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for each item in the list that we want to prune. mask : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_items), denoting unpadded elements of ``embeddings``. num_items_to_keep : ``Union[int, torch.LongTensor]``, required. If a tensor of shape (batch_size), specifies the number of items to keep for each individual sentence in minibatch. If an int, keep the same number of items for all sentences. Returns ------- top_embeddings : ``torch.FloatTensor`` The representations of the top-k scoring items. Has shape (batch_size, max_num_items_to_keep, embedding_size). top_mask : ``torch.LongTensor`` The corresponding mask for ``top_embeddings``. Has shape (batch_size, max_num_items_to_keep). top_indices : ``torch.IntTensor`` The indices of the top-k scoring items into the original ``embeddings`` tensor. This is returned because it can be useful to retain pointers to the original items, if each item is being scored by multiple distinct scorers, for instance. Has shape (batch_size, max_num_items_to_keep). top_item_scores : ``torch.FloatTensor`` The values of the top-k scoring items. Has shape (batch_size, max_num_items_to_keep, 1). """ # If an int was given for number of items to keep, construct tensor by repeating the value. if isinstance(num_items_to_keep, int): batch_size = mask.size(0) # Put the tensor on same device as the mask. num_items_to_keep = num_items_to_keep * torch.ones( [batch_size], dtype=torch.long, device=mask.device) max_items_to_keep = num_items_to_keep.max() mask = mask.unsqueeze(-1) num_items = embeddings.size(1) # Shape: (batch_size, num_items, 1) scores = self._scorer(embeddings) if scores.size(-1) != 1 or scores.dim() != 3: raise ValueError( f"The scorer passed to Pruner must produce a tensor of shape" f"(batch_size, num_items, 1), but found shape {scores.size()}") # Make sure that we don't select any masked items by setting their scores to be very # negative. These are logits, typically, so -1e20 should be plenty negative. scores = util.replace_masked_values(scores, mask, -1e20) # Shape: (batch_size, max_num_items_to_keep, 1) _, top_indices = scores.topk(max_items_to_keep, 1) # Mask based on number of items to keep for each sentence. # Shape: (batch_size, max_num_items_to_keep) top_indices_mask = util.get_mask_from_sequence_lengths( num_items_to_keep, max_items_to_keep) top_indices_mask = top_indices_mask.byte() # Shape: (batch_size, max_num_items_to_keep) top_indices = top_indices.squeeze(-1) # Fill all masked indices with largest "top" index for that sentence, so that all masked # indices will be sorted to the end. # Shape: (batch_size, 1) fill_value, _ = top_indices.max(dim=1) fill_value = fill_value.unsqueeze(-1) # Shape: (batch_size, max_num_items_to_keep) top_indices = torch.where(top_indices_mask, top_indices, fill_value) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). top_indices, _ = torch.sort(top_indices, 1) # Shape: (batch_size * max_num_items_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select items for each element in the batch. flat_top_indices = util.flatten_and_batch_shift_indices( top_indices, num_items) # Shape: (batch_size, max_num_items_to_keep, embedding_size) top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices) # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside # the top k for each sentence. # Shape: (batch_size, max_num_items_to_keep) sequence_mask = util.batched_index_select(mask, top_indices, flat_top_indices) sequence_mask = sequence_mask.squeeze(-1).byte() top_mask = top_indices_mask & sequence_mask top_mask = top_mask.long() # Shape: (batch_size, max_num_items_to_keep, 1) top_scores = util.batched_index_select(scores, top_indices, flat_top_indices) return top_embeddings, top_mask, top_indices, top_scores
def _get_entity_action_logits(self, state: WikiTablesDecoderState, actions_to_link: List[List[int]], attention_weights: torch.Tensor) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.FloatTensor]: """ Returns scores for each action in ``actions_to_link`` that are derived from the linking scores between the question and the table entities, and the current attention on the question. The intuition is that if we're paying attention to a particular word in the question, we should tend to select entity productions that we think that word refers to. We additionally return a mask representing which elements in the returned ``action_logits`` tensor are just padding, and an embedded representation of each action that can be used as input to the next step of the encoder. That embedded representation is derived from the type of the entity produced by the action. The ``actions_to_link`` are in terms of the `batch` action list passed to ``model.forward()``. We need to convert these integers into indices into the linking score tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the linking score for each entity, then aggregate the scores using the current question attention. Parameters ---------- state : ``WikiTablesDecoderState`` The current state. We'll use this to get the linking scores. actions_to_link : ``List[List[int]]`` A list of _batch_ action indices for each group element. Should have shape (group_size, num_actions), unpadded. This is expected to be output from :func:`_get_actions_to_consider`. attention_weights : ``torch.Tensor`` The current attention weights over the question tokens. Should have shape ``(group_size, num_question_tokens)``. Returns ------- action_logits : ``torch.FloatTensor`` A score for each of the given actions. Shape is ``(group_size, num_actions)``, where ``num_actions`` is the maximum number of considered actions for any group element. action_mask : ``torch.LongTensor`` A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index, action_index)`` pairs were merely added as padding. type_embeddings : ``torch.LongTensor`` A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded representation of the `type` of the entity corresponding to each action. """ # First we map the actions to entity indices, using state.actions_to_entities, and find the # type of each entity using state.entity_types. action_entities: List[List[int]] = [] entity_types: List[List[int]] = [] for batch_index, action_list in zip(state.batch_indices, actions_to_link): action_entities.append([]) entity_types.append([]) for action_index in action_list: entity_index = state.actions_to_entities[(batch_index, action_index)] action_entities[-1].append(entity_index) entity_types[-1].append(state.entity_types[entity_index]) # Then we create a padded tensor suitable for use with # `state.flattened_linking_scores.index_select()`. num_actions = [len(action_list) for action_list in action_entities] max_num_actions = max(num_actions) padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions) for action_list in action_entities] padded_types = [common_util.pad_sequence_to_length(type_list, max_num_actions) for type_list in entity_types] # Shape: (group_size, num_actions) action_tensor = Variable(state.score[0].data.new(padded_actions).long()) type_tensor = Variable(state.score[0].data.new(padded_types).long()) # To get the type embedding tensor, we just use an embedding matrix on the list of entity # types. type_embeddings = self._entity_type_embedding(type_tensor) # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens). # We want to select from this using `action_tensor` to get a tensor of shape (group_size, # num_actions, num_question_tokens). Unfortunately, the index_select functions in nn.util # don't do this operation. So we'll do some reshapes and do the index_select ourselves. group_size = len(state.batch_indices) num_question_tokens = state.flattened_linking_scores.size(-1) flattened_actions = action_tensor.view(-1) # (group_size * num_actions, num_question_tokens) flattened_action_linking = state.flattened_linking_scores.index_select(0, flattened_actions) # (group_size, num_actions, num_question_tokens) action_linking = flattened_action_linking.view(group_size, max_num_actions, num_question_tokens) # Now we get action logits by weighting these entity x token scores by the attention over # the question tokens. We can do this efficiently with torch.bmm. action_logits = action_linking.bmm(attention_weights.unsqueeze(-1)).squeeze(-1) # Finally, we make a mask for our action logit tensor. sequence_lengths = Variable(action_linking.data.new(num_actions)) action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions) return action_logits, action_mask, type_embeddings
def forward( self, # type: ignore source_features: torch.FloatTensor, source_lengths: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor] = None, words: Dict[str, torch.LongTensor] = None, segments: torch.LongTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, epoch_num: int = None, dataset: str = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ output_dict = {} if dataset is not None: self._target_granularity = dataset[0] if epoch_num is not None: self._epoch_num = epoch_num[0] self.set_output_layer_num() source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)).bool() source_features = source_features.unsqueeze(1) # make a channel dim if self._delta: source_features = self._delta(source_features) batch_size, n_channels, timesteps, feature_size = source_features.size( ) source_features = self._input_norm( source_features.transpose(-2, -1).reshape(batch_size, -1, timesteps)) \ .view(batch_size, n_channels, feature_size, timesteps).transpose(-2, -1) source_features = self.time_mask(source_features, source_mask) source_features = self.freq_mask(source_features, source_mask) source_features = source_features.masked_fill( ~source_mask.unsqueeze(1).unsqueeze(-1).expand_as(source_features), 0.0) state = self._encode(source_features, source_lengths) source_lengths = util.get_lengths_from_binary_sequence_mask( state["source_mask"]) target_tokens["mask"] = (target_tokens[self._target_namespace] != self._pad_index).bool() if self._phn_ctc_layer and \ (self._phn_target_namespace in self._target_granularity or self._train_at_phn_level): raise NotImplementedError # logits = self._projection_layer(state["encoder_outputs"]) # phn_ctc_output_dict = self._phn_ctc_layer(logits, source_lengths, target_tokens) # output_dict.update({f"phn_ctc_{key}": value for key, value in phn_ctc_output_dict.items()}) if self._rnnt_layer is not None and self._rnnt_layer.loss_ratio > 0.0: rnnt_output_dict = self._rnnt_layer(state["encoder_outputs"], source_lengths, target_tokens) output_dict.update({ f"rnnt_{key}": value for key, value in rnnt_output_dict.items() }) if self._ctc_layer is not None and self._ctc_layer.loss_ratio > 0.0: logits = self._projection_layer(state["encoder_outputs"]) ctc_output_dict = self._ctc_layer(logits, source_lengths, target_tokens) output_dict.update({ f"ctc_{key}": value for key, value in ctc_output_dict.items() }) if target_tokens and self._att_ratio > 0.0 and \ self._target_granularity == self._target_namespace: targets = target_tokens[self._target_namespace] output_dict["target_tokens"] = targets target_mask = util.get_text_field_mask(target_tokens) if self._train_at_phn_level: raise NotImplementedError # state = self._get_phn_level_representations( # state["encoder_outputs"].detach().requires_grad_(True), # state["source_mask"], # output_dict["phn_ctc"]) state = self._init_decoder_state(state) output_dict.update(self._forward_loop(state, target_tokens)) self._logs["att_wer"](output_dict["predictions"], targets) if self._dep_parser or self._pos_tagger: relevant_mask = target_mask[:, 1:] attention_contexts, _ = _remove_eos( output_dict["attention_contexts"], relevant_mask) if segments is not None: segments, _ = remove_sentence_boundaries( segments, target_mask) attention_contexts, _ = \ char_to_word(attention_contexts, segments) contexts = {"tokens": attention_contexts} if self._dep_parser: parser_outputs = self._dep_parser(contexts, pos_tags, metadata, head_tags, head_indices) parser_outputs["dep_loss"] = parser_outputs.pop("loss") output_dict.update(parser_outputs) if self._pos_tagger: tagger_outputs = self._pos_tagger(contexts, pos_tags, metadata) tagger_outputs["pos_loss"] = tagger_outputs.pop("loss") output_dict.update(tagger_outputs) if not self.training: if self._target_granularity == self._target_namespace: if self._att_ratio > 0.0: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: targets = target_tokens[self._target_namespace] # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._logs["att_bleu"](best_predictions, targets) self._logs["att_wer"](best_predictions, targets) log_dict = self.decode(output_dict) verbose_target = [ self._indices_to_tokens(tokens.tolist()[1:]) for tokens in target_tokens[self._target_namespace] ] verbose_best_pred = [ beams[0] for beams in log_dict["predicted_tokens"] ] sep = " " if self._target_namespace == 'tokens' else "" with open(f"preds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_best_pred ])) fp.write("\n") with open(f"golds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_target ])) fp.write("\n") # for gold, pred in zip(verbose_target, verbose_best_pred): # print(gold, pred) if self.training: output_dict = self._collect_losses( output_dict, ctc=(self._ctc_layer.loss_ratio if self._ctc_layer else 0), rnnt=(self._rnnt_layer.loss_ratio if self._rnnt_layer else 0), att=self._att_ratio, dal=self._latency_penalty, dep=self._dep_ratio, pos=self._pos_ratio) if torch.isnan(output_dict["loss"]).any() or \ (torch.abs(output_dict["loss"]) == float('inf')).any(): for key, _ in output_dict.items(): if "loss" in key: output_dict[key] = output_dict[key].new_zeros( size=(), requires_grad=True).clone() self._update_metrics(output_dict) return output_dict