def embed_encode_and_aggregate_list_text_field_with_feats_only(texts_list: Dict[str, torch.LongTensor], text_field_embedder, embeddings_dropout, encoder: Seq2SeqEncoder, aggregation_type, token_features=None, init_hidden_states=None): embedded_texts = text_field_embedder(texts_list) embedded_texts = embeddings_dropout(embedded_texts) if token_features is not None: embedded_texts = torch.cat([token_features], dim=-1) bs, ch_cnt, ch_tkn_cnt, d = tuple(embedded_texts.shape) embedded_texts_flattened = embedded_texts.view([bs * ch_cnt, ch_tkn_cnt, -1]) # masks texts_mask_dim_3 = get_text_field_mask(texts_list, num_wrapping_dims=1).float() texts_mask_flatened = texts_mask_dim_3.view([-1, ch_tkn_cnt]) # context encoding multiple_texts_init_states = None if init_hidden_states is not None: if init_hidden_states.shape[0] == bs and init_hidden_states.shape[1] != ch_cnt: if init_hidden_states.shape[1] != encoder.get_output_dim(): raise ValueError("The shape of init_hidden_states is {0} but is expected to be {1} or {2}".format(str(init_hidden_states.shape), str([bs, encoder.get_output_dim()]), str([bs, ch_cnt, encoder.get_output_dim()]))) # in this case we passed only 2D tensor which is the default output from question encoder multiple_texts_init_states = init_hidden_states.unsqueeze(1).expand([bs, ch_cnt, encoder.get_output_dim()]).contiguous() # reshape this to match the flattedned tokens multiple_texts_init_states = multiple_texts_init_states.view([bs * ch_cnt, encoder.get_output_dim()]) else: multiple_texts_init_states = init_hidden_states.view([bs * ch_cnt, encoder.get_output_dim()]) encoded_texts_flattened = encoder(embedded_texts_flattened, texts_mask_flatened, hidden_state=multiple_texts_init_states) aggregated_choice_flattened = seq2vec_seq_aggregate(encoded_texts_flattened, texts_mask_flatened, aggregation_type, encoder.is_bidirectional(), 1) # bs*ch X d aggregated_choice_flattened_reshaped = aggregated_choice_flattened.view([bs, ch_cnt, -1]) return aggregated_choice_flattened_reshaped
def embedd_encode_and_aggregate_text_field(question: Dict[str, torch.LongTensor], text_field_embedder, embeddings_dropout, encoder, aggregation_type, get_last_states=False): embedded_question = text_field_embedder(question) question_mask = get_text_field_mask(question).float() embedded_question = embeddings_dropout(embedded_question) encoded_question = encoder(embedded_question, question_mask) # aggregate sequences to a single item encoded_question_aggregated = seq2vec_seq_aggregate(encoded_question, question_mask, aggregation_type, encoder.is_bidirectional(), 1) # bs X d last_hidden_states = None if get_last_states: last_hidden_states = get_final_encoder_states(encoded_question, question_mask, encoder.is_bidirectional()) return encoded_question_aggregated, last_hidden_states
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_premise = self._embeddings_dropout(embedded_premise) embedded_hypothesis = self._text_field_embedder(hypothesis) embedded_hypothesis = self._embeddings_dropout(embedded_hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) embedded_premise = seq2vec_seq_aggregate( embedded_premise, premise_mask, self._premise_aggregate, self._premise_encoder.is_bidirectional(), 1) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) embedded_hypothesis = seq2vec_seq_aggregate( embedded_hypothesis, hypothesis_mask, self._hypothesis_aggregate, self._premise_encoder.is_bidirectional(), 1) aggregate_input = torch.cat([ embedded_premise, embedded_hypothesis, torch.abs(embedded_hypothesis - embedded_premise), embedded_hypothesis * embedded_hypothesis ], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: labels = label.long().view(-1) loss = self._loss(label_logits, labels) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict