def test_get_entity_action_logits(self): decoder_step = WikiTablesDecoderStep(1, 5, SimilarityFunction.from_params(Params({})), 5, 3) actions_to_link = [[1, 2], [3, 4, 5], [6]] # (group_size, num_question_tokens) = (3, 3) attention_weights = Variable(torch.Tensor([[.2, .8, 0], [.7, .1, .2], [.3, .3, .4]])) action_logits, mask, type_embeddings = decoder_step._get_entity_action_logits(self.state, actions_to_link, attention_weights) assert_almost_equal(mask.data.cpu().numpy(), [[1, 1, 0], [1, 1, 1], [1, 0, 0]]) assert tuple(action_logits.size()) == (3, 3) assert_almost_equal(action_logits[0, 0].data.cpu().numpy(), .4 * .2 + .5 * .8 + .6 * 0) assert_almost_equal(action_logits[0, 1].data.cpu().numpy(), .7 * .2 + .8 * .8 + .9 * 0) assert_almost_equal(action_logits[1, 0].data.cpu().numpy(), -.4 * .7 + -.5 * .1 + -.6 * .2) assert_almost_equal(action_logits[1, 1].data.cpu().numpy(), -.7 * .7 + -.8 * .1 + -.9 * .2) assert_almost_equal(action_logits[1, 2].data.cpu().numpy(), -1.0 * .7 + -1.1 * .1 + -1.2 * .2) assert_almost_equal(action_logits[2, 0].data.cpu().numpy(), 1.0 * .3 + 1.1 * .3 + 1.2 * .4) embedding_matrix = decoder_step._entity_type_embedding.weight.data.cpu().numpy() assert_almost_equal(type_embeddings[0, 0].data.cpu().numpy(), embedding_matrix[2]) assert_almost_equal(type_embeddings[0, 1].data.cpu().numpy(), embedding_matrix[1]) assert_almost_equal(type_embeddings[1, 0].data.cpu().numpy(), embedding_matrix[0]) assert_almost_equal(type_embeddings[1, 1].data.cpu().numpy(), embedding_matrix[1]) assert_almost_equal(type_embeddings[1, 2].data.cpu().numpy(), embedding_matrix[2]) assert_almost_equal(type_embeddings[2, 0].data.cpu().numpy(), embedding_matrix[0])
def test_get_entity_action_logits(self): decoder_step = WikiTablesDecoderStep(1, 5, SimilarityFunction.from_params(Params({})), 5, 3) actions_to_link = [[1, 2], [3, 4, 5], [6]] # (group_size, num_question_tokens) = (3, 3) attention_weights = torch.Tensor([[.2, .8, 0], [.7, .1, .2], [.3, .3, .4]]) action_logits, mask, type_embeddings = decoder_step._get_entity_action_logits(self.state, actions_to_link, attention_weights) assert_almost_equal(mask.detach().cpu().numpy(), [[1, 1, 0], [1, 1, 1], [1, 0, 0]]) assert tuple(action_logits.size()) == (3, 3) assert_almost_equal(action_logits[0, 0].detach().cpu().numpy(), .4 * .2 + .5 * .8 + .6 * 0) assert_almost_equal(action_logits[0, 1].detach().cpu().numpy(), .7 * .2 + .8 * .8 + .9 * 0) assert_almost_equal(action_logits[1, 0].detach().cpu().numpy(), -.4 * .7 + -.5 * .1 + -.6 * .2) assert_almost_equal(action_logits[1, 1].detach().cpu().numpy(), -.7 * .7 + -.8 * .1 + -.9 * .2) assert_almost_equal(action_logits[1, 2].detach().cpu().numpy(), -1.0 * .7 + -1.1 * .1 + -1.2 * .2) assert_almost_equal(action_logits[2, 0].detach().cpu().numpy(), 1.0 * .3 + 1.1 * .3 + 1.2 * .4) embedding_matrix = decoder_step._entity_type_embedding.weight.detach().cpu().numpy() assert_almost_equal(type_embeddings[0, 0].detach().cpu().numpy(), embedding_matrix[2]) assert_almost_equal(type_embeddings[0, 1].detach().cpu().numpy(), embedding_matrix[1]) assert_almost_equal(type_embeddings[1, 0].detach().cpu().numpy(), embedding_matrix[0]) assert_almost_equal(type_embeddings[1, 1].detach().cpu().numpy(), embedding_matrix[1]) assert_almost_equal(type_embeddings[1, 2].detach().cpu().numpy(), embedding_matrix[2]) assert_almost_equal(type_embeddings[2, 0].detach().cpu().numpy(), embedding_matrix[0])
def test_get_action_embeddings(self): action_embeddings = Variable(torch.rand(5, 4)) self.state.action_embeddings = action_embeddings actions_to_embed = [[0, 4], [1], [2, 3, 4]] embeddings, mask = WikiTablesDecoderStep._get_action_embeddings( self.state, actions_to_embed) assert_almost_equal(mask.data.cpu().numpy(), [[1, 1, 0], [1, 0, 0], [1, 1, 1]]) assert tuple(embeddings.size()) == (3, 3, 4) assert_almost_equal(embeddings[0, 0].data.cpu().numpy(), action_embeddings[0].data.cpu().numpy()) assert_almost_equal(embeddings[0, 1].data.cpu().numpy(), action_embeddings[4].data.cpu().numpy()) assert_almost_equal(embeddings[0, 2].data.cpu().numpy(), action_embeddings[0].data.cpu().numpy()) assert_almost_equal(embeddings[1, 0].data.cpu().numpy(), action_embeddings[1].data.cpu().numpy()) assert_almost_equal(embeddings[1, 1].data.cpu().numpy(), action_embeddings[0].data.cpu().numpy()) assert_almost_equal(embeddings[1, 2].data.cpu().numpy(), action_embeddings[0].data.cpu().numpy()) assert_almost_equal(embeddings[2, 0].data.cpu().numpy(), action_embeddings[2].data.cpu().numpy()) assert_almost_equal(embeddings[2, 1].data.cpu().numpy(), action_embeddings[3].data.cpu().numpy()) assert_almost_equal(embeddings[2, 2].data.cpu().numpy(), action_embeddings[4].data.cpu().numpy())
def test_get_actions_to_consider(self): # pylint: disable=protected-access valid_actions_1 = {'e': [0, 1, 2, 4]} valid_actions_2 = {'e': [0, 1, 3]} valid_actions_3 = {'e': [2, 3, 4]} self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal) self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal) self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal) # We're making a bunch of the actions linked actions here, pretending that there are only # two global actions. self.state.action_indices = { (0, 0): 1, (0, 1): 0, (0, 2): -1, (0, 3): -1, (0, 4): -1, (1, 0): -1, (1, 1): 0, (1, 2): -1, (1, 3): -1, } considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state) # These are _global_ action indices. They come from actions [[(0, 0), (0, 1)], [(1, 1)], []]. expected_to_embed = [[1, 0], [0], []] assert to_embed == expected_to_embed # These are _batch_ action indices with a _global_ action index of -1. # They come from actions [[(0, 2), (0, 4)], [(1, 0), (1, 3)], [(0, 2), (0, 3), (0, 4)]]. expected_to_link = [[2, 4], [0, 3], [2, 3, 4]] assert to_link == expected_to_link # These are _batch_ action indices, with padding in between the embedded actions and the # linked actions (and after the linked actions, if necessary). expected_considered = [[0, 1, 2, 4, -1], [1, -1, 0, 3, -1], [-1, -1, 2, 3, 4]] assert considered == expected_considered
def test_get_action_embeddings(self): action_embeddings = torch.rand(5, 4) self.state.action_embeddings = action_embeddings self.state.output_action_embeddings = action_embeddings self.state.action_biases = torch.rand(5, 1) actions_to_embed = [[0, 4], [1], [2, 3, 4]] embeddings, _, _, mask = WikiTablesDecoderStep._get_action_embeddings(self.state, actions_to_embed) assert_almost_equal(mask.detach().cpu().numpy(), [[1, 1, 0], [1, 0, 0], [1, 1, 1]]) assert tuple(embeddings.size()) == (3, 3, 4) assert_almost_equal(embeddings[0, 0].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy()) assert_almost_equal(embeddings[0, 1].detach().cpu().numpy(), action_embeddings[4].detach().cpu().numpy()) assert_almost_equal(embeddings[0, 2].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy()) assert_almost_equal(embeddings[1, 0].detach().cpu().numpy(), action_embeddings[1].detach().cpu().numpy()) assert_almost_equal(embeddings[1, 1].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy()) assert_almost_equal(embeddings[1, 2].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy()) assert_almost_equal(embeddings[2, 0].detach().cpu().numpy(), action_embeddings[2].detach().cpu().numpy()) assert_almost_equal(embeddings[2, 1].detach().cpu().numpy(), action_embeddings[3].detach().cpu().numpy()) assert_almost_equal(embeddings[2, 2].detach().cpu().numpy(), action_embeddings[4].detach().cpu().numpy())
def test_get_actions_to_consider_returns_none_if_no_linked_actions(self): # pylint: disable=protected-access valid_actions_1 = {'e': [0, 1, 2, 4]} valid_actions_2 = {'e': [0, 1, 3]} valid_actions_3 = {'e': [2, 3, 4]} self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal) self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal) self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal) considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state) # These are _global_ action indices. All of the actions in this case are embedded, so this # is just a mapping from the valid actions above to their global ids. expected_to_embed = [[1, 0, 2, 5], [4, 0, 3], [2, 3, 5]] assert to_embed == expected_to_embed # There are no linked actions (all of them are embedded), so this should be None. assert to_link is None # These are _batch_ action indices, with padding in between the embedded actions and the # linked actions. Because there are no linked actions, this is basically just the # valid_actions for each group element padded with -1s. expected_considered = [[0, 1, 2, 4], [0, 1, 3, -1], [2, 3, 4, -1]] assert considered == expected_considered
def __init__(self, vocab, question_embedder, action_embedding_dim, encoder, entity_encoder, decoder_beam_search, max_decoding_steps, attention, mixture_feedforward=None, training_beam_size=None, use_neighbor_similarity_for_linking=False, dropout=0.0, num_linking_features=10, rule_namespace=u'rule_labels', tables_directory=u'/wikitables/'): use_similarity = use_neighbor_similarity_for_linking super(WikiTablesMmlSemanticParser, self).__init__( vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._decoder_step = WikiTablesDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout)
def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention_function: SimilarityFunction, training_beam_size: int = None, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._decoder_step = WikiTablesDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, attention_function=attention_function, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout)
def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, input_attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/', initial_mml_model_file: str = None) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = \ ExpectedRiskMinimization(beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states) unlinked_terminals_global_indices = [] global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace) for production, index in global_vocab.items(): right_side = production.split(" -> ")[1] if right_side in types.COMMON_NAME_MAPPING: # This is a terminal production. unlinked_terminals_global_indices.append(index) self._num_unlinked_terminals = len(unlinked_terminals_global_indices) self._decoder_step = WikiTablesDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout, unlinked_terminal_indices=unlinked_terminals_global_indices) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.")
def test_compute_new_states_with_no_action_constraints(self): # pylint: disable=protected-access # This test is basically identical to the previous one, but without specifying # `allowed_actions`. This makes sure we get the right behavior at test time. log_probs = Variable( torch.FloatTensor([[.1, .9, -.1, .2], [.3, 1.1, .1, .8], [.1, .25, .3, .4]])) considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]] max_actions = 1 step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2], [3, 3]], [[4, 4], [9, 9], [3, 3], [9, 9]], [[1, 1], [2, 2], [5, 5], [9, 9]]]) new_hidden_state = torch.FloatTensor( [[i + 1, i + 1] for i in range(len(considered_actions))]) new_memory_cell = torch.FloatTensor( [[i + 1, i + 1] for i in range(len(considered_actions))]) new_attended_question = torch.FloatTensor( [[i + 1, i + 1] for i in range(len(considered_actions))]) new_attention_weights = torch.FloatTensor( [[i + 1, i + 1] for i in range(len(considered_actions))]) new_states = WikiTablesDecoderStep._compute_new_states( self.state, log_probs, new_hidden_state, new_memory_cell, step_action_embeddings, new_attended_question, new_attention_weights, considered_actions, allowed_actions=None, max_actions=max_actions) assert len(new_states) == 2 new_state = new_states[0] # For batch instance 0, we should have selected action 1 from group index 0. assert new_state.batch_indices == [0] assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [.9]) # These two have values taken from what's defined in setUp() - the prior action history # ([1]) and the nonterminals corresponding to the action we picked ('j'). assert new_state.action_history == [[1, 1]] assert new_state.grammar_state[0]._nonterminal_stack == ['g'] # All of these values come from the objects instantiated directly above. assert_almost_equal( new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [1, 1]) assert_almost_equal( new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [1, 1]) assert_almost_equal( new_state.rnn_state[0].previous_action_embedding.cpu().numpy(). tolist(), [9, 9]) assert_almost_equal( new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [1, 1]) # And these should just be copied from the prior state. assert_almost_equal( new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal( new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(), self.encoder_output_mask.data.cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions new_state = new_states[1] # For batch instance 0, we should have selected action 0 from group index 1. assert new_state.batch_indices == [1] assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [.3]) # These have values taken from what's defined in setUp() - the prior action history # ([3, 4]) and the nonterminals corresponding to the action we picked ('q'). assert new_state.action_history == [[3, 4, 0]] assert new_state.grammar_state[0]._nonterminal_stack == ['q'] # All of these values come from the objects instantiated directly above. assert_almost_equal( new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2]) assert_almost_equal( new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2]) assert_almost_equal( new_state.rnn_state[0].previous_action_embedding.cpu().numpy(). tolist(), [4, 4]) assert_almost_equal( new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [2, 2]) # And these should just be copied from the prior state. assert_almost_equal( new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal( new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(), self.encoder_output_mask.data.cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions
def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, max_decoding_steps: int, attention_function: SimilarityFunction, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: super(WikiTablesSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._max_decoding_steps = max_decoding_steps self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = WikiTablesAccuracy(tables_directory) self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal(self._first_action_embedding) torch.nn.init.normal(self._first_attended_question) check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(), "entity word average embedding dim", "question embedding dim") self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None if self._use_neighbor_similarity_for_linking: self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) else: self._question_entity_params = None self._question_neighbor_params = None self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, attention_function=attention_function, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout)
def test_compute_new_states(self): # pylint: disable=protected-access log_probs = Variable(torch.FloatTensor([[.1, .9, -.1, .2], [.3, 1.1, .1, .8], [.1, .25, .3, .4]])) considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]] allowed_actions = [{2, 3}, {0}, {4}] max_actions = 1 step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2], [3, 3]], [[4, 4], [9, 9], [3, 3], [9, 9]], [[1, 1], [2, 2], [5, 5], [9, 9]]]) new_hidden_state = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))]) new_memory_cell = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))]) new_attended_question = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))]) new_attention_weights = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))]) new_states = WikiTablesDecoderStep._compute_new_states(self.state, log_probs, new_hidden_state, new_memory_cell, step_action_embeddings, new_attended_question, new_attention_weights, considered_actions, allowed_actions, max_actions) assert len(new_states) == 2 new_state = new_states[0] # For batch instance 0, we should have selected action 4 from group index 2. assert new_state.batch_indices == [0] # These three have values taken from what's defined in setUp() - the prior action history # (empty in this case), the initial score (2.2), and the nonterminals corresponding to the # action we picked ('j'). assert new_state.action_history == [[4]] assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [2.2 + .3]) assert new_state.grammar_state[0]._nonterminal_stack == ['j'] # All of these values come from the objects instantiated directly above. assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [3, 3]) assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [3, 3]) assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [5, 5]) assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [3, 3]) # And these should just be copied from the prior state. assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(), self.encoder_output_mask.data.cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions new_state = new_states[1] # For batch instance 1, we should have selected action 0 from group index 1. assert new_state.batch_indices == [1] # These three have values taken from what's defined in setUp() - the prior action history # ([3, 4]), the initial score (1.1), and the nonterminals corresponding to the action we # picked ('q'). assert new_state.action_history == [[3, 4, 0]] assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [1.1 + .3]) assert new_state.grammar_state[0]._nonterminal_stack == ['q'] # All of these values come from the objects instantiated directly above. assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2]) assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2]) assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [4, 4]) assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [2, 2]) # And these should just be copied from the prior state. assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(), self.encoder_output_mask.data.cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions
def test_compute_new_states_with_no_action_constraints(self): # pylint: disable=protected-access # This test is basically identical to the previous one, but without specifying # `allowed_actions`. This makes sure we get the right behavior at test time. log_probs = torch.FloatTensor([[.1, .9, -.1, .2], [.3, 1.1, .1, .8], [.1, .25, .3, .4]]) considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]] max_actions = 1 step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2], [3, 3]], [[4, 4], [9, 9], [3, 3], [9, 9]], [[1, 1], [2, 2], [5, 5], [9, 9]]]) new_hidden_state = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))]) new_memory_cell = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))]) new_attended_question = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))]) new_attention_weights = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))]) new_states = WikiTablesDecoderStep._compute_new_states(self.state, log_probs, new_hidden_state, new_memory_cell, step_action_embeddings, new_attended_question, new_attention_weights, considered_actions, allowed_actions=None, max_actions=max_actions) assert len(new_states) == 2 new_state = new_states[0] # For batch instance 0, we should have selected action 1 from group index 0. assert new_state.batch_indices == [0] assert_almost_equal(new_state.score[0].detach().cpu().numpy().tolist(), [.9]) # These two have values taken from what's defined in setUp() - the prior action history # ([1]) and the nonterminals corresponding to the action we picked ('j'). assert new_state.action_history == [[1, 1]] assert new_state.grammar_state[0]._nonterminal_stack == ['g'] # All of these values come from the objects instantiated directly above. assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [1, 1]) assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [1, 1]) assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [9, 9]) assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [1, 1]) # And these should just be copied from the prior state. assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.detach().cpu().numpy(), self.encoder_output_mask.detach().cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions new_state = new_states[1] # For batch instance 0, we should have selected action 0 from group index 1. assert new_state.batch_indices == [1] assert_almost_equal(new_state.score[0].detach().cpu().numpy().tolist(), [.3]) # These have values taken from what's defined in setUp() - the prior action history # ([3, 4]) and the nonterminals corresponding to the action we picked ('q'). assert new_state.action_history == [[3, 4, 0]] assert new_state.grammar_state[0]._nonterminal_stack == ['q'] # All of these values come from the objects instantiated directly above. assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2]) assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2]) assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [4, 4]) assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [2, 2]) # And these should just be copied from the prior state. assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(), self.encoder_outputs.cpu().numpy()) assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.detach().cpu().numpy(), self.encoder_output_mask.detach().cpu().numpy()) assert_almost_equal(new_state.action_embeddings.cpu().numpy(), self.action_embeddings.cpu().numpy()) assert new_state.action_indices == self.action_indices assert new_state.possible_actions == self.possible_actions