def test_get_actions_to_consider(self): # pylint: disable=protected-access valid_actions_1 = {'e': [0, 1, 2, 4]} valid_actions_2 = {'e': [0, 1, 3]} valid_actions_3 = {'e': [2, 3, 4]} self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal) self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal) self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal) # We're making a bunch of the actions linked actions here, pretending that there are only # two global actions. self.state.action_indices = { (0, 0): 1, (0, 1): 0, (0, 2): -1, (0, 3): -1, (0, 4): -1, (1, 0): -1, (1, 1): 0, (1, 2): -1, (1, 3): -1, } considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state) # These are _global_ action indices. They come from actions [[(0, 0), (0, 1)], [(1, 1)], []]. expected_to_embed = [[1, 0], [0], []] assert to_embed == expected_to_embed # These are _batch_ action indices with a _global_ action index of -1. # They come from actions [[(0, 2), (0, 4)], [(1, 0), (1, 3)], [(0, 2), (0, 3), (0, 4)]]. expected_to_link = [[2, 4], [0, 3], [2, 3, 4]] assert to_link == expected_to_link # These are _batch_ action indices, with padding in between the embedded actions and the # linked actions (and after the linked actions, if necessary). expected_considered = [[0, 1, 2, 4, -1], [1, -1, 0, 3, -1], [-1, -1, 2, 3, 4]] assert considered == expected_considered
def test_get_valid_actions_adds_lambda_productions_only_for_correct_type( self): state = GrammarState(['t'], {('s', 'x'): ['t']}, { 's': [1, 2], 't': [3, 4] }, {'s -> x': 5}, is_nonterminal) assert state.get_valid_actions() == [3, 4] # We're doing this assert twice to make sure we haven't accidentally modified the state. assert state.get_valid_actions() == [3, 4]
def test_get_valid_actions_adds_lambda_productions_only_for_correct_type(self): state = GrammarState(['t'], {('s', 'x'): ['t']}, {'s': [1, 2], 't': [3, 4]}, {'s -> x': 5}, is_nonterminal) assert state.get_valid_actions() == [3, 4] # We're doing this assert twice to make sure we haven't accidentally modified the state. assert state.get_valid_actions() == [3, 4]
def test_take_action_gives_correct_next_states_with_non_lambda_productions(self): # state.take_action() doesn't read or change these objects, it just passes them through, so # we'll use some sentinels to be sure of that. valid_actions = object() action_indices = object() state = GrammarState(['s'], {}, valid_actions, action_indices, is_nonterminal) next_state = state.take_action('s -> [t, r]') expected_next_state = GrammarState(['r', 't'], {}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = GrammarState(['r', 't'], {}, valid_actions, action_indices, is_nonterminal) next_state = state.take_action('t -> identity') expected_next_state = GrammarState(['r'], {}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__
def _create_grammar_state( self, world: NlvrWorld, possible_actions: List[ProductionRuleArray]) -> GrammarState: valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_mapping[action[0]] = i translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. action_indices = [ action_mapping[action_string] for action_string in action_strings ] # All actions in NLVR are global actions. global_actions = [(possible_actions[index][2], index) for index in action_indices] # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_input_embeddings, list(global_action_ids)) return GrammarState([START_SYMBOL], {}, translated_valid_actions, {}, type_declaration.is_nonterminal)
def test_get_valid_actions_adds_lambda_productions(self): state = GrammarState( ['s'], {('s', 'x'): ['s']}, { 's': { 'global': (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2]) } }, {'s -> x': (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal) actions = state.get_valid_actions() assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5]) assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6]) assert actions['global'][2] == [1, 2, 5] # We're doing this assert twice to make sure we haven't accidentally modified the state. actions = state.get_valid_actions() assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5]) assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6]) assert actions['global'][2] == [1, 2, 5]
def test_get_actions_to_consider_returns_none_if_no_linked_actions(self): # pylint: disable=protected-access valid_actions_1 = {'e': [0, 1, 2, 4]} valid_actions_2 = {'e': [0, 1, 3]} valid_actions_3 = {'e': [2, 3, 4]} self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal) self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal) self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal) considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state) # These are _global_ action indices. All of the actions in this case are embedded, so this # is just a mapping from the valid actions above to their global ids. expected_to_embed = [[1, 0, 2, 5], [4, 0, 3], [2, 3, 5]] assert to_embed == expected_to_embed # There are no linked actions (all of them are embedded), so this should be None. assert to_link is None # These are _batch_ action indices, with padding in between the embedded actions and the # linked actions. Because there are no linked actions, this is basically just the # valid_actions for each group element padded with -1s. expected_considered = [[0, 1, 2, 4], [0, 1, 3, -1], [2, 3, 4, -1]] assert considered == expected_considered
def test_get_valid_actions_uses_top_of_stack(self): state = GrammarState(['s'], {}, { 's': [1, 2], 't': [3, 4] }, {}, is_nonterminal) assert state.get_valid_actions() == [1, 2] state = GrammarState(['t'], {}, { 's': [1, 2], 't': [3, 4] }, {}, is_nonterminal) assert state.get_valid_actions() == [3, 4] state = GrammarState(['e'], {}, { 's': [1, 2], 't': [3, 4], 'e': [] }, {}, is_nonterminal) assert state.get_valid_actions() == []
def test_get_valid_actions_uses_top_of_stack(self): state = GrammarState(['s'], {}, {'s': [1, 2], 't': [3, 4]}, {}, is_nonterminal) assert state.get_valid_actions() == [1, 2] state = GrammarState(['t'], {}, {'s': [1, 2], 't': [3, 4]}, {}, is_nonterminal) assert state.get_valid_actions() == [3, 4] state = GrammarState(['e'], {}, {'s': [1, 2], 't': [3, 4], 'e': []}, {}, is_nonterminal) assert state.get_valid_actions() == []
def _create_grammar_state(world, possible_actions): valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_string = action[0] action_mapping[action_string] = i translated_valid_actions = {} for key, action_strings in list(valid_actions.items()): translated_valid_actions[key] = [ action_mapping[action_string] for action_string in action_strings ] return GrammarState([START_SYMBOL], {}, translated_valid_actions, action_mapping, type_declaration.is_nonterminal)
def test_get_valid_actions_uses_top_of_stack(self): s_actions = object() t_actions = object() e_actions = object() state = GrammarState(['s'], {}, { 's': s_actions, 't': t_actions }, {}, is_nonterminal) assert state.get_valid_actions() == s_actions state = GrammarState(['t'], {}, { 's': s_actions, 't': t_actions }, {}, is_nonterminal) assert state.get_valid_actions() == t_actions state = GrammarState(['e'], {}, { 's': s_actions, 't': t_actions, 'e': e_actions }, {}, is_nonterminal) assert state.get_valid_actions() == e_actions
def _create_grammar_state( world: NlvrWorld, possible_actions: List[ProductionRuleArray]) -> GrammarState: valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_mapping[action[0]] = i translated_valid_actions = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = [ action_mapping[action_string] for action_string in action_strings ] return GrammarState([START_SYMBOL], {}, translated_valid_actions, action_mapping, type_declaration.is_nonterminal)
def test_take_action_gives_correct_next_states_with_non_lambda_productions( self): # state.take_action() doesn't read or change these objects, it just passes them through, so # we'll use some sentinels to be sure of that. valid_actions = object() context_actions = object() state = GrammarState(['s'], {}, valid_actions, context_actions, is_nonterminal) next_state = state.take_action('s -> [t, r]') expected_next_state = GrammarState(['r', 't'], {}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = GrammarState(['r', 't'], {}, valid_actions, context_actions, is_nonterminal) next_state = state.take_action('t -> identity') expected_next_state = GrammarState(['r'], {}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__
def test_take_action_gives_correct_next_states_with_lambda_productions(self): # state.take_action() doesn't read or change these objects, it just passes them through, so # we'll use some sentinels to be sure of that. valid_actions = object() action_indices = object() state = GrammarState(['t', '<s,d>'], {}, valid_actions, action_indices, is_nonterminal) next_state = state.take_action('<s,d> -> [lambda x, d]') expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> [<s,r>, d]') expected_next_state = GrammarState(['t', 'd', '<s,r>'], {('s', 'x'): ['d', '<s,r>']}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('<s,r> -> [lambda y, r]') expected_next_state = GrammarState(['t', 'd', 'r'], {('s', 'x'): ['d', 'r'], ('s', 'y'): ['r']}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('r -> identity') expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> x') expected_next_state = GrammarState(['t'], {}, valid_actions, action_indices, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__
def setUp(self): super().setUp() batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [Variable(torch.FloatTensor([x])) for x in [.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) grammar_state = [ GrammarState(['e'], {}, {}, {}, is_nonterminal) for _ in batch_indices ] self.encoder_outputs = torch.FloatTensor([[1, 2], [3, 4], [5, 6]]) self.encoder_output_mask = Variable( torch.FloatTensor([[1, 1], [1, 0], [1, 1]])) self.action_embeddings = torch.FloatTensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) self.action_indices = { (0, 0): 1, (0, 1): 0, (0, 2): 2, (0, 3): 3, (0, 4): 5, (1, 0): 4, (1, 1): 0, (1, 2): 2, (1, 3): 3, } self.possible_actions = [[ ('e -> f', False, None), ('e -> g', True, None), ('e -> h', True, None), ('e -> i', True, None), ('e -> j', True, None) ], [ ('e -> q', True, None), ('e -> g', True, None), ('e -> h', True, None), ('e -> i', True, None) ]] # (batch_size, num_entities, num_question_tokens) = (2, 5, 3) linking_scores = Variable( torch.Tensor([[[.1, .2, .3], [.4, .5, .6], [.7, .8, .9], [1.0, 1.1, 1.2], [1.3, 1.4, 1.5]], [[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9], [-1.0, -1.1, -1.2], [-1.3, -1.4, -1.5]]])) flattened_linking_scores = linking_scores.view(2 * 5, 3) # Maps (batch_index, action_index) to indices into the flattened linking score tensor, # which has shae (batch_size * num_entities, num_question_tokens). actions_to_entities = { (0, 0): 0, (0, 1): 1, (0, 2): 2, (0, 6): 3, (1, 3): 6, (1, 4): 7, (1, 5): 8, } entity_types = { 0: 0, 1: 2, 2: 1, 3: 0, 4: 0, 5: 1, 6: 0, 7: 1, 8: 2, } rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnState(hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask)) self.state = WikiTablesDecoderState( batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, action_embeddings=self.action_embeddings, action_indices=self.action_indices, possible_actions=self.possible_actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_types)
def test_is_finished_just_uses_nonterminal_stack(self): state = GrammarState(['s'], {}, {}, {}, is_nonterminal) assert not state.is_finished() state = GrammarState([], {}, {}, {}, is_nonterminal) assert state.is_finished()
def test_take_action_gives_correct_next_states_with_lambda_productions( self): # state.take_action() doesn't read or change these objects, it just passes them through, so # we'll use some sentinels to be sure of that. valid_actions = object() context_actions = object() state = GrammarState(['t', '<s,d>'], {}, valid_actions, context_actions, is_nonterminal) next_state = state.take_action('<s,d> -> [lambda x, d]') expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> [<s,r>, d]') expected_next_state = GrammarState(['t', 'd', '<s,r>'], {('s', 'x'): ['d', '<s,r>']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('<s,r> -> [lambda y, r]') expected_next_state = GrammarState(['t', 'd', 'r'], { ('s', 'x'): ['d', 'r'], ('s', 'y'): ['r'] }, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('r -> identity') expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> x') expected_next_state = GrammarState(['t'], {}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__
def test_take_action_crashes_with_mismatched_types(self): with pytest.raises(AssertionError): state = GrammarState(['s'], {}, {}, {}, is_nonterminal) state.take_action('t -> identity')
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name('dot_product')(), num_start_types=3, add_action_bias=False) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { 'e': { 'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2]), 'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4]) }, 'd': { 'global': (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), 'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3]) } } grammar_state = [ GrammarState([nonterminal], {}, valid_actions, {}, is_nonterminal) for _, nonterminal in zip(batch_indices, ['e', 'd', 'e']) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [[ ('e -> f', False, None), ('e -> g', True, None), ('e -> h', True, None), ('e -> i', True, None), ('e -> j', True, None) ], [ ('d -> q', True, None), ('d -> g', True, None), ('d -> h', True, None), ('d -> i', True, None) ]] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnState(hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask)) self.state = GrammarBasedDecoderState( batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions)
def test_get_valid_actions_adds_lambda_productions(self): state = GrammarState([u's'], {(u's', u'x'): [u's']}, {u's': [1, 2]}, {u's -> x': 5}, is_nonterminal) assert state.get_valid_actions() == [1, 2, 5] # We're doing this assert twice to make sure we haven't accidentally modified the state. assert state.get_valid_actions() == [1, 2, 5]
def _create_grammar_state(self, world: WikiTablesWorld, possible_actions: List[ProductionRuleArray], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarState: """ This method creates the GrammarState object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRuleArrays``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``WikiTablesWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRuleArray]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases( global_action_tensor) global_input_embeddings = torch.cat( [global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding( entity_type_tensor) translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) # Lastly, we need to also create embedded representations of context-specific actions. In # this case, those are only variable productions, like "r -> x". Note that our language # only permits one lambda at a time, so we don't need to worry about how nested lambdas # might impact this. context_actions = {} for action_id, action in enumerate(possible_actions): if action[0].endswith(" -> x"): input_embedding = self._action_embedder(action[2]) if self._add_action_bias: input_bias = self._action_biases(action[2]) input_embedding = torch.cat([input_embedding, input_bias], dim=-1) output_embedding = self._output_action_embedder(action[2]) context_actions[action[0]] = (input_embedding, output_embedding, action_id) return GrammarState([START_SYMBOL], {}, translated_valid_actions, context_actions, type_declaration.is_nonterminal)