def _prepare_decode_step_input( self, input_indices: torch.LongTensor, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor: """ Given the input indices for the current timestep of the decoder, and all the encoder outputs, compute the input at the current timestep. Note: This method is agnostic to whether the indices are gold indices or the predictions made by the decoder at the last timestep. So, this can be used even if we're doing some kind of scheduled sampling. If we're not using attention, the output of this method is just an embedding of the input indices. If we are, the output will be a concatentation of the embedding and an attended average of the encoder inputs. Parameters ---------- input_indices : torch.LongTensor Indices of either the gold inputs to the decoder or the predicted labels from the previous timestep. decoder_hidden_state : torch.LongTensor, optional (not needed if no attention) Output of from the decoder at the last time step. Needed only if using attention. encoder_outputs : torch.LongTensor, optional (not needed if no attention) Encoder outputs from all time steps. Needed only if using attention. encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention) Masks on encoder outputs. Needed only if using attention. """ # input_indices : (batch_size,) since we are processing these one timestep at a time. # (batch_size, target_embedding_dim) embedded_input = self._target_embedder(input_indices) if self._attention_function: # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim) # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will # complain. encoder_outputs_mask = encoder_outputs_mask.float() # (batch_size, input_sequence_length) input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # (batch_size, encoder_output_dim) attended_input = weighted_sum(encoder_outputs, input_weights) # (batch_size, encoder_output_dim + target_embedding_dim) return torch.cat((attended_input, embedded_input), -1) else: return embedded_input
def test_weighted_sum_handles_3d_attention_with_3d_matrix(self): batch_size = 1 length_1 = 5 length_2 = 2 embedding_dim = 4 sentence_array = numpy.random.rand(batch_size, length_2, embedding_dim) attention_array = numpy.random.rand(batch_size, length_1, length_2) sentence_tensor = torch.from_numpy(sentence_array).float() attention_tensor = torch.from_numpy(attention_array).float() aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy() assert aggregated_array.shape == (batch_size, length_1, embedding_dim) for i in range(length_1): expected_array = (attention_array[0, i, 0] * sentence_array[0, 0] + attention_array[0, i, 1] * sentence_array[0, 1]) numpy.testing.assert_almost_equal(aggregated_array[0, i], expected_array, decimal=5)
def test_weighted_sum_works_on_simple_input(self): batch_size = 1 sentence_length = 5 embedding_dim = 4 sentence_array = numpy.random.rand(batch_size, sentence_length, embedding_dim) sentence_tensor = torch.from_numpy(sentence_array).float() attention_tensor = torch.FloatTensor([[.3, .4, .1, 0, 1.2]]) aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy() assert aggregated_array.shape == (batch_size, embedding_dim) expected_array = (0.3 * sentence_array[0, 0] + 0.4 * sentence_array[0, 1] + 0.1 * sentence_array[0, 2] + 0.0 * sentence_array[0, 3] + 1.2 * sentence_array[0, 4]) numpy.testing.assert_almost_equal(aggregated_array, [expected_array], decimal=5)
def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ batch_size, sequence_length, _ = tokens.size() # Shape: (batch_size, sequence_length, sequence_length) similarity_matrix = self._matrix_attention(tokens, tokens) if self._num_attention_heads > 1: # In this case, the similarity matrix actually has shape # (batch_size, sequence_length, sequence_length, num_heads). To make the rest of the # logic below easier, we'll permute this to # (batch_size, sequence_length, num_heads, sequence_length). similarity_matrix = similarity_matrix.permute(0, 1, 3, 2) # Shape: (batch_size, sequence_length, [num_heads,] sequence_length) intra_sentence_attention = util.last_dim_softmax(similarity_matrix.contiguous(), mask) # Shape: (batch_size, sequence_length, projection_dim) output_token_representation = self._projection(tokens) if self._num_attention_heads > 1: # We need to split and permute the output representation to be # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can # do a proper weighted sum with `intra_sentence_attention`. shape = list(output_token_representation.size()) new_shape = shape[:-1] + [self._num_attention_heads, -1] # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads) output_token_representation = output_token_representation.view(*new_shape) # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads) output_token_representation = output_token_representation.permute(0, 2, 1, 3) # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) attended_sentence = util.weighted_sum(output_token_representation, intra_sentence_attention) if self._num_attention_heads > 1: # Here we concatenate the weighted representation for each head. We'll accomplish this # just with a resize. # Shape: (batch_size, sequence_length, projection_dim) attended_sentence = attended_sentence.view(batch_size, sequence_length, -1) # Shape: (batch_size, sequence_length, combination_dim) combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence]) return self._output_projection(combined_tensors)
def attend_on_question( self, query: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a query (which is typically the decoder hidden state), compute an attention over the output of the question encoder, and return a weighted sum of the question representations given this attention. We also return the attention weights themselves. This is a simple computation, but we have it as a separate method so that the ``forward`` method on the main parser module can call it on the initial hidden state, to simplify the logic in ``take_step``. """ # (group_size, question_length) question_attention_weights = self._input_attention( query, encoder_outputs, encoder_output_mask) # (group_size, encoder_output_dim) attended_question = util.weighted_sum(encoder_outputs, question_attention_weights) return attended_question, question_attention_weights
def test_weighted_sum_handles_higher_order_input(self): batch_size = 1 length_1 = 5 length_2 = 6 length_3 = 2 embedding_dim = 4 sentence_array = numpy.random.rand(batch_size, length_1, length_2, length_3, embedding_dim) attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3) sentence_tensor = torch.from_numpy(sentence_array).float() attention_tensor = torch.from_numpy(attention_array).float() aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy() assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim) expected_array = ( attention_array[0, 3, 2, 0] * sentence_array[0, 3, 2, 0] + attention_array[0, 3, 2, 1] * sentence_array[0, 3, 2, 1]) numpy.testing.assert_almost_equal(aggregated_array[0, 3, 2], expected_array, decimal=5)
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``. label : torch.LongTensor, optional (default = None) A variable representing the label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ text_mask = util.get_text_field_mask(tokens).float() # Pop elmo tokens, since elmo embedder should not be present. elmo_tokens = tokens.pop("elmo", None) embedded_text = self._text_field_embedder(tokens) # Add the "elmo" key back to "tokens" if not None, since the tests and the # subsequent training epochs rely not being modified during forward() if elmo_tokens is not None: tokens["elmo"] = elmo_tokens # Create ELMo embeddings if applicable if self._elmo: if elmo_tokens is not None: elmo_representations = self._elmo(elmo_tokens)["elmo_representations"] # Pop from the end is more performant with list if self._use_integrator_output_elmo: integrator_output_elmo = elmo_representations.pop() if self._use_input_elmo: input_elmo = elmo_representations.pop() assert not elmo_representations else: raise ConfigurationError( "Model was built to use Elmo, but input text is not tokenized for Elmo.") if self._use_input_elmo: embedded_text = torch.cat([embedded_text, input_elmo], dim=-1) dropped_embedded_text = self._embedding_dropout(embedded_text) pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text) encoded_tokens = self._encoder(pre_encoded_text, text_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.last_dim_softmax(attention_logits, text_mask) encoded_text = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([encoded_tokens, encoded_tokens - encoded_text, encoded_tokens * encoded_text], 2) integrated_encodings = self._integrator(integrator_input, text_mask) # Concatenate ELMo representations to integrated_encodings if specified if self._use_integrator_output_elmo: integrated_encodings = torch.cat([integrated_encodings, integrator_output_elmo], dim=-1) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), -1e7) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), +1e7) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, text_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self._integrator_dropout(pooled_representations) logits = self._output_layer(pooled_representations_dropped) class_probabilities = F.softmax(logits, dim=-1) output_dict = {'logits': logits, 'class_probabilities': class_probabilities} if label is not None: loss = self.loss(logits, label) for metric in self.metrics.values(): metric(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.last_dim_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional, (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward(hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs, "h2p_attention": h2p_attention, "p2h_attention": p2h_attention} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss if metadata is not None: output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata] output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata] return output_dict
def _get_initial_state_and_scores( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions( actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions( linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return { "initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score }
def forward( self, # pylint: disable=arguments-differ inputs: torch.Tensor, mask: torch.LongTensor = None) -> torch.FloatTensor: """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, timesteps, input_dim) mask : ``torch.FloatTensor``, optional (default = None). A tensor of shape (batch_size, timesteps). Returns ------- A tensor of shape (batch_size, timesteps, output_projection_dim), where output_projection_dim = input_dim by default. """ num_heads = self._num_heads batch_size, timesteps, _ = inputs.size() if mask is None: mask = inputs.new_ones(batch_size, timesteps) # Shape (batch_size, timesteps, 2 * attention_dim + values_dim) combined_projection = self._combined_projection(inputs) # split by attention dim - if values_dim > attention_dim, we will get more # than 3 elements returned. All of the rest are the values vector, so we # just concatenate them back together again below. queries, keys, *values = combined_projection.split( self._attention_dim, -1) queries = queries.contiguous() keys = keys.contiguous() values = torch.cat(values, -1).contiguous() # Shape (num_heads * batch_size, timesteps, values_dim / num_heads) values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim / num_heads)) values_per_head = values_per_head.transpose(1, 2).contiguous() values_per_head = values_per_head.view( batch_size * num_heads, timesteps, int(self._values_dim / num_heads)) # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads) queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim / num_heads)) queries_per_head = queries_per_head.transpose(1, 2).contiguous() queries_per_head = queries_per_head.view( batch_size * num_heads, timesteps, int(self._attention_dim / num_heads)) # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads) keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim / num_heads)) keys_per_head = keys_per_head.transpose(1, 2).contiguous() keys_per_head = keys_per_head.view( batch_size * num_heads, timesteps, int(self._attention_dim / num_heads)) # shape (num_heads * batch_size, timesteps, timesteps) scaled_similarities = torch.bmm( queries_per_head, keys_per_head.transpose(1, 2)) / self._scale # shape (num_heads * batch_size, timesteps, timesteps) # Normalise the distributions, using the same mask for all heads. attention = last_dim_softmax( scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps)) attention = self._attention_dropout(attention) # Take a weighted sum of the values with respect to the attention # distributions for each element in the num_heads * batch_size dimension. # shape (num_heads * batch_size, timesteps, values_dim/num_heads) outputs = weighted_sum(values_per_head, attention) # Reshape back to original shape (batch_size, timesteps, values_dim) # shape (batch_size, num_heads, timesteps, values_dim/num_heads) outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads)) # shape (batch_size, timesteps, num_heads, values_dim/num_heads) outputs = outputs.transpose(1, 2).contiguous() # shape (batch_size, timesteps, values_dim) outputs = outputs.view(batch_size, timesteps, self._values_dim) # Project back to original input size. # shape (batch_size, timesteps, input_size) outputs = self._output_projection(outputs) return outputs
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat([ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis ], dim=-1) hypothesis_enhanced = torch.cat([ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise ], dim=-1) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward( premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward( hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout( projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout( projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1), -1e7).max(dim=1) v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1), -1e7).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(premise_mask, 1, keepdim=True) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict