def test_get_final_encoder_states(self): encoder_outputs = torch.Tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]) mask = torch.Tensor([[1, 1, 1], [1, 1, 0]]) final_states = util.get_final_encoder_states(encoder_outputs, mask, bidirectional=False) assert_almost_equal(final_states.data.numpy(), [[9, 10, 11, 12], [17, 18, 19, 20]]) final_states = util.get_final_encoder_states(encoder_outputs, mask, bidirectional=True) assert_almost_equal(final_states.data.numpy(), [[9, 10, 3, 4], [17, 18, 15, 16]])
def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]): embedded_input = self._sentence_embedder(sentence) # (batch_size, sentence_length) sentence_mask = util.get_text_field_mask(sentence).float() batch_size = embedded_input.size(0) # (batch_size, sentence_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_input, sentence_mask)) final_encoder_output = util.get_final_encoder_states(encoder_outputs, sentence_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) attended_sentence, _ = self._decoder_step.attend_on_question(final_encoder_output, encoder_outputs, sentence_mask) encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)] sentence_mask_list = [sentence_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, attended_sentence[i], encoder_outputs_list, sentence_mask_list)) return initial_rnn_state
def _get_initial_state(self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(actions[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state
def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRuleArray]], linking_scores: torch.Tensor) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append( RnnStatelet( final_encoder_output[i].repeat( self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) else: initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, metadata: List[Dict[str, Any]], pos_tags: Dict[str, torch.LongTensor] = None, span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. spans : ``torch.LongTensor``, required. A tensor of shape ``(batch_size, num_spans, 2)`` representing the inclusive start and end indices of all possible spans in the sentence. metadata : List[Dict[str, Any]], required. A dictionary of metadata for each batch element which has keys: tokens : ``List[str]``, required. The original string tokens in the sentence. gold_tree : ``nltk.Tree``, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : ``List[str]``, optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree. pos_tags : ``torch.LongTensor``, optional (default = None) The output of a ``SequenceLabelField`` containing POS tags. span_labels : ``torch.LongTensor``, optional (default = None) A torch tensor representing the integer gold class labels for all possible spans, of shape ``(batch_size, num_spans)``. Returns ------- An output dictionary consisting of: class_probabilities : ``torch.FloatTensor`` A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. spans : ``torch.LongTensor`` The original spans tensor. tokens : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, required. A list of POS tags in the sentence for each element in the batch. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self.pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: # This happens if you use batch_size 1 and encounter # a length 1 sentence in PTB, which do exist. -.- span_mask = span_mask.unsqueeze(-1) if span_labels is not None and span_labels.dim() == 1: span_labels = span_labels.unsqueeze(-1) num_spans = get_lengths_from_binary_sequence_mask(span_mask) encoded_text = self.encoder(embedded_text_input, mask) encoder_final_state = get_final_encoder_states(encoded_text, mask) span_representations = self.span_extractor(encoded_text, spans, mask, span_mask) if self.feedforward_layer is not None: span_representations = self.feedforward_layer(span_representations) logits = self.tag_projection_layer(span_representations) class_probabilities = masked_softmax(logits, span_mask.unsqueeze(-1)) output_dict = { "encoder_final_state": encoder_final_state, "encoded_text": encoded_text, "class_probabilities": class_probabilities, "spans": spans, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta.get("pos_tags") for meta in metadata], "num_spans": num_spans } if span_labels is not None: loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask) self.tag_accuracy(class_probabilities, span_labels, span_mask) output_dict["loss"] = loss # The evalb score is expensive to compute, so we only compute # it for the validation and test sets. batch_gold_trees = [meta.get("gold_tree") for meta in metadata] if all(batch_gold_trees ) and self._evalb_score is not None and not self.training: gold_pos_tags: List[List[str]] = [ list(zip(*tree.pos()))[1] for tree in batch_gold_trees ] predicted_trees = self.construct_trees( class_probabilities.cpu().data, spans.cpu().data, num_spans.data, output_dict["tokens"], gold_pos_tags) self._evalb_score(predicted_trees, batch_gold_trees) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. entity_bits : ``torch.Tensor``, optional (default=None) Tensor encoding bits for the world entities. denotation_target : ``torch.Tensor``, optional (default=None) If model's field ``denotation_only`` is True, this is the tensor target denotation. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : List[Dict[str, Any]], optional (default=None). A dictionary of metadata for each batch element which has keys: question_tokens : ``List[str]``, optional. The original string tokens in the question. world_extractions : ``nltk.Tree``, optional. Extracted worlds from the question. answer_index : ``List[str]``, optional. Index of the correct answer. """ table_text = table["text"] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / ( embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / ( embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2), ) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / ( embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / ( embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand( -1, num_entities * num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view( batch_size, num_question_tokens * num_entities * num_entity_tokens, self._embedding_dim, ) question_entity_similarity = self._entity_similarity_layer( product) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table["linking"] if self._linking_params is not None: feature_scores = self._linking_params( linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier( final_encoder_output) loss = torch.nn.functional.cross_entropy( denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list, )) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None, ) if self.training: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {"action_mapping": action_mapping} if target_action_sequences is not None: outputs["loss"] = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask))["loss"] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs["best_action_sequence"] = [] outputs["debug_info"] = [] outputs["entities"] = [] if self._linking_params is not None: outputs["linking_scores"] = linking_scores outputs["feature_scores"] = feature_scores outputs["linking_features"] = linking_features if self._use_entities: outputs["linking_probabilities"] = linking_probabilities if entity_bits is not None: outputs["entity_bits"] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs["logical_form"] = [] outputs["denotation_acc"] = [] outputs["score"] = [] outputs["parse_acc"] = [] outputs["answer_index"] = [] if metadata is not None: outputs["question_tokens"] = [] outputs["world_extractions"] = [] for i in range(batch_size): if metadata is not None: outputs["question_tokens"].append(metadata[i].get( "question_tokens", [])) if metadata is not None: outputs["world_extractions"].append(metadata[i].get( "world_extractions", {})) outputs["entities"].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][ 0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match( best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [ action_mapping[(i, action_index)] for action_index in best_action_indices ] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form( action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = "Error producing logical form" denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and "answer_index" in metadata[i]: answer_index = metadata[i]["answer_index"] denotation_accuracy = self._denotation_match( predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp( best_final_states[i][0].score[0].data.cpu().item()) outputs["answer_index"].append(predicted_answer_index) outputs["score"].append(score) outputs["parse_acc"].append(sequence_in_targets) outputs["best_action_sequence"].append(action_strings) outputs["logical_form"].append(logical_form) outputs["denotation_acc"].append(denotation_accuracy) outputs["debug_info"].append( best_final_states[i][0].debug_info[0]) # type: ignore else: outputs["parse_acc"].append(0) outputs["logical_form"].append("") outputs["denotation_acc"].append(0) outputs["score"].append(0) outputs["answer_index"].append(-1) outputs["best_action_sequence"].append([]) outputs["debug_info"].append([]) self._has_logical_form(0.0) return outputs
def forward(self, # type: ignore context_tokens: Dict[str, torch.LongTensor], tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, intents: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, # pylint: disable=unused-argument **kwargs) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- Returns ------- """ if self.context_for_intent or self.context_for_tag or \ self.attention_for_intent or self.attention_for_tag: embedded_context_input = self.text_field_embedder(context_tokens) if self.dropout: embedded_context_input = self.dropout(embedded_context_input) context_mask = util.get_text_field_mask(context_tokens) encoded_context = self.encoder(embedded_context_input, context_mask) if self.dropout: encoded_context = self.dropout(encoded_context) encoded_context_summary = util.get_final_encoder_states( encoded_context, context_mask, self.encoder.is_bidirectional()) embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout: encoded_text = self.dropout(encoded_text) intent_encoded_text = self.intent_encoder(encoded_text, mask) if self.intent_encoder else encoded_text if self.dropout and self.intent_encoder: intent_encoded_text = self.dropout(intent_encoded_text) is_bidirectional = self.intent_encoder.is_bidirectional() if self.intent_encoder else self.encoder.is_bidirectional() if self._feedforward is not None: encoded_summary = self._feedforward(util.get_final_encoder_states( intent_encoded_text, mask, is_bidirectional)) else: encoded_summary = util.get_final_encoder_states( intent_encoded_text, mask, is_bidirectional) tag_encoded_text = self.tag_encoder(encoded_text, mask) if self.tag_encoder else encoded_text if self.dropout and self.tag_encoder: tag_encoded_text = self.dropout(tag_encoded_text) if self.attention_for_intent or self.attention_for_tag: attention_weights = self.attention(encoded_summary, encoded_context, context_mask.float()) attended_context = util.weighted_sum(encoded_context, attention_weights) if self.context_for_intent: encoded_summary = torch.cat([encoded_summary, encoded_context_summary], dim=-1) if self.attention_for_intent: encoded_summary = torch.cat([encoded_summary, attended_context], dim=-1) if self.context_for_tag: tag_encoded_text = torch.cat([tag_encoded_text, encoded_context_summary.unsqueeze(dim=1).expand( encoded_context_summary.size(0), tag_encoded_text.size(1), encoded_context_summary.size(1))], dim=-1) if self.attention_for_tag: tag_encoded_text = torch.cat([tag_encoded_text, attended_context.unsqueeze(dim=1).expand( attended_context.size(0), tag_encoded_text.size(1), attended_context.size(1))], dim=-1) intent_logits = self.intent_projection_layer(encoded_summary) intent_probs = torch.sigmoid(intent_logits) predicted_intents = (intent_probs > 0.5).long() sequence_logits = self.tag_projection_layer(tag_encoded_text) if self.crf is not None: best_paths = self.crf.viterbi_tags(sequence_logits, mask) # Just get the tags and ignore the score. predicted_tags = [x for x, y in best_paths] else: predicted_tags = self.get_predicted_tags(sequence_logits) output = {"sequence_logits": sequence_logits, "mask": mask, "tags": predicted_tags, "intent_logits": intent_logits, "intent_probs": intent_probs, "intents": predicted_intents} if tags is not None: if self.crf is not None: # Add negative log-likelihood as loss log_likelihood = self.crf(sequence_logits, tags, mask) output["loss"] = -log_likelihood # Represent viterbi tags as "class probabilities" that we can # feed into the metrics class_probabilities = sequence_logits * 0. for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 else: loss = sequence_cross_entropy_with_logits(sequence_logits, tags, mask) class_probabilities = sequence_logits output["loss"] = loss if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask.float()) if metadata is not None: output["words"] = [x["words"] for x in metadata] if tags is not None and metadata: self.decode(output) self._dai_f1_metric(output["dialog_act"], [x["dialog_act"] for x in metadata]) rewards = self.get_rewards(output["dialog_act"], [x["dialog_act"] for x in metadata]) if self.rl else None if intents is not None: output["loss"] += torch.mean(self.intent_loss(intent_logits, intents.float())) self._intent_f1_metric(predicted_intents, intents) return output
def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) else: initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: schema_text = schema['text'] embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1) schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() embedded_utterance = self._question_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size() num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) num_question_tokens = utterance['tokens'].size(1) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_schema.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_utterance, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = schema['linking'] linking_scores = question_entity_similarity_max_score feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device) if self._use_neighbor_similarity_for_linking and neighbor_indices is not None: # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_schema, schema_mask) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_utterance], 2) # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) max_entities_relevance = linking_probabilities.max(dim=1)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() graph_initial_embedding = entity_type_embeddings * entities_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: entities_graph_encoding = self._get_schema_graph_encoding( worlds, graph_initial_embedding) graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities) encoder_outputs = torch.cat( (encoder_outputs, graph_link_embedding), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim( ) else: entities_graph_encoding = None if self._self_attend: # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding) entities_ff = self._ent2ent_ff(entities_graph_encoding) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state( worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[i] if entities_graph_encoding is not None else None) for i in range(batch_size) ] initial_sql_state = [ SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]]) -> GrammarBasedState: utterance_mask = util.get_text_field_mask(utterance).float() embedded_utterance = self.question_embedder(utterance) batch_size, _, _ = embedded_utterance.size() encoder_outputs = self._dropout( self._question_encoder(embedded_utterance, utterance_mask)) schema_text = schema['text'] input_mm_schema = self._input_mm_embedder(schema_text, num_wrapping_dims=1) output_mm_schema = self._output_mm_embedder(schema_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = input_mm_schema.size() schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() # TODO # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, input_mm_schema.device) # (batch_size, num_entities, embedding_dim) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # (batch_size, num_entities, embedding_dim) # An entity memory-representation is concatenated with two parts: # 1. Entity tokens embedding # 2. Entity type embedding K = torch.cat([ self._input_mm_encoder(input_mm_schema, schema_mask), entity_type_embeddings ], dim=2) V = torch.cat([ self._output_mm_encoder(output_mm_schema, schema_mask), entity_type_embeddings ], dim=2) encoder_output_dim = self._question_encoder.get_output_dim() # Encodes utterance in the context of the schema, which is stored in external memory encoder_outputs_with_context, attn_weights = self._mm_attn( encoder_outputs, K, V) attn_weights = attn_weights.transpose(1, 2) final_encoder_output = util.get_final_encoder_states( encoder_outputs_with_context, utterance_mask, self._question_encoder.is_bidirectional()) max_entities_relevance = attn_weights.max(dim=2)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() if self._self_attend: entities_ff = self._ent2ent_ff(entity_type_embeddings * entities_relevance) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] # RnnStatelet is using to keep track of the internal state of a decoder RNN: initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(worlds[i], valid_actions[i], attn_weights[i], linked_actions_linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_sql_state = [ SqlState(valid_actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=valid_actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], premise_tags, hypothesis: Dict[str, torch.LongTensor], hypothesis_tags, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis], dim=-1 ) hypothesis_enhanced = torch.cat( [encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise], dim=-1 ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward(premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values( v_ai, premise_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_b_max, _ = replace_masked_values( v_bi, hypothesis_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True ) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True ) # running the parser encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags) p_parse_encoder_final_state = get_final_encoder_states(encoded_p_parse, p_parse_mask) encoded_h_parse, h_parse_mask = self._parser(hypothesis, hypothesis_tags) h_parse_encoder_final_state = get_final_encoder_states(encoded_h_parse, h_parse_mask) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max, p_parse_encoder_final_state, h_parse_encoder_final_state], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], premise_tags: torch.LongTensor, hypothesis: Dict[str, torch.LongTensor], hypothesis_tags: torch.LongTensor, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` premise_tags : torch.LongTensor The POS tags of the premise. hypothesis : Dict[str, torch.LongTensor] From a ``TextField``. hypothesis_tags: torch.LongTensor The POS tags of the hypothesis. label : torch.IntTensor, optional, (default = None) From a ``LabelField``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze( -1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) # running the parser encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags) p_parse_encoder_final_state = get_final_encoder_states( encoded_p_parse, p_parse_mask) encoded_h_parse, h_parse_mask = self._parser(hypothesis, hypothesis_tags) h_parse_encoder_final_state = get_final_encoder_states( encoded_h_parse, h_parse_mask) compared_premise = torch.cat( [compared_premise, p_parse_encoder_final_state], dim=-1) compared_hypothesis = torch.cat( [compared_hypothesis, h_parse_encoder_final_state], dim=-1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {'logits': label_logits, 'label_probs': label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict['loss'] = loss if metadata is not None: output_dict['premise_tokens'] = [ x['premise_tokens'] for x in metadata ] output_dict['hypothesis_tokens'] = [ x['hypothesis_tokens'] for x in metadata ] return output_dict
def forward(self, inputs: torch.Tensor, mask: torch.Tensor): return get_final_encoder_states(self.seq2seq(inputs, None), mask)
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, intents: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, # pylint: disable=unused-argument **kwargs) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence to be tagged under a 'words' key. Returns ------- An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[int]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout: encoded_text = self.dropout(encoded_text) intent_encoded_text = self.intent_encoder( encoded_text, mask) if self.intent_encoder else encoded_text if self.dropout and self.intent_encoder: intent_encoded_text = self.dropout(intent_encoded_text) is_bidirectional = self.intent_encoder.is_bidirectional( ) if self.intent_encoder else self.encoder.is_bidirectional() if self._feedforward is not None: encoded_summary = self._feedforward( util.get_final_encoder_states(intent_encoded_text, mask, is_bidirectional)) else: encoded_summary = util.get_final_encoder_states( intent_encoded_text, mask, is_bidirectional) sequence_logits = self.tag_projection_layer(encoded_text) if self.crf is not None: best_paths = self.crf.viterbi_tags(sequence_logits, mask) # Just get the tags and ignore the score. predicted_tags = [x for x, y in best_paths] else: predicted_tags = self.get_predicted_tags(sequence_logits) intent_logits = self.intent_projection_layer(encoded_summary) predicted_intents = (torch.sigmoid(intent_logits) > 0.5).long() output = { "sequence_logits": sequence_logits, "mask": mask, "tags": predicted_tags, "intent_logits": intent_logits, "intents": predicted_intents } if tags is not None: if self.crf is not None: # Add negative log-likelihood as loss log_likelihood = self.crf(sequence_logits, tags, mask) output["loss"] = -log_likelihood # Represent viterbi tags as "class probabilities" that we can # feed into the metrics class_probabilities = sequence_logits * 0. for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 else: loss = sequence_cross_entropy_with_logits( sequence_logits, tags, mask) class_probabilities = sequence_logits output["loss"] = loss # self.metrics['tag_acc'](class_probabilities, tags, mask.float()) if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask.float()) if intents is not None: output["loss"] += self.intent_loss(intent_logits, intents.float()) # bloss = self.intent_loss2(intent_logits, intents.float()) # self.metrics['int_acc'](predicted_intents, intents) self._intent_f1_metric(predicted_intents, intents) # print(list([self.vocab.get_token_from_index(intent[0], namespace=self.intent_label_namespace) # for intent in instance_intents.nonzero().tolist()] for instance_intents in predicted_intents)) # print(list([self.vocab.get_token_from_index(intent[0], namespace=self.intent_label_namespace) # for intent in instance_intents.nonzero().tolist()] for instance_intents in intents)) if metadata is not None: output["words"] = [x["words"] for x in metadata] if tags is not None and metadata: self.decode(output) # print(output) # print(metadata) self._dai_f1_metric(output["dialog_act"], [x["dialog_act"] for x in metadata]) return output
def forward(self, tokens: TextFieldTensors, targets: TextFieldTensors, target_sentiments: torch.LongTensor = None, target_sequences: Optional[torch.LongTensor] = None, metadata: torch.LongTensor = None, position_weights: Optional[torch.LongTensor] = None, position_embeddings: Optional[Dict[str, torch.LongTensor]] = None, **kwargs) -> Dict[str, torch.Tensor]: ''' The text and targets are Dictionaries as they are text fields they can be represented many different ways e.g. just words or words and chars etc therefore the dictionary represents these different ways e.g. {'words': words_tensor_ids, 'chars': char_tensor_ids} ''' # Get masks for the targets before they get manipulated targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1) # This is required if the input is of shape greater than 3 dim e.g. # character input where it is # (batch size, number targets, token length, char length) label_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64) batch_size, number_targets = label_mask.shape batch_size_num_targets = batch_size * number_targets # Embed and encode text as a sequence embedded_context = self.context_field_embedder(tokens) embedded_context = self._variational_dropout(embedded_context) context_mask = util.get_text_field_mask(tokens) # Need to repeat the so it is of shape: # (Batch Size * Number Targets, Sequence Length, Dim) Currently: # (Batch Size, Sequence Length, Dim) batch_size, context_sequence_length, context_embed_dim = embedded_context.shape reshaped_embedding_context = embedded_context.unsqueeze(1).repeat( 1, number_targets, 1, 1) reshaped_embedding_context = reshaped_embedding_context.view( batch_size_num_targets, context_sequence_length, context_embed_dim) # Embed and encode target as a sequence. If True here the target # embeddings come from the context. if self._use_target_sequences: _, _, target_sequence_length, target_index_length = target_sequences.shape target_index_len_err = ( 'The size of the context sequence ' f'{context_sequence_length} is not the same' ' as the target index sequence ' f'{target_index_length}. This is to get ' 'the contextualized target through the context') assert context_sequence_length == target_index_length, target_index_len_err seq_targets_mask = target_sequences.view(batch_size_num_targets, target_sequence_length, target_index_length) reshaped_embedding_targets = torch.matmul( seq_targets_mask.type(torch.float32), reshaped_embedding_context) else: temp_targets = elmo_input_reshape(targets, batch_size, number_targets, batch_size_num_targets) if self.target_field_embedder: embedded_targets = self.target_field_embedder(temp_targets) else: embedded_targets = self.context_field_embedder(temp_targets) embedded_targets = elmo_input_reverse(embedded_targets, targets, batch_size, number_targets, batch_size_num_targets) # Size (batch size, num targets, target sequence length, embedding dim) embedded_targets = self._time_variational_dropout(embedded_targets) batch_size, number_targets, target_sequence_length, target_embed_dim = embedded_targets.shape reshaped_embedding_targets = embedded_targets.view( batch_size_num_targets, target_sequence_length, target_embed_dim) encoded_targets_mask = targets_mask.view(batch_size_num_targets, target_sequence_length) # Shape (Batch Size * Number targets), encoded dim encoded_targets_seq = self.target_encoder(reshaped_embedding_targets, encoded_targets_mask) encoded_targets_seq = self._naive_dropout(encoded_targets_seq) repeated_context_mask = context_mask.unsqueeze(1).repeat( 1, number_targets, 1) repeated_context_mask = repeated_context_mask.view( batch_size_num_targets, context_sequence_length) # Need to concat the target embeddings to the context words repeated_encoded_targets = encoded_targets_seq.unsqueeze(1).repeat( 1, context_sequence_length, 1) if self._AE: reshaped_embedding_context = torch.cat( (reshaped_embedding_context, repeated_encoded_targets), -1) # add position embeddings if required. reshaped_embedding_context = concat_position_embeddings( reshaped_embedding_context, position_embeddings, self.target_position_embedding) # Size (batch size * number targets, sequence length, embedding dim) reshaped_encoded_context_seq = self.context_encoder( reshaped_embedding_context, repeated_context_mask) reshaped_encoded_context_seq = self._variational_dropout( reshaped_encoded_context_seq) # Weighted position information encoded into the context sequence. if self.target_position_weight is not None: if position_weights is None: raise ValueError( 'This model requires `position_weights` to ' 'better encode the target but none were given') position_output = self.target_position_weight( reshaped_encoded_context_seq, position_weights, repeated_context_mask) reshaped_encoded_context_seq, weighted_position_weights = position_output # Whether to concat the aspect embeddings on to the contextualised word # representations attention_encoded_context_seq = reshaped_encoded_context_seq if self._AttentionAE: attention_encoded_context_seq = torch.cat( (attention_encoded_context_seq, repeated_encoded_targets), -1) _, _, attention_encoded_dim = attention_encoded_context_seq.shape # Projection layer before the attention layer attention_encoded_context_seq = self.attention_project_layer( attention_encoded_context_seq) attention_encoded_context_seq = self._context_attention_activation_function( attention_encoded_context_seq) attention_encoded_context_seq = self._variational_dropout( attention_encoded_context_seq) # Attention over the context sequence attention_vector = self.attention_vector.unsqueeze(0).repeat( batch_size_num_targets, 1) attention_weights = self.context_attention_layer( attention_vector, attention_encoded_context_seq, repeated_context_mask) expanded_attention_weights = attention_weights.unsqueeze(-1) weighted_encoded_context_seq = reshaped_encoded_context_seq * expanded_attention_weights weighted_encoded_context_vec = weighted_encoded_context_seq.sum(dim=1) # Add the last hidden state of the context vector, with the attention vector context_final_states = util.get_final_encoder_states( reshaped_encoded_context_seq, repeated_context_mask, bidirectional=self.context_encoder_bidirectional) context_final_states = self.final_hidden_state_projection_layer( context_final_states) weighted_encoded_context_vec = self.final_attention_projection_layer( weighted_encoded_context_vec) feature_vector = context_final_states + weighted_encoded_context_vec feature_vector = self._naive_dropout(feature_vector) # Reshape the vector into (Batch Size, Number Targets, number labels) _, feature_dim = feature_vector.shape feature_target_seq = feature_vector.view(batch_size, number_targets, feature_dim) if self.inter_target_encoding is not None: feature_target_seq = self.inter_target_encoding( feature_target_seq, label_mask) feature_target_seq = self._variational_dropout(feature_target_seq) if self.feedforward is not None: feature_target_seq = self.feedforward(feature_target_seq) logits = self.label_projection(feature_target_seq) masked_class_probabilities = util.masked_softmax( logits, label_mask.unsqueeze(-1)) output_dict = { "class_probabilities": masked_class_probabilities, "targets_mask": label_mask } # Convert it to bool tensor. label_mask = label_mask == 1 if target_sentiments is not None: # gets the loss per target instance due to the average=`token` if self.loss_weights is not None: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, label_mask, average='token', alpha=self.loss_weights) else: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, label_mask, average='token') for metrics in [self.metrics, self.f1_metrics]: for metric in metrics.values(): metric(logits, target_sentiments, label_mask) output_dict["loss"] = loss if metadata is not None: words = [] texts = [] targets = [] target_words = [] for batch_index, sample in enumerate(metadata): words.append(sample['text words']) texts.append(sample['text']) targets.append(sample['targets']) target_words.append(sample['target words']) output_dict["words"] = words output_dict["text"] = texts word_attention_weights = attention_weights.view( batch_size, number_targets, context_sequence_length) output_dict["word_attention"] = word_attention_weights output_dict["targets"] = targets output_dict["target words"] = target_words output_dict["context_mask"] = context_mask return output_dict
def get_BILOU_features(self, token_indices, sent_len, span_len): #print(token_indices) span_level_token_indices = {} for ky,val in list(token_indices.items()): if ky == 'elmo': continue val = val.unsqueeze(1) span_level_token_indices[ky] = torch.cat([val[:, :, i:i + span_len + 1] for i in range(sent_len - 1 - span_len)], 1) ''' print(span_level_token_indices) t = span_level_token_indices["tokens"][0].cpu().numpy().tolist() import json with open("./data/dict.json", "r", encoding="utf-8") as df: dic = json.load(df) a = [[dic[str(word)] for word in span] for span in t] print(a) ''' ori_seq = [self.id2words(each.cpu().numpy().tolist()) for each in span_level_token_indices["tokens"]] att_logits = torch.Tensor([self.span_score(seq) for seq in ori_seq]) spans_embedded = self.softdict_text_field_embedder(span_level_token_indices, num_wrapping_dims=1) spans_mask = util.get_text_field_mask(span_level_token_indices, num_wrapping_dims=1) ''' for param in self.softdict_text_field_embedder.parameters(): #np.save("embed.npy", param.detach().numpy()) print(param.size()), exit(0) ''' #print(spans_mask) #print(spans_mask.size()) if util.get_device_of(spans_mask) >= 0: att_mask = torch.ge(torch.mean(spans_mask.float(), -1), (torch.ones(spans_mask.size(0), spans_mask.size(1)) - 2e-6).cuda(util.get_device_of(spans_mask))) else: att_mask = torch.ge(torch.mean(spans_mask.float(), -1), (torch.ones(spans_mask.size(0), spans_mask.size(1)) - 2e-6)) dim_2_pad = self.ALLOWED_SPANLEN - spans_embedded.size(2) p2d = (0,0,0, dim_2_pad) # now shape (batch_size, num_span, max_span_width, dim) spans_embedded = F.pad(spans_embedded, p2d, "constant", 0.) spans_mask = F.pad(spans_mask, (0, dim_2_pad), "constant", 0.) #print("embed:") #print(spans_embedded) ''' tt = {"tokens":torch.LongTensor([ 50, 1138, 84, 7, 645, 1135, 7386, 1123, 4979, 952, 2, 381, 173, 128, 8932, 9, 95, 1098, 16550, 524, 3897, 5190, 8242, 22, 2112, 6912, 1408, 814, 9853, 128])} t = self.softdict_text_field_embedder(tt) print(t),exit(0) ''' batch_size = spans_mask.size(0) num_spans = spans_mask.size(1) if util.get_device_of(spans_mask) >= 0: length_vec = torch.autograd.Variable(torch.LongTensor(range(self.ALLOWED_SPANLEN))).cuda(util.get_device_of(spans_mask)) else: length_vec = torch.autograd.Variable(torch.LongTensor(range(self.ALLOWED_SPANLEN))) length_vec = self.length_embedder(length_vec).unsqueeze(0).unsqueeze(0).expand(batch_size, num_spans, -1,-1) spans_encoded = self.encoder(spans_embedded, spans_mask) #BiLSTM #spans_encoded = torch.cat((spans_encoded, length_vec), 3).contiguous() #print(spans_encoded) spans_encoded = spans_encoded.reshape([batch_size * num_spans, self.ALLOWED_SPANLEN, -1]) ''' [batch_size * num_spans, self.ALLOWED_SPANLEN] shaped mask may occur whole zero like tensor([[1, 0, 0, ..., 0, 0, 0], [1, 0, 0, ..., 0, 0, 0], [1, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], device='cuda:0'), and use 'get_final_encoder_states()' will lead to the error RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/THCReduceAll.cuh:327 change to the tensor masked on span_sequence level is still assign as the unmasked tensor: [1 * 1:0*other] remain one 1 ''' spans_mask = spans_mask.reshape([batch_size * num_spans, self.ALLOWED_SPANLEN]) if util.get_device_of(spans_mask) >= 0: tmp = torch.zeros(self.ALLOWED_SPANLEN, dtype=torch.int64).cuda(util.get_device_of(spans_mask)) else: tmp = torch.zeros(self.ALLOWED_SPANLEN, dtype=torch.int64) tmp[0] = 1 tmp = tmp.expand([batch_size * num_spans, self.ALLOWED_SPANLEN]) new_spans_mask = spans_mask | tmp #print(new_spans_mask) last_state = get_final_encoder_states(spans_encoded, new_spans_mask) attention_coe, attention_out, attention_logits = self.attention(lstm_output=spans_encoded, final_state=last_state, mask_cuda=util.get_device_of(spans_mask)) #print(attention_logits),exit(0) attention_logits = attention_logits.reshape([batch_size, num_spans, -1])[:,:,1] # here 0 stand for true / 1 #print(attention_logits) #print(attention_logits.size()) attention_logits = attention_logits * att_mask.float() #print(attention_logits), exit(0) attention_out = attention_out.reshape([batch_size, num_spans, -1]) #print(attention_coe.size()) #attention_coe = attention_coe * spans_mask.float() attention_coe = attention_coe.reshape([batch_size, num_spans, -1]) attention_coe = attention_coe.unsqueeze(-1) #print(attention_coe) #attention_coe = torch.gt(attention_coe, 0.1).float() attention_coe = attention_coe.expand([batch_size, num_spans, attention_coe.size(2), 1]) #print(attention_coe.size()), exit(0) attention_coe = torch.cat([attention_coe, attention_coe.new_zeros(batch_size, 1, attention_coe.size(2), attention_coe.size(3))], dim=1) attention_out = torch.cat([attention_out, attention_out.new_zeros(batch_size, 1, attention_out.size(-1))], dim=1) attention_logits = torch.cat([attention_logits, attention_logits.new_zeros(batch_size, 1)], dim=1) #print(attention_logits.size(), att_logits.size()),exit(0) att_logits = torch.cat([att_logits, att_logits.new_zeros(batch_size, 1)], dim=1) return attention_coe[:,:,:span_len+1,:].detach(), att_logits
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata to persist Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens).float() encoder_output = self._encoder(embedded_text, mask) encoded_repr = [] for aggregation in self._aggregations: if aggregation == "meanpool": broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoder_output * broadcast_mask encoded_text = masked_mean(context_vectors, broadcast_mask, dim=1, keepdim=False) elif aggregation == 'maxpool': broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoder_output * broadcast_mask encoded_text = masked_max(context_vectors, broadcast_mask, dim=1) elif aggregation == 'final_state': is_bi = self._encoder.is_bidirectional() encoded_text = get_final_encoder_states( encoder_output, mask, is_bi) encoded_repr.append(encoded_text) encoded_repr = torch.cat(encoded_repr, 1) if self.dropout: encoded_repr = self.dropout(encoded_repr) output_hidden = self._output_feedforward(encoded_repr) label_logits = self._classification_layer(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def bidaf_reprs(self, question, contexts): # Shape: (B, ques_len, D), (B, num_contexts, context_len, D) (embedded_question_tensor, embedded_passages_tensor, question_mask_tensor, passages_mask_tensor) = self.embed_ques_passages(question, contexts) batch_size = embedded_question_tensor.size()[0] num_contexts = embedded_passages_tensor.size()[1] embedded_questions = [] questions_mask = [] embedded_contexts = [] contexts_mask = [] for i in range(0, batch_size): embedded_questions.append(embedded_question_tensor[i]) embedded_contexts.append(embedded_passages_tensor[i]) questions_mask.append(question_mask_tensor[i]) contexts_mask.append(passages_mask_tensor[i]) # Shape: (B, ques_len, D) encoded_ques_tensor = self.encode_question( embedded_question=embedded_question_tensor, question_lstm_mask=question_mask_tensor) # Shape: (B, D) ques_encoded_final_state = allenutil.get_final_encoder_states( encoded_ques_tensor, question_mask_tensor, self.bidaf_encoder_bidirectional) # List of tensors: (question_len, D) encoded_questions = [] # List of tensors: (num_contexts, context_len, D) encoded_contexts = [] for i in range(0, batch_size): # Shape: (1, ques_len, D) # encoded_ques = self.encode_question(embedded_question=embedded_questions[i].unsqueeze(0), # question_lstm_mask=questions_mask[i].unsqueeze(0)) encoded_questions.append(encoded_ques_tensor[i]) # Shape: (num_contexts, context_len, D) encoded_context = self.encode_context( embedded_passage=embedded_contexts[i], passage_lstm_mask=contexts_mask[i]) encoded_contexts.append(encoded_context) modeled_contexts = [] for i in range(0, batch_size): # Shape: (question_len, D) encoded_ques = encoded_questions[i] ques_mask = questions_mask[i] encoded_ques_ex = encoded_ques.unsqueeze(0).expand( num_contexts, *encoded_ques.size()) ques_mask_ex = ques_mask.unsqueeze(0).expand( num_contexts, *ques_mask.size()) output_dict = self.forward_bidaf( encoded_question=encoded_ques_ex, encoded_passage=encoded_contexts[i], question_lstm_mask=ques_mask_ex, passage_lstm_mask=contexts_mask[i]) # Shape: (num_contexts, context_len, D) modeled_context = output_dict['modeled_passage'] modeled_contexts.append(modeled_context) return (ques_encoded_final_state, encoded_ques_tensor, question_mask_tensor, embedded_questions, questions_mask, embedded_contexts, contexts_mask, encoded_questions, encoded_contexts, modeled_contexts)
def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : Dict[str, torch.LongTensor], optional (default = None) Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, encoder_output_dim) embedded_input = self._source_embedder(source_tokens) batch_size, _, _ = embedded_input.size() source_mask = util.get_text_field_mask(source_tokens) encoder_outputs = self._encoder(embedded_input, source_mask) # (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( encoder_outputs, source_mask, self._encoder.is_bidirectional()) if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end symbol. Either way, we # don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output decoder_context = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim) last_predictions = None step_logits = [] step_probabilities = [] step_predictions = [] for timestep in range(num_decoding_steps): use_gold_targets = False # Use gold tokens at test time when provided and at a rate of 1 - # _scheduled_sampling_ratio during training. if self.training: if torch.rand(1).item() >= self._scheduled_sampling_ratio: use_gold_targets = True elif target_tokens: use_gold_targets = True if use_gold_targets: input_choices = targets[:, timestep] else: if timestep == 0: # For the first timestep, when we do not have targets, we input start symbols. # (batch_size,) input_choices = source_mask.new_full( (batch_size, ), fill_value=self._start_index) else: input_choices = last_predictions decoder_input = self._prepare_decode_step_input( input_choices, decoder_hidden, encoder_outputs, source_mask) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) # (batch_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) step_probabilities.append(class_probabilities.unsqueeze(1)) last_predictions = predicted_classes # (batch_size, 1) step_predictions.append(last_predictions.unsqueeze(1)) # step_logits is a list containing tensors of shape (batch_size, 1, num_classes) # This is (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) class_probabilities = torch.cat(step_probabilities, 1) all_predictions = torch.cat(step_predictions, 1) output_dict = { "logits": logits, "class_probabilities": class_probabilities, "predictions": all_predictions } if target_tokens: target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss # TODO: Define metrics relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() self.__sequence_accuracy(all_predictions.unsqueeze(1), relevant_targets, relevant_mask) return output_dict
def forward(self, spans_tensor: torch.FloatTensor, spans_mask: torch.FloatTensor, question_tensor: torch.FloatTensor, question_mask: torch.FloatTensor, evd_chain_labels: torch.FloatTensor, self_att_layer: Seq2SeqEncoder, sent_encoder: Seq2SeqEncoder, get_all_beam: bool = False): print("spans_tensor", spans_tensor.shape) batch_size, num_spans, max_batch_span_width = spans_mask.size() # shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) spans_tensor = spans_tensor.view(batch_size, num_spans, max_batch_span_width, spans_tensor.size(2)) # shape: (batch_size, num_spans) max_pooled_span_mask = (torch.sum(spans_mask, dim=-1) >= 1).float() att_score = None # extract the final hidden states as the question vector # Shape: (batch_size, embedding_dim) question_emb = util.get_final_encoder_states(question_tensor, question_mask, True) # decode the most likely evidence path # shape (all_predictions): (batch_size, K, num_decoding_steps) # shape (all_logprobs): (batch_size, K, num_decoding_steps) # shape (seq_logprobs): (batch_size, K) # shape (final_hidden): (batch_size, K, decoder_output_dim) all_predictions, all_logprobs, seq_logprobs, final_hidden = self.evd_decoder( spans_tensor, spans_mask, question_emb, aux_input=None, #question_emb,#None transition_mask=None, labels=evd_chain_labels) if self._pass_label: all_predictions = evd_chain_labels.long().unsqueeze(1) all_logprobs = torch.zeros_like(all_predictions).float() #print("batch:", batch_size) #print("predict num:", torch.sum((all_predictions > 0).float(), dim=1)) print("all prediction:", all_predictions) # The selection order of each sentence. Set to -1 if not being chosen # shape: (batch_size, K, num_spans) _, beam, num_steps = all_predictions.size() orders = spans_tensor.new_ones((batch_size, beam, 1 + num_spans)) * -1 indices = util.get_range_vector(num_steps, util.get_device_of(spans_tensor)).\ float().\ unsqueeze(0).\ unsqueeze(0).\ expand(batch_size, beam, num_steps) orders.scatter_(2, all_predictions, indices) orders = orders[:, :, 1:] # For beamsearch, get the top one. For other helpers, just like squeeze if not get_all_beam: all_predictions = all_predictions[:, 0, :] all_logprobs = all_logprobs[:, 0, :] seq_logprobs = seq_logprobs[:, 0] final_hidden = final_hidden[:, 0, :] # build the gate. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans) if not get_all_beam: gate = spans_tensor.new_zeros((batch_size, 1 + num_spans)) else: gate = spans_tensor.new_zeros((batch_size, beam, 1 + num_spans)) gate.scatter_(-1, all_predictions, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) or (batch_size, K, num_spans) gate = gate[..., 1:] #print("gate:", gate) #print("real num:", torch.sum(gate, dim=1)) #print("seq probs:", torch.exp(seq_logprobs)) # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1) if not get_all_beam: gate = gate.reshape(batch_size * num_spans, 1) else: gate = gate.reshape(batch_size * beam * num_spans, 1) # The probability of each selected sentence being selected. If not selected, set to 0. # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1) if not get_all_beam: gate_probs = spans_tensor.new_zeros((batch_size, 1 + num_spans)) else: gate_probs = spans_tensor.new_zeros( (batch_size, beam, 1 + num_spans)) gate_probs.scatter_(-1, all_predictions, all_logprobs.exp()) gate_probs = gate_probs[..., 1:] if not get_all_beam: gate_probs = gate_probs.reshape(batch_size * num_spans, 1) else: gate_probs = gate_probs.reshape(batch_size * beam * num_spans, 1) return all_predictions, all_logprobs, seq_logprobs, gate, gate_probs, max_pooled_span_mask, att_score, orders
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, text_encoder: Seq2SeqEncoder, classifier_feedforward: FeedForward, verbose_metrics: False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, loss: Optional[dict] = None, ) -> None: super(MultilabelTextClassifier, self).__init__(vocab, regularizer) self.log = logging.getLogger(__name__) self.text_field_embedder = text_field_embedder self.num_classes = self.vocab.get_vocab_size("labels") self.log.warning(f'num_classes: {self.num_classes}') self.text_encoder = text_encoder self.classifier_feedforward = classifier_feedforward self.log.warning( f'output_dim: {self.classifier_feedforward.get_output_dim()}') self.prediction_layer = torch.nn.Linear( self.classifier_feedforward.get_output_dim(), self.num_classes) self.pool = lambda text, mask: util.get_final_encoder_states( text, mask, bidirectional=True) self.label_accuracy = CategoricalAccuracy() self.label_f1_metrics = OrderedDict() self.verbose_metrics = verbose_metrics for i in range(self.num_classes): label = vocab.get_token_from_index(index=i, namespace="labels") self.log.warning(f'label {i}: {label}') self.label_f1_metrics[label] = F1Measure(positive_label=i) self.micro_f1 = MultiLabelF1Measure() self.label_f1 = OrderedDict() for i in range(self.num_classes): label = vocab.get_token_from_index(index=i, namespace="labels") self.label_f1[label] = MultiLabelF1Measure() if loss is not None: alpha = loss.get('alpha') gamma = loss.get('gamma') weight = loss.get('weight') if alpha is not None: alpha = float(alpha) if gamma is not None: gamma = float(gamma) if weight is not None: weight = torch.tensor(ast.literal_eval(weight)) if loss is None or loss.get('type') == 'CrossEntropyLoss': self.loss = torch.nn.CrossEntropyLoss() elif loss.get('type') == 'BinaryFocalLoss': self.loss = BinaryFocalLoss(alpha=alpha, gamma=gamma) elif loss.get('type') == 'FocalLoss': self.loss = FocalLoss(alpha=alpha, gamma=gamma) elif loss.get('type') == 'MultiLabelMarginLoss': self.loss = torch.nn.MultiLabelMarginLoss() elif loss.get('type') == 'MultiLabelSoftMarginLoss': self.loss = torch.nn.MultiLabelSoftMarginLoss(weight) else: raise ValueError(f'Unexpected loss "{loss}"') initializer(self)
def forward(self, tags, history, next_sym, source_tokens, his_symptoms, target_tokens, **args): bs = len(tags) # self.flatten_parameters() # 获取history的embedding embeddings = self._source_embedder(history) mask = get_text_field_mask(history, num_wrapping_dims=1) # num_wrapping 增加维度 sz = list(embeddings.size()) embeddings = embeddings.view(sz[0] * sz[1], sz[2], sz[3]) mask = mask.view(sz[0] * sz[1], sz[2]) # 获取每一句的hidden bs * sen_num * hidden utter_hidden = self._vecoder(embeddings, mask) utter_hidden = utter_hidden.view(sz[0], sz[1], -1) # bs * sen_num * hidden dialog_mask = get_text_field_mask(history) dialog_hidden = self._sen_encoder(utter_hidden, dialog_mask) # hred的形式 # print("dialog_hidden: ",dialog_hidden.size()) # 初始化每个节点 symp_state = torch.zeros( bs, self.symp_size, self.outfeature).cuda() # bs * symp_size * hidden symp_state += self.symp_state # 每一个节点的初始化emb相同,这是个问题吗? # 句子与句子连边 如果不用cuda呢 sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) for i in range(bs): dic = {} for j in range(len(tags[i])): # 这一个地方可以改一下,这里是和前面的所有有关系的边都连上了 symp_state[i][self.topic + j] = utter_hidden[i][j] dic[j] = set(list(tags[i][j])) for k in range(j): for aa in dic[j]: if aa in dic[k] and aa != -1: # sym_mat[i][self.topic+j][self.topic+k] += 1 sym_mat[i][self.topic + k][self.topic + j] += 1 last_h = self.attn_one(symp_state, sym_mat) sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) for i in range(bs): for j in range(len(tags[i])): for tt in tags[i][j]: if tt != -1: sym_mat[i][self.topic + j][tt] += 1 # last_h = self.attn_two(last_h, sym_mat) # # topic和topic连边 sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) #加边 # for symp_i in his_symptoms: # for symp_j in his_symptoms: # self.evovl_mat[symp_i][symp_j] = 1 # temp_mat = (torch.nn.functional.relu(self.symp_mat) + self.evovl_mat).cpu() # with open('visulize_graph.txt', 'a') as fout: # fout.write('evovl_mat is: \n') # for i in self.evovl_mat.detach().cpu().numpy(): # fout.write(str(i) + '\n') # fout.write('temp_mat is: \n') # for i in temp_mat.detach().cpu().numpy(): # fout.write(str(i) + '\n') # print('[info] temp_mat is:{}'.format(temp_mat)) sym_mat[:, :self.topic, :self.topic] += self.symp_mat last_h = self.attn_three(last_h, sym_mat) # last_h = self.attn_three(last_h, sym_mat) topic_pre = torch.sum(self.predict_layer * last_h, dim=-1) + self.predict_bias topic_probs = torch.sigmoid(topic_pre) topics_weight = torch.ones_like(topic_probs) + 5 * next_sym.float() topic_loss = torch.nn.functional.binary_cross_entropy( topic_probs, next_sym.float(), weight=topics_weight) ans = (topic_probs > 0.5).long() # his_symptoms bs * sym_size? # his_mask = torch.where(his_symptoms > 0, torch.full_like(his_symptoms, 0), torch.full_like(his_symptoms,1)).long() # 隐藏句子节点 # his_mask his_sentence_mask = torch.zeros(bs, self.sen_num).long() total_mask = torch.cat( (torch.ones(bs, self.topic).long(), his_sentence_mask), -1) if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: aa = next_sym.long() else: aa = ans # total_mask = torch.ones(bs, self.symp_size).cuda() # total_mask = total_mask.long() & his_mask.long() topic_embedding = aa.float().matmul(self.symp_state) topic_hidden = last_h # 计算topic的f1, acc, rec pre_total = torch.sum(ans).item() true_total = torch.sum(next_sym).item() pre_right = torch.sum((ans == next_sym).long() * next_sym).item() # print(pre_total,pre_right) self.topic_acc(pre_right, pre_total) self.topic_rec(pre_right, true_total) acc = self.topic_acc.get_metric(False) rec = self.topic_rec.get_metric(False) f1 = 0. if acc + rec > 0: f1 = acc * rec * 2 / (acc + rec) self.topic_f1(f1) # Encoding source_tokens embedded_input = self._source_embedder(source_tokens) source_mask = util.get_text_field_mask(source_tokens) encoder_outputs = self._encoder(embedded_input, source_mask) final_encoder_output = util.get_final_encoder_states( encoder_outputs, source_mask, self._encoder.is_bidirectional()) # if self.training: # ff = next_sym.float().matmul(symp_state) # else: # ff = topics_weight.matmul(symp_state) # print('[info]final_encoder_output is:{}, ff:{}'.format(final_encoder_output.size(), ff.size())) state = { "source_mask": source_mask, "encoder_outputs": encoder_outputs, # bs * seq_len * dim "decoder_hidden": dialog_hidden, # bs * dim hred的输出 # "decoder_hidden": torch.cat((topic_embedding, dialog_hidden), -1), "decoder_context": encoder_outputs.new_zeros(bs, self._decoder_output_dim), "topic_embedding": topic_embedding } # state[''] = topic_embedding # 获取一次decoder output_dict = self._forward_loop(state, topic_hidden, total_mask.cuda(), target_tokens) best_predictions = output_dict["predictions"] # output something references, hypothesis = [], [] for i in range(bs): cut_hypo = best_predictions[i][:] if self._end_index in list(best_predictions[i]): cut_hypo = best_predictions[i][:list(best_predictions[i]). index(self._end_index)] hypothesis.append([ self.vocab.get_token_from_index(idx.item()) for idx in cut_hypo ]) flag = 1 for i in range(bs): cut_ref = target_tokens['tokens'][1:] if self._end_index in list(target_tokens['tokens'][i]): cut_ref = target_tokens['tokens'][i][ 1:list(target_tokens['tokens'][i]).index(self._end_index)] references.append([ self.vocab.get_token_from_index(idx.item()) for idx in cut_ref ]) if random.random() <= 0.001 and flag == 1: #not self.training and flag = 0 for jj in range(i): print('___hypo___', ''.join(hypothesis[jj]), end=' ## ') print(''.join(references[jj])) print("") self.bleu_aver(references, hypothesis) self.bleu1(references, hypothesis) self.bleu2(references, hypothesis) self.bleu4(references, hypothesis) self.kd_metric(references, hypothesis) self.dink1(hypothesis) self.dink2(hypothesis) if self.training: output_dict['loss'] = output_dict['loss'] + 8 * topic_loss else: output_dict['loss'] = topic_loss return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view(batch_size, num_question_tokens*num_entities*num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer(product) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier(final_encoder_output) loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get('question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get('world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp(best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs
def _get_initial_state_and_scores( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh( entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions( actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions( linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return { "initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score }
def forward(self, inputs: torch.Tensor, mask: torch.Tensor): # https://github.com/allenai/allennlp/issues/2411 return get_final_encoder_states(self._seq2seq(inputs, None), mask)
def forward(self, inputs: torch.Tensor, mask: torch.Tensor): out = self.stacked_self_att_enc(inputs, mask) return get_final_encoder_states(out, mask)
def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor] = None, target_tokens: Dict[str, torch.LongTensor] = None, source_tokens_raw=None, target_tokens_raw=None, predict: bool = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_input = self._source_embedder(source_tokens) batch_size, _, _ = embedded_input.size() source_mask = get_text_field_mask(source_tokens) encoder_outputs = self._encoder(embedded_input, source_mask) final_encoder_output = get_final_encoder_states( encoder_outputs, source_mask ) #encoder_outputs[:, -1] # (batch_size, encoder_output_dim) if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end symbol. Either way, we # don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps decoder_hidden = self.decode_h0_projection_layer(final_encoder_output) decoder_context = self.decode_h0_projection_layer(final_encoder_output) last_predictions = None step_attensions = [] step_probabilities = [] step_predictions = [] step_p_gen = [] for timestep in range(num_decoding_steps): if self.training and all( torch.rand(1) >= self._scheduled_sampling_ratio): input_choices = targets[:, timestep] else: if timestep == 0: # For the first timestep, when we do not have targets, we input start symbols. # (batch_size,) input_choices = source_mask.new().resize_( batch_size).fill_(self._start_index) else: input_choices = last_predictions # input_indices : (batch_size,) since we are processing these one timestep at a time. # (batch_size, target_embedding_dim) input_choices = {'tokens': input_choices} decoder_input = self._target_embedder(input_choices) #Dh_t(S_t),Dc_t decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) #cat[S_t,H*_t(short memory)] P_attensions, decoder_output = self._decode_step_output( decoder_hidden, encoder_outputs, source_mask) # (batch_size, num_classes) # W[S_t,H*_t]+b output_attention = self._output_attention_layer(decoder_output) output_projections = self._output_projection_layer( output_attention) # P_vocab class_probabilities = F.softmax(output_projections, dim=-1) # generation probability #P_gen = F.sigmoid(self._pointer_gen_layer(torch.cat((decoder_output,decoder_input),-1))) #class_probabilities = P_gen*class_probabilities #step_p_gen.append(P_gen.unsqueeze(1)) #print(f'P_gen:{P_gen.data.mean()}') # list of (batch_size, 1, num_classes) step_attensions.append(P_attensions.unsqueeze(1)) _, predicted_classes = torch.max(class_probabilities, 1) step_probabilities.append(class_probabilities.unsqueeze(1)) last_predictions = predicted_classes # (batch_size, 1) step_predictions.append(last_predictions.unsqueeze(1)) # This is (batch_size, num_decoding_steps, num_classes) all_attensions = torch.cat(step_attensions, 1) #all_p_gens = torch.cat(step_p_gen,1) class_probabilities = torch.cat(step_probabilities, 1) all_predictions = torch.cat(step_predictions, 1) output_dict = { "all_attensions": all_attensions, #"all_p_gens": all_p_gens, "source_tokens": source_tokens_raw, "class_probabilities": class_probabilities, "predictions": all_predictions } #att_dists = self._att_dists(all_predictions,all_attensions,source_tokens_raw) #output_dict.update({"att_dists":att_dists}) if target_tokens: target_mask = get_text_field_mask(target_tokens) gen_loss = self._get_loss(class_probabilities, targets, target_mask) import pdb pdb.set_trace() #copy_loss = self._get_copy_loss(all_p_gens,att_dists,target_tokens_raw) #copy_loss = self._get_copy_loss(att_dists,target_tokens_raw) #loss = gen_loss#+copy_loss print(f'gen_loss:{gen_loss.data.mean()}' ) #,copy_loss:{copy_loss.data.mean()}') output_dict["loss"] = gen_loss for metric in self.metrics.values(): evaluated_sentences = [ ''.join(i) for i in self.decode(output_dict)["predicted_tokens"] ] reference_sentences = [ ''.join([j.text for j in i]) for i in target_tokens_raw ] #print(f'evaluated_sentences:{evaluated_sentences},reference_sentences:{reference_sentences}') metric(evaluated_sentences, reference_sentences) return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default=None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = question_entity_similarity_max_score + feature_scores else: # The linking score is computed as a linear projection of two terms. The first is the maximum # similarity score over the entity's words and the question token. The second is the maximum # similarity over the words in the entity's neighbors and the question token. # The second term, projected_question_neighbor_similarity, is useful when # a column needs to be selected. For example, the question token might have no similarity # with the column name, but is similar with the cells in the column. # Note that projected_question_neighbor_similarity is intended to capture the same information # as the related_column feature. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0)) initial_score = Variable(embedded_question.data.new(batch_size).fill_(0)) action_embeddings, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, debug_info=None) if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['linking_scores'] = linking_scores if self._linking_params is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' if example_lisp_string: self._denotation_accuracy(logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs['entities'].append(world[i].table_graph.entities) else: outputs['logical_form'].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) return outputs
def forward(self, inputs: torch.Tensor, mask: torch.Tensor): output_seq = self.encoder(inputs, mask) output_vec = get_final_encoder_states(output_seq, mask) return output_vec
def _get_initial_state_and_scores(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return {"initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score}
def _get_initial_rnn_and_grammar_state( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], outputs: Dict[str, Any]) -> Tuple[List[RnnState], List[GrammarState]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnState and GrammarState for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state
def _get_initial_rnn_and_grammar_state(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet], List[LambdaGrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state
def embed_and_encode_ques_contexts(text_field_embedder: TextFieldEmbedder, qencoder: Seq2SeqEncoder, batch_size: int, question: Dict[str, torch.LongTensor], contexts: Dict[str, torch.LongTensor]): """ Embed and Encode question and contexts Parameters: ----------- text_field_embedder: ``TextFieldEmbedder`` qencoder: ``Seq2SeqEncoder`` question: Dict[str, torch.LongTensor] Output of a TextField. Should yield tensors of shape (B, ques_length, D) contexts: Dict[str, torch.LongTensor] Output of a TextField. Should yield tensors of shape (B, num_contexts, ques_length, D) Returns: --------- embedded_questions: List[(ques_length, D)] Batch-sized list of embedded questions from the text_field_embedder encoded_questions: List[(ques_length, D)] Batch-sized list of encoded questions from the qencoder questions_mask: List[(ques_length)] Batch-sized list of questions masks encoded_ques_tensor: Shape: (batch_size, ques_len, D) Output of the qencoder questions_mask_tensor: Shape: (batch_size, ques_length) Questions mask as a tensor ques_encoded_final_state: Shape: (batch_size, D) For each question, the final state of the qencoder embedded_contexts: List[(num_contexts, context_length, D)] Batch-sized list of embedded contexts for each instance from the text_field_embedder contexts_mask: List[(num_contexts, context_length)] Batch-sized list of contexts_mask for each context in the instance """ # Shape: (B, question_length, D) embedded_questions_tensor = text_field_embedder(question) # Shape: (B, question_length) questions_mask_tensor = allenutil.get_text_field_mask(question).float() embedded_questions = [ embedded_questions_tensor[i] for i in range(batch_size) ] questions_mask = [questions_mask_tensor[i] for i in range(batch_size)] # Shape: (B, ques_len, D) encoded_ques_tensor = qencoder(embedded_questions_tensor, questions_mask_tensor) # Shape: (B, D) ques_encoded_final_state = allenutil.get_final_encoder_states( encoded_ques_tensor, questions_mask_tensor, qencoder.is_bidirectional()) encoded_questions = [encoded_ques_tensor[i] for i in range(batch_size)] # # contexts is a (B, num_contexts, context_length, *) tensors # (tokenindexer, indexed_tensor) = next(iter(contexts.items())) # num_contexts = indexed_tensor.size()[1] # # Making a separate batched token_indexer_dict for each context -- [{token_inderxer: (C, T, *)}] # contexts_indices_list: List[Dict[str, torch.LongTensor]] = [{} for _ in range(batch_size)] # for token_indexer_name, token_indices_tensor in contexts.items(): # print(f"{token_indexer_name}: {token_indices_tensor.size()}") # for i in range(batch_size): # contexts_indices_list[i][token_indexer_name] = token_indices_tensor[i, ...] # # # Each tensor of shape (num_contexts, context_len, D) # embedded_contexts = [] # contexts_mask = [] # # Shape: (num_contexts, context_length, D) # for i in range(batch_size): # embedded_contexts_i = text_field_embedder(contexts_indices_list[i]) # embedded_contexts.append(embedded_contexts_i) # contexts_mask_i = allenutil.get_text_field_mask(contexts_indices_list[i]).float() # contexts_mask.append(contexts_mask_i) embedded_contexts_tensor = text_field_embedder(contexts, num_wrapping_dims=1) contexts_mask_tensor = allenutil.get_text_field_mask( contexts, num_wrapping_dims=1).float() embedded_contexts = [ embedded_contexts_tensor[i] for i in range(batch_size) ] contexts_mask = [contexts_mask_tensor[i] for i in range(batch_size)] return (embedded_questions, encoded_questions, questions_mask, encoded_ques_tensor, questions_mask_tensor, ques_encoded_final_state, embedded_contexts, contexts_mask)