def forward( # type: ignore self, tokens: TextFieldTensors, labels: torch.IntTensor = None) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` From a `TextField` labels : `torch.IntTensor`, optional (default = `None`) From a `MultiLabelField` # Returns An output dictionary consisting of: - `logits` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing unnormalized log probabilities of the label. - `probs` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing probabilities of the label. - `loss` : (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward is not None: embedded_text = self._feedforward(embedded_text) logits = self._classification_layer(embedded_text) probs = torch.sigmoid(logits) output_dict = {"logits": logits, "probs": probs} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if labels is not None: loss = self._loss(logits, labels.float().view(-1, self._num_labels)) output_dict["loss"] = loss # TODO (John): This shouldn't be necessary as __call__ of the metrics detaches these # tensors anyways? cloned_logits, cloned_labels = logits.clone(), labels.clone() self._micro_f1(cloned_logits, cloned_labels) self._macro_f1(cloned_logits, cloned_labels) return output_dict
def forward(self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], list_, passages_length: torch.LongTensor = None, correct_passage: torch.LongTensor = None, span_start: torch.IntTensor = None, span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # shape: B x N x T x E embedded_passage_list = self._embedder(list_) # shape: N (batch_size, num_passages, max_p, embedding_size) = embedded_passage_list.size() # shape: B x Tq x E embedded_question = self._embedder(question) embedded_passage = embedded_passage_list.view(batch_size, -1, embedding_size) # embedded_passage = self._embedder(passage) # batch_size = embedded_question.size(0) total_passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question) # passage_mask = util.get_text_field_mask(passage) passage_list_mask = util.get_text_field_mask(list_, 1) passage_mask = passage_list_mask.view(batch_size, -1) # shape: B x T x 2H encoded_question = self._dropout(self._question_encoder(embedded_question, question_mask)) encoded_passage = self._dropout(self._passage_encoder(embedded_passage, passage_mask)) passage_mask = passage_mask.float() question_mask = question_mask.float() encoding_dim = encoded_question.size(-1) #encoded_passage_list = self._dropout(self._passage_encoder(embedded_passage_list, passage_list_mask)) # shape: B x 2H if encoded_passage.is_cuda: cuda_device = encoded_passage.get_device() gru_hidden = Variable(torch.zeros(batch_size, encoding_dim).cuda(cuda_device)) else: gru_hidden = Variable(torch.zeros(batch_size, encoding_dim)) question_awared_passage = [] for timestep in range(total_passage_length): u_t_P = encoded_passage[:, timestep, :] # shape: B x Tq = attention(B x 2H, B x Tq x 2H) attn_weights = self._question_attention_for_passage(encoded_passage[:, timestep, :], encoded_question, question_mask) # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq) attended_question = util.weighted_sum(encoded_question, attn_weights) # shape: B x 4H passage_question_combined = torch.cat([encoded_passage[:, timestep, :], attended_question], dim=-1) # shape: B x 4H gate = F.sigmoid(self._gate(passage_question_combined)) gru_input = gate * passage_question_combined # shape: B x 2H gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) question_awared_passage.append(gru_hidden) # shape: B x T x 2H # question aware passage representation v_P question_awared_passage = torch.stack(question_awared_passage, dim=1) # compute question vector r_Q # shape: B x T = attention(B x 2H, B x T x 2H) v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim) attn_weights = self._question_attention_for_question(v_r_Q_tiled, encoded_question, question_mask) # shape: B x 2H r_Q = util.weighted_sum(encoded_question, attn_weights) # shape: B x T = attention(B x 2H, B x T x 2H) span_start_logits = self._passage_attention_for_answer(r_Q, question_awared_passage, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_mask) # shape: B x 2H c_t = util.weighted_sum(question_awared_passage, span_start_probs) # shape: B x 2H h_1 = self._dropout(self._answer_net(c_t, r_Q)) span_end_logits = self._passage_attention_for_answer(h_1, question_awared_passage, passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_mask) #num_passages = passages_length.size(1) #cum_passages = torch.cumsum(passages_length, dim=1) g = [] for i in range(num_passages): attn_weights = self._passage_attention_for_ranking(r_Q, question_awared_passage[:, i*max_p: (i + 1)*max_p, :], passage_mask[:, i*max_p: (i + 1)*max_p]) r_P = util.weighted_sum(question_awared_passage[:, i*max_p: (i + 1)*max_p, :], attn_weights) question_passage_combined = torch.cat([r_Q, r_P], dim=-1) gi = self._dropout(self._match_layer_2(F.tanh(self._match_layer_1(question_passage_combined)))) g.append(gi) # compute r_P # shape: B x T = attention(B x 2H, B x T x 2H) #attn_weights = self._passage_attention_for_ranking(r_Q, question_awared_passage, passage_mask) # shape: B x 2H #r_P = util.weighted_sum(question_awared_passage, attn_weights) # shape: B x 4H #question_passage_combined = torch.cat([r_Q, r_P], dim=-1) # shape: B x 10 #g = self._dropout(self._match_layer_2(F.tanh(self._match_layer_1(question_passage_combined)))) #cum_passages = torch.cumsum(passages_length, dim=1) #for b in range(batch_size): # for i in range(num_passages): # attn_weights = self._passage_attention_for_ranking(r_Q[b], question_awared_passage padded_span_start = span_start.clone() padded_span_end = span_end.clone() cumsum = torch.cumsum(passage_mask.long(), dim=1) for b in range(batch_size): padded_span_start[b] = (cumsum[b] == span_start[b] + 1).nonzero()[0][0] padded_span_end[b] = (cumsum[b] == span_end[b] + 1).nonzero()[0][0] g = torch.cat(g, dim=1) passage_log_probs = F.log_softmax(g, dim=-1) output_dict = {} if span_start is not None: AP_loss = F.nll_loss(span_start_log_probs, padded_span_start.squeeze(-1)) +\ F.nll_loss(span_end_log_probs, padded_span_end.squeeze(-1)) PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1)) loss = self._r * AP_loss + self._r * PR_loss output_dict['loss'] = loss _, max_start = torch.max(span_start_probs, dim=1) _, max_end = torch.max(span_end_probs, dim=1) #max_start = max_start.cpu().data[0] #max_end = max_end.cpu().data[0] #unpad for b in range(batch_size): max_start.data[b] = cumsum.data[b, max_start.data[b]] - 1 max_end.data[b] = cumsum.data[b, max_end.data[b]] - 1 output_dict['span_start_idx'] = max_start output_dict['span_end_idx'] = max_end self._num_iter += 1 if (self._num_iter % 50 == 0): print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.cpu().data[0], max_end.cpu().data[0])) return output_dict