def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: batch_index = state.batch_indices[group_index] decoder_outputs = state.rnn_state[group_index].decoder_outputs is_linked_action = not state.possible_actions[batch_index][action][ 1] if is_linked_action: if decoder_outputs is None: decoder_outputs = hidden_state[group_index].unsqueeze(0), [ action ], predicted_action_embedding[group_index].unsqueeze(0) else: decoder_outputs_states, decoder_outputs_ids, predicted_action_embeddings = decoder_outputs decoder_outputs = torch.cat( (decoder_outputs_states, hidden_state[group_index].unsqueeze(0)), dim=0), decoder_outputs_ids + [action], torch.cat( (predicted_action_embeddings, predicted_action_embedding[group_index].unsqueeze( 0)), dim=0) for i, _, current_log_probs, _, actions, lsq, lsp, gate, attended_dec in batch_action_probs[ batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() considered_lsq = lsq considered_lsp = lsp gate_value = gate attended_decoder = attended_dec break new_rnn_state = RnnStatelet( hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, decoder_outputs) return state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], considered_lsq, considered_lsp, gate_value)
def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: batch_index = state.batch_indices[group_index] action_entity_id = state.action_entity_mapping[batch_index][action] + 1 # add 1 so that -1 becomes 0 (pad) if not state.action_history[0]: decoded_item_embeddings = state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0) else: decoded_item_embeddings = torch.cat(( state.rnn_state[group_index].decoded_item_embeddings, state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0) ), dim=0) new_rnn_state = RnnStatelet(hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, state.rnn_state[group_index].item_embeddings, decoder_outputs[group_index], decoded_item_embeddings) for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() break return state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], updated_rnn_state['output_attention_weights'])
def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: schema_text = schema['text'] """KAIMARY""" # TextFieldEmbedder needs a "token" key in the Dict """ embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim]) schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens]) embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim]) entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim]) """ embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1) schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() embedded_utterance = self._question_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size() num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) num_question_tokens = utterance['tokens'].size(1) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_schema.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_utterance, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) """KAIMARY""" # Variable: linking_scores # The entitiy linking score s(e, i) in the Krishnamurthy 2017 # (batch_size, num_entities, num_question_tokens, num_features) linking_features = schema['linking'] linking_scores = question_entity_similarity_max_score feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores """KAIMARY""" # linking_probabilities # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device) if self._use_neighbor_similarity_for_linking and neighbor_indices is not None: """KAIMARY""" # Seq2VecEncoder get the hidden state of the last step as the unique output # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_schema, schema_mask) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) """KAIMARY""" # Variable: entity_embedding # Rv in B Bogin 2019 # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) """KAIMARY""" # Variable: link_embedding # Li in B Bogin 2019 # Is an average of entity vectors weighted by the resulting distribution link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) """KAIMARY""" # Variable: encoder_input # [Wi, Li] in B Bogin 2019 encoder_input = torch.cat([link_embedding, embedded_utterance], 2) # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) """KAIMARY""" # Variable: max_entities_relevance # ρv = maxi plink(v | xi) in B Bogin 2019 # Is the maximum probability of v for any word xi max_entities_relevance = linking_probabilities.max(dim=1)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() """KAIMARY""" # entity_type_embeddings ??? # Variable: graph_initial_embedding # hv(0) in B Bogin 2019 # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN graph_initial_embedding = entity_type_embeddings * entities_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: """KAIMARY""" # Variable: entities_graph_encoding # φv in B Bogin 2019 # Is the final representation of each schema item after L steps entities_graph_encoding = self._get_schema_graph_encoding( worlds, graph_initial_embedding) """KAIMARY""" # Variable: graph_link_embedding # Lφ,i in B Bogin 2019 graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities) encoder_outputs = torch.cat( (encoder_outputs, graph_link_embedding), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim( ) else: entities_graph_encoding = None if self._self_attend: # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding) entities_ff = self._ent2ent_ff(entities_graph_encoding) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state( worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[i] if entities_graph_encoding is not None else None) for i in range(batch_size) ] initial_sql_state = [ SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state
def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: """ group_index: batch index action: next action index (created based on the allowed actions[supervision]) new_score: is log probability (log_probs) of the action to now. So it must be a negative score. Calc log_probs in _compute_action_probabilities function, as follow: (new_score = ) log_probs = state.score[group_index] + current_log_probs log_probs is a accumulated value that accumulate the scores (new_score will become the scores in next round). Actually, the score is also used as the loss value. action_embedding: come from 'output_action_embeddings' in '_compute_action_probabilities' """ # state is the Initial state. encoder state. # This function can access the parameter in _construct_next_states. batch_index = state.batch_indices[group_index] # Initial state.rnn_state do not contain the decoder_outputs. # Because the state is encoder state and there is no decoder in spider_parser.py that create the initial state. # But we can give a decoder_outputs to the new state in this function. So next time, the decoder_outputs will be not None. # But for the first time, decoder_outputs is None. If the next action is still the global action, it will keep None. # What is decoder_outputs, please see state_machines/states/rnn_statelet.py. decoder_outputs = state.rnn_state[group_index].decoder_outputs is_linked_action = not state.possible_actions[batch_index][action][ 1] if is_linked_action: # non-global action, so store the hidden state. if decoder_outputs is None: decoder_outputs = hidden_state[group_index].unsqueeze(0), [ action ] else: # store all hidden_state, so use torch.cat decoder_outputs_states, decoder_outputs_ids = decoder_outputs decoder_outputs = torch.cat( (decoder_outputs_states, hidden_state[group_index].unsqueeze(0)), dim=0), decoder_outputs_ids + [action] # Create a new RnnStatelet to instead of the old one. new_rnn_state = RnnStatelet( hidden_state[group_index], # updated_state['hidden_state'], memory_cell[group_index], # updated_state['memory_cell'] # It come from 'output_action_embeddings' in '_compute_action_probabilities' # It is the embeding from self._output_action_embedder if last action is 'global'. action_embedding, attended_question[ group_index], # updated_state['attended_question'] state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, decoder_outputs) for i, _, current_log_probs, _, actions, lsq, lsp in batch_action_probs[ batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() considered_lsq = lsq considered_lsp = lsp break else: pass #assert False # I think it will not be false. But it will be here after training, why ??? return state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], considered_lsq, considered_lsp)
def make_state( group_index: int, action: int, new_score: torch.Tensor, action_input_embedding: torch.Tensor, action_output_embedding: torch.Tensor) -> GrammarBasedState: batch_index = state.batch_indices[group_index] decoder_outputs = state.rnn_state[group_index].decoder_outputs is_linked_action = not state.possible_actions[batch_index][action][ 1] if is_linked_action: if decoder_outputs is None: decoder_outputs = hidden_state[group_index].unsqueeze(0), [ action ] else: decoder_outputs_states, decoder_outputs_ids = decoder_outputs decoder_outputs = torch.cat( (decoder_outputs_states, hidden_state[group_index].unsqueeze(0)), dim=0), decoder_outputs_ids + [action] # Temporal encoding for both input and output action embeddings decoding_step = state.rnn_state[group_index].decoding_step + 1 decoder_input_action_embeddings = state.rnn_state[ group_index].decoder_input_action_embeddings if decoder_input_action_embeddings is None: decoder_input_action_embeddings = ( action_input_embedding + self.A_t( torch.tensor(decoding_step, dtype=torch.long, device=action_input_embedding.device)) ).unsqueeze(0) else: decoder_input_action_embeddings = torch.cat( (decoder_input_action_embeddings, (action_input_embedding + self.A_t( torch.tensor(decoding_step, dtype=torch.long, device=action_input_embedding.device)) ).unsqueeze(0)), dim=0) decoder_output_action_embeddings = state.rnn_state[ group_index].decoder_output_action_embeddings if decoder_output_action_embeddings is None: decoder_output_action_embeddings = ( action_output_embedding + self.B_t( torch.tensor(decoding_step, dtype=torch.long, device=action_output_embedding.device)) ).unsqueeze(0) else: decoder_output_action_embeddings = torch.cat( (decoder_output_action_embeddings, (action_output_embedding + self.B_t( torch.tensor(decoding_step, dtype=torch.long, device=action_output_embedding.device)) ).unsqueeze(0)), dim=0) new_rnn_state = RnnStatelet( hidden_state[group_index], memory_cell[group_index], action_output_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, decoding_step, decoder_outputs, decoder_input_action_embeddings, decoder_output_action_embeddings) for i, _, current_log_probs, _, _, actions, lsq, lsp in batch_action_probs[ batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() considered_lsq = lsq considered_lsp = lsp break return state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], considered_lsq, considered_lsp)
def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: device = utterance['tokens'].device batch_size, num_entities, _, _ = schema['linking'].size() num_entity_tokens = max([max(len(t) for t in world.db_context.entity_tokens) for world in worlds]) # Hack token ids: # We have multiple chunks in an utterance but all the entity tokens chunks should have same type. utterance['tokens-type-ids'][utterance['tokens-type-ids']>1]=1 utterance['tokens-offsets'] = torch.cat((utterance['tokens-offsets'].new_zeros((batch_size, 1)), utterance['tokens-offsets']), dim=-1) u_lens = [len(world.db_context.tokenized_utterance) for world in worlds] num_question_tokens = max(u_lens) for i in range(0, batch_size): for j, t in enumerate(utterance['tokens'][i]): if t == 102: utterance['tokens-type-ids'][i,j] = 0 combined_embedding = self._question_embedder(utterance) u_embeddings = [] u_e_embeddings = [] embedding_padding_fn = lambda:torch.zeros(self._bert_embedding_dim, device=device) for i, world in enumerate(worlds): u_len = u_lens[i] u_embedding = combined_embedding[i][1:u_len+1] u_embeddings.append(torch.cat([u_embedding, torch.zeros((num_question_tokens-u_len, self._bert_embedding_dim), device=device)])) e_embeddings = [] pos = u_len + 2 for t in world.db_context.entity_tokens: e_embeddings.append(torch.cat([combined_embedding[i][pos:pos+len(t)], torch.zeros((num_entity_tokens-len(t), self._bert_embedding_dim), device=device)])) pos += len(t) + 1 u_e_embeddings.append(torch.stack(pad_sequence_to_length(e_embeddings, desired_length=num_entities, default_value=lambda:torch.zeros((num_entity_tokens, self._bert_embedding_dim), device=device)))) embedded_schema = torch.stack(u_e_embeddings) schema_mask = (embedded_schema.sum(dim=-1) != 0).float() embedded_utterance = torch.stack(u_embeddings) utterance_mask = util.get_mask_from_sequence_lengths(torch.as_tensor(u_lens, device=device).long(), max_length=max(u_lens)) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. embedded_schema = self._entity_encoder(embedded_schema, schema_mask) entity_question_similarity = self._embedding_sim_attn(embedded_schema, embedded_utterance) feature_scores = self._linking_params(schema['linking']).squeeze(3) # (batch_size, num_entities, num_question_tokens) linking_scores = entity_question_similarity * 10 + feature_scores # (batch_size, num_question_tokens, num_entities) linking_prob_by_type = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) entity_embeddings = entity_type_embeddings # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, utterance_mask)) # compute the relevance of each entity with the relevance GNN ent_relevance, ent_relevance_logits, ent_to_qst_lnk_probs = self._graph_pruning(worlds, encoder_outputs, entity_embeddings, linking_scores, utterance_mask, self._get_graph_adj_lists) # save this for loss calculation self.predicted_relevance_logits = ent_relevance_logits # multiply the embedding with the computed relevance graph_initial_embedding = entity_embeddings * ent_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: entities_graph_encoding = self._get_schema_graph_encoding(worlds, graph_initial_embedding) graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_prob_by_type) encoder_outputs = torch.cat(( encoder_outputs, graph_link_embedding ), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim() else: entities_graph_encoding = None if self._self_attend: linked_actions_linking_scores = torch.stack([self._graph_attention(entities_graph_encoding[:,i], entities_graph_encoding) for i in range(0, num_entities)]).transpose(0, 1) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = encoder_outputs.new_zeros(batch_size, encoder_output_dim) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[ i] if entities_graph_encoding is not None else None) for i in range(batch_size)] initial_sql_state = [SqlState(actions[i], self._parse_sql_on_decoding) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[w.get_action_entity_mapping() for w in worlds]) return initial_state