def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: device = utterance['tokens'].device batch_size, num_entities, _, _ = schema['linking'].size() num_entity_tokens = max([max(len(t) for t in world.db_context.entity_tokens) for world in worlds]) # Hack token ids: # We have multiple chunks in an utterance but all the entity tokens chunks should have same type. utterance['tokens-type-ids'][utterance['tokens-type-ids']>1]=1 utterance['tokens-offsets'] = torch.cat((utterance['tokens-offsets'].new_zeros((batch_size, 1)), utterance['tokens-offsets']), dim=-1) u_lens = [len(world.db_context.tokenized_utterance) for world in worlds] num_question_tokens = max(u_lens) for i in range(0, batch_size): for j, t in enumerate(utterance['tokens'][i]): if t == 102: utterance['tokens-type-ids'][i,j] = 0 combined_embedding = self._question_embedder(utterance) u_embeddings = [] u_e_embeddings = [] embedding_padding_fn = lambda:torch.zeros(self._bert_embedding_dim, device=device) for i, world in enumerate(worlds): u_len = u_lens[i] u_embedding = combined_embedding[i][1:u_len+1] u_embeddings.append(torch.cat([u_embedding, torch.zeros((num_question_tokens-u_len, self._bert_embedding_dim), device=device)])) e_embeddings = [] pos = u_len + 2 for t in world.db_context.entity_tokens: e_embeddings.append(torch.cat([combined_embedding[i][pos:pos+len(t)], torch.zeros((num_entity_tokens-len(t), self._bert_embedding_dim), device=device)])) pos += len(t) + 1 u_e_embeddings.append(torch.stack(pad_sequence_to_length(e_embeddings, desired_length=num_entities, default_value=lambda:torch.zeros((num_entity_tokens, self._bert_embedding_dim), device=device)))) embedded_schema = torch.stack(u_e_embeddings) schema_mask = (embedded_schema.sum(dim=-1) != 0).float() embedded_utterance = torch.stack(u_embeddings) utterance_mask = util.get_mask_from_sequence_lengths(torch.as_tensor(u_lens, device=device).long(), max_length=max(u_lens)) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. embedded_schema = self._entity_encoder(embedded_schema, schema_mask) entity_question_similarity = self._embedding_sim_attn(embedded_schema, embedded_utterance) feature_scores = self._linking_params(schema['linking']).squeeze(3) # (batch_size, num_entities, num_question_tokens) linking_scores = entity_question_similarity * 10 + feature_scores # (batch_size, num_question_tokens, num_entities) linking_prob_by_type = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) entity_embeddings = entity_type_embeddings # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, utterance_mask)) # compute the relevance of each entity with the relevance GNN ent_relevance, ent_relevance_logits, ent_to_qst_lnk_probs = self._graph_pruning(worlds, encoder_outputs, entity_embeddings, linking_scores, utterance_mask, self._get_graph_adj_lists) # save this for loss calculation self.predicted_relevance_logits = ent_relevance_logits # multiply the embedding with the computed relevance graph_initial_embedding = entity_embeddings * ent_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: entities_graph_encoding = self._get_schema_graph_encoding(worlds, graph_initial_embedding) graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_prob_by_type) encoder_outputs = torch.cat(( encoder_outputs, graph_link_embedding ), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim() else: entities_graph_encoding = None if self._self_attend: linked_actions_linking_scores = torch.stack([self._graph_attention(entities_graph_encoding[:,i], entities_graph_encoding) for i in range(0, num_entities)]).transpose(0, 1) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = encoder_outputs.new_zeros(batch_size, encoder_output_dim) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[ i] if entities_graph_encoding is not None else None) for i in range(batch_size)] initial_sql_state = [SqlState(actions[i], self._parse_sql_on_decoding) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[w.get_action_entity_mapping() for w in worlds]) return initial_state
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: schema_text = schema['text'] """KAIMARY""" # TextFieldEmbedder needs a "token" key in the Dict """ embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim]) schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens]) embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim]) entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim]) """ embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1) schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() embedded_utterance = self._question_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size() num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) num_question_tokens = utterance['tokens'].size(1) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_schema.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_utterance, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) """KAIMARY""" # Variable: linking_scores # The entitiy linking score s(e, i) in the Krishnamurthy 2017 # (batch_size, num_entities, num_question_tokens, num_features) linking_features = schema['linking'] linking_scores = question_entity_similarity_max_score feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores """KAIMARY""" # linking_probabilities # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device) if self._use_neighbor_similarity_for_linking and neighbor_indices is not None: """KAIMARY""" # Seq2VecEncoder get the hidden state of the last step as the unique output # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_schema, schema_mask) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) """KAIMARY""" # Variable: entity_embedding # Rv in B Bogin 2019 # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) """KAIMARY""" # Variable: link_embedding # Li in B Bogin 2019 # Is an average of entity vectors weighted by the resulting distribution link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) """KAIMARY""" # Variable: encoder_input # [Wi, Li] in B Bogin 2019 encoder_input = torch.cat([link_embedding, embedded_utterance], 2) # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) """KAIMARY""" # Variable: max_entities_relevance # ρv = maxi plink(v | xi) in B Bogin 2019 # Is the maximum probability of v for any word xi max_entities_relevance = linking_probabilities.max(dim=1)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() """KAIMARY""" # entity_type_embeddings ??? # Variable: graph_initial_embedding # hv(0) in B Bogin 2019 # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN graph_initial_embedding = entity_type_embeddings * entities_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: """KAIMARY""" # Variable: entities_graph_encoding # φv in B Bogin 2019 # Is the final representation of each schema item after L steps entities_graph_encoding = self._get_schema_graph_encoding( worlds, graph_initial_embedding) """KAIMARY""" # Variable: graph_link_embedding # Lφ,i in B Bogin 2019 graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities) encoder_outputs = torch.cat( (encoder_outputs, graph_link_embedding), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim( ) else: entities_graph_encoding = None if self._self_attend: # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding) entities_ff = self._ent2ent_ff(entities_graph_encoding) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state( worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[i] if entities_graph_encoding is not None else None) for i in range(batch_size) ] initial_sql_state = [ SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]]) -> GrammarBasedState: utterance_mask = util.get_text_field_mask(utterance).float() embedded_utterance = self.question_embedder(utterance) batch_size, _, _ = embedded_utterance.size() encoder_outputs = self._dropout( self._question_encoder(embedded_utterance, utterance_mask)) schema_text = schema['text'] input_mm_schema = self._input_mm_embedder(schema_text, num_wrapping_dims=1) output_mm_schema = self._output_mm_embedder(schema_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = input_mm_schema.size() schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() # TODO # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, input_mm_schema.device) # (batch_size, num_entities, embedding_dim) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # (batch_size, num_entities, embedding_dim) # An entity memory-representation is concatenated with two parts: # 1. Entity tokens embedding # 2. Entity type embedding K = torch.cat([ self._input_mm_encoder(input_mm_schema, schema_mask), entity_type_embeddings ], dim=2) V = torch.cat([ self._output_mm_encoder(output_mm_schema, schema_mask), entity_type_embeddings ], dim=2) encoder_output_dim = self._question_encoder.get_output_dim() # Encodes utterance in the context of the schema, which is stored in external memory encoder_outputs_with_context, attn_weights = self._mm_attn( encoder_outputs, K, V) attn_weights = attn_weights.transpose(1, 2) final_encoder_output = util.get_final_encoder_states( encoder_outputs_with_context, utterance_mask, self._question_encoder.is_bidirectional()) max_entities_relevance = attn_weights.max(dim=2)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() if self._self_attend: entities_ff = self._ent2ent_ff(entity_type_embeddings * entities_relevance) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] # RnnStatelet is using to keep track of the internal state of a decoder RNN: initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(worlds[i], valid_actions[i], attn_weights[i], linked_actions_linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_sql_state = [ SqlState(valid_actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=valid_actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state