コード例 #1
0
    def test_get_entity_action_logits(self):
        decoder_step = WikiTablesDecoderStep(1, 5, SimilarityFunction.from_params(Params({})), 5, 3)
        actions_to_link = [[1, 2], [3, 4, 5], [6]]
        # (group_size, num_question_tokens) = (3, 3)
        attention_weights = Variable(torch.Tensor([[.2, .8, 0],
                                                   [.7, .1, .2],
                                                   [.3, .3, .4]]))
        action_logits, mask, type_embeddings = decoder_step._get_entity_action_logits(self.state,
                                                                                      actions_to_link,
                                                                                      attention_weights)
        assert_almost_equal(mask.data.cpu().numpy(), [[1, 1, 0], [1, 1, 1], [1, 0, 0]])

        assert tuple(action_logits.size()) == (3, 3)
        assert_almost_equal(action_logits[0, 0].data.cpu().numpy(), .4 * .2 + .5 * .8 + .6 * 0)
        assert_almost_equal(action_logits[0, 1].data.cpu().numpy(), .7 * .2 + .8 * .8 + .9 * 0)
        assert_almost_equal(action_logits[1, 0].data.cpu().numpy(), -.4 * .7 + -.5 * .1 + -.6 * .2)
        assert_almost_equal(action_logits[1, 1].data.cpu().numpy(), -.7 * .7 + -.8 * .1 + -.9 * .2)
        assert_almost_equal(action_logits[1, 2].data.cpu().numpy(), -1.0 * .7 + -1.1 * .1 + -1.2 * .2)
        assert_almost_equal(action_logits[2, 0].data.cpu().numpy(), 1.0 * .3 + 1.1 * .3 + 1.2 * .4)

        embedding_matrix = decoder_step._entity_type_embedding.weight.data.cpu().numpy()
        assert_almost_equal(type_embeddings[0, 0].data.cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[0, 1].data.cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 0].data.cpu().numpy(), embedding_matrix[0])
        assert_almost_equal(type_embeddings[1, 1].data.cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 2].data.cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[2, 0].data.cpu().numpy(), embedding_matrix[0])
コード例 #2
0
    def test_get_entity_action_logits(self):
        decoder_step = WikiTablesDecoderStep(1, 5, SimilarityFunction.from_params(Params({})), 5, 3)
        actions_to_link = [[1, 2], [3, 4, 5], [6]]
        # (group_size, num_question_tokens) = (3, 3)
        attention_weights = torch.Tensor([[.2, .8, 0],
                                          [.7, .1, .2],
                                          [.3, .3, .4]])
        action_logits, mask, type_embeddings = decoder_step._get_entity_action_logits(self.state,
                                                                                      actions_to_link,
                                                                                      attention_weights)
        assert_almost_equal(mask.detach().cpu().numpy(), [[1, 1, 0], [1, 1, 1], [1, 0, 0]])

        assert tuple(action_logits.size()) == (3, 3)
        assert_almost_equal(action_logits[0, 0].detach().cpu().numpy(), .4 * .2 + .5 * .8 + .6 * 0)
        assert_almost_equal(action_logits[0, 1].detach().cpu().numpy(), .7 * .2 + .8 * .8 + .9 * 0)
        assert_almost_equal(action_logits[1, 0].detach().cpu().numpy(), -.4 * .7 + -.5 * .1 + -.6 * .2)
        assert_almost_equal(action_logits[1, 1].detach().cpu().numpy(), -.7 * .7 + -.8 * .1 + -.9 * .2)
        assert_almost_equal(action_logits[1, 2].detach().cpu().numpy(), -1.0 * .7 + -1.1 * .1 + -1.2 * .2)
        assert_almost_equal(action_logits[2, 0].detach().cpu().numpy(), 1.0 * .3 + 1.1 * .3 + 1.2 * .4)

        embedding_matrix = decoder_step._entity_type_embedding.weight.detach().cpu().numpy()
        assert_almost_equal(type_embeddings[0, 0].detach().cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[0, 1].detach().cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 0].detach().cpu().numpy(), embedding_matrix[0])
        assert_almost_equal(type_embeddings[1, 1].detach().cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 2].detach().cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[2, 0].detach().cpu().numpy(), embedding_matrix[0])
コード例 #3
0
 def test_get_action_embeddings(self):
     action_embeddings = Variable(torch.rand(5, 4))
     self.state.action_embeddings = action_embeddings
     actions_to_embed = [[0, 4], [1], [2, 3, 4]]
     embeddings, mask = WikiTablesDecoderStep._get_action_embeddings(
         self.state, actions_to_embed)
     assert_almost_equal(mask.data.cpu().numpy(),
                         [[1, 1, 0], [1, 0, 0], [1, 1, 1]])
     assert tuple(embeddings.size()) == (3, 3, 4)
     assert_almost_equal(embeddings[0, 0].data.cpu().numpy(),
                         action_embeddings[0].data.cpu().numpy())
     assert_almost_equal(embeddings[0, 1].data.cpu().numpy(),
                         action_embeddings[4].data.cpu().numpy())
     assert_almost_equal(embeddings[0, 2].data.cpu().numpy(),
                         action_embeddings[0].data.cpu().numpy())
     assert_almost_equal(embeddings[1, 0].data.cpu().numpy(),
                         action_embeddings[1].data.cpu().numpy())
     assert_almost_equal(embeddings[1, 1].data.cpu().numpy(),
                         action_embeddings[0].data.cpu().numpy())
     assert_almost_equal(embeddings[1, 2].data.cpu().numpy(),
                         action_embeddings[0].data.cpu().numpy())
     assert_almost_equal(embeddings[2, 0].data.cpu().numpy(),
                         action_embeddings[2].data.cpu().numpy())
     assert_almost_equal(embeddings[2, 1].data.cpu().numpy(),
                         action_embeddings[3].data.cpu().numpy())
     assert_almost_equal(embeddings[2, 2].data.cpu().numpy(),
                         action_embeddings[4].data.cpu().numpy())
コード例 #4
0
    def test_get_actions_to_consider(self):
        # pylint: disable=protected-access
        valid_actions_1 = {'e': [0, 1, 2, 4]}
        valid_actions_2 = {'e': [0, 1, 3]}
        valid_actions_3 = {'e': [2, 3, 4]}
        self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
        self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
        self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)

        # We're making a bunch of the actions linked actions here, pretending that there are only
        # two global actions.
        self.state.action_indices = {
                (0, 0): 1,
                (0, 1): 0,
                (0, 2): -1,
                (0, 3): -1,
                (0, 4): -1,
                (1, 0): -1,
                (1, 1): 0,
                (1, 2): -1,
                (1, 3): -1,
                }

        considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
        # These are _global_ action indices.  They come from actions [[(0, 0), (0, 1)], [(1, 1)], []].
        expected_to_embed = [[1, 0], [0], []]
        assert to_embed == expected_to_embed
        # These are _batch_ action indices with a _global_ action index of -1.
        # They come from actions [[(0, 2), (0, 4)], [(1, 0), (1, 3)], [(0, 2), (0, 3), (0, 4)]].
        expected_to_link = [[2, 4], [0, 3], [2, 3, 4]]
        assert to_link == expected_to_link
        # These are _batch_ action indices, with padding in between the embedded actions and the
        # linked actions (and after the linked actions, if necessary).
        expected_considered = [[0, 1, 2, 4, -1], [1, -1, 0, 3, -1], [-1, -1, 2, 3, 4]]
        assert considered == expected_considered
コード例 #5
0
    def test_get_actions_to_consider(self):
        # pylint: disable=protected-access
        valid_actions_1 = {'e': [0, 1, 2, 4]}
        valid_actions_2 = {'e': [0, 1, 3]}
        valid_actions_3 = {'e': [2, 3, 4]}
        self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
        self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
        self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)

        # We're making a bunch of the actions linked actions here, pretending that there are only
        # two global actions.
        self.state.action_indices = {
                (0, 0): 1,
                (0, 1): 0,
                (0, 2): -1,
                (0, 3): -1,
                (0, 4): -1,
                (1, 0): -1,
                (1, 1): 0,
                (1, 2): -1,
                (1, 3): -1,
                }

        considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
        # These are _global_ action indices.  They come from actions [[(0, 0), (0, 1)], [(1, 1)], []].
        expected_to_embed = [[1, 0], [0], []]
        assert to_embed == expected_to_embed
        # These are _batch_ action indices with a _global_ action index of -1.
        # They come from actions [[(0, 2), (0, 4)], [(1, 0), (1, 3)], [(0, 2), (0, 3), (0, 4)]].
        expected_to_link = [[2, 4], [0, 3], [2, 3, 4]]
        assert to_link == expected_to_link
        # These are _batch_ action indices, with padding in between the embedded actions and the
        # linked actions (and after the linked actions, if necessary).
        expected_considered = [[0, 1, 2, 4, -1], [1, -1, 0, 3, -1], [-1, -1, 2, 3, 4]]
        assert considered == expected_considered
コード例 #6
0
 def test_get_action_embeddings(self):
     action_embeddings = torch.rand(5, 4)
     self.state.action_embeddings = action_embeddings
     self.state.output_action_embeddings = action_embeddings
     self.state.action_biases = torch.rand(5, 1)
     actions_to_embed = [[0, 4], [1], [2, 3, 4]]
     embeddings, _, _, mask = WikiTablesDecoderStep._get_action_embeddings(self.state, actions_to_embed)
     assert_almost_equal(mask.detach().cpu().numpy(), [[1, 1, 0], [1, 0, 0], [1, 1, 1]])
     assert tuple(embeddings.size()) == (3, 3, 4)
     assert_almost_equal(embeddings[0, 0].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy())
     assert_almost_equal(embeddings[0, 1].detach().cpu().numpy(), action_embeddings[4].detach().cpu().numpy())
     assert_almost_equal(embeddings[0, 2].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy())
     assert_almost_equal(embeddings[1, 0].detach().cpu().numpy(), action_embeddings[1].detach().cpu().numpy())
     assert_almost_equal(embeddings[1, 1].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy())
     assert_almost_equal(embeddings[1, 2].detach().cpu().numpy(), action_embeddings[0].detach().cpu().numpy())
     assert_almost_equal(embeddings[2, 0].detach().cpu().numpy(), action_embeddings[2].detach().cpu().numpy())
     assert_almost_equal(embeddings[2, 1].detach().cpu().numpy(), action_embeddings[3].detach().cpu().numpy())
     assert_almost_equal(embeddings[2, 2].detach().cpu().numpy(), action_embeddings[4].detach().cpu().numpy())
コード例 #7
0
 def test_get_actions_to_consider_returns_none_if_no_linked_actions(self):
     # pylint: disable=protected-access
     valid_actions_1 = {'e': [0, 1, 2, 4]}
     valid_actions_2 = {'e': [0, 1, 3]}
     valid_actions_3 = {'e': [2, 3, 4]}
     self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
     self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
     self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)
     considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
     # These are _global_ action indices.  All of the actions in this case are embedded, so this
     # is just a mapping from the valid actions above to their global ids.
     expected_to_embed = [[1, 0, 2, 5], [4, 0, 3], [2, 3, 5]]
     assert to_embed == expected_to_embed
     # There are no linked actions (all of them are embedded), so this should be None.
     assert to_link is None
     # These are _batch_ action indices, with padding in between the embedded actions and the
     # linked actions.  Because there are no linked actions, this is basically just the
     # valid_actions for each group element padded with -1s.
     expected_considered = [[0, 1, 2, 4], [0, 1, 3, -1], [2, 3, 4, -1]]
     assert considered == expected_considered
コード例 #8
0
 def test_get_actions_to_consider_returns_none_if_no_linked_actions(self):
     # pylint: disable=protected-access
     valid_actions_1 = {'e': [0, 1, 2, 4]}
     valid_actions_2 = {'e': [0, 1, 3]}
     valid_actions_3 = {'e': [2, 3, 4]}
     self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
     self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
     self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)
     considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
     # These are _global_ action indices.  All of the actions in this case are embedded, so this
     # is just a mapping from the valid actions above to their global ids.
     expected_to_embed = [[1, 0, 2, 5], [4, 0, 3], [2, 3, 5]]
     assert to_embed == expected_to_embed
     # There are no linked actions (all of them are embedded), so this should be None.
     assert to_link is None
     # These are _batch_ action indices, with padding in between the embedded actions and the
     # linked actions.  Because there are no linked actions, this is basically just the
     # valid_actions for each group element padded with -1s.
     expected_considered = [[0, 1, 2, 4], [0, 1, 3, -1], [2, 3, 4, -1]]
     assert considered == expected_considered
コード例 #9
0
 def __init__(self,
              vocab,
              question_embedder,
              action_embedding_dim,
              encoder,
              entity_encoder,
              decoder_beam_search,
              max_decoding_steps,
              attention,
              mixture_feedforward=None,
              training_beam_size=None,
              use_neighbor_similarity_for_linking=False,
              dropout=0.0,
              num_linking_features=10,
              rule_namespace=u'rule_labels',
              tables_directory=u'/wikitables/'):
     use_similarity = use_neighbor_similarity_for_linking
     super(WikiTablesMmlSemanticParser, self).__init__(
         vocab=vocab,
         question_embedder=question_embedder,
         action_embedding_dim=action_embedding_dim,
         encoder=encoder,
         entity_encoder=entity_encoder,
         max_decoding_steps=max_decoding_steps,
         use_neighbor_similarity_for_linking=use_similarity,
         dropout=dropout,
         num_linking_features=num_linking_features,
         rule_namespace=rule_namespace,
         tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
コード例 #10
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              mixture_feedforward: FeedForward,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention_function: SimilarityFunction,
              training_beam_size: int = None,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         attention_function=attention_function,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
コード例 #11
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              mixture_feedforward: FeedForward,
              input_attention: Attention,
              decoder_beam_size: int,
              decoder_num_finished_states: int,
              max_decoding_steps: int,
              normalize_beam_score_by_length: bool = False,
              checklist_cost_weight: float = 0.6,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/',
              initial_mml_model_file: str = None) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = \
             ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                      normalize_by_length=normalize_beam_score_by_length,
                                      max_decoding_steps=self._max_decoding_steps,
                                      max_num_finished_states=decoder_num_finished_states)
     unlinked_terminals_global_indices = []
     global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
     for production, index in global_vocab.items():
         right_side = production.split(" -> ")[1]
         if right_side in types.COMMON_NAME_MAPPING:
             # This is a terminal production.
             unlinked_terminals_global_indices.append(index)
     self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=input_attention,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout,
         unlinked_terminal_indices=unlinked_terminals_global_indices)
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if initial_mml_model_file is not None:
         if os.path.isfile(initial_mml_model_file):
             archive = load_archive(initial_mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")
コード例 #12
0
    def test_compute_new_states_with_no_action_constraints(self):
        # pylint: disable=protected-access
        # This test is basically identical to the previous one, but without specifying
        # `allowed_actions`.  This makes sure we get the right behavior at test time.
        log_probs = Variable(
            torch.FloatTensor([[.1, .9, -.1, .2], [.3, 1.1, .1, .8],
                               [.1, .25, .3, .4]]))
        considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]]
        max_actions = 1
        step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2],
                                                     [3, 3]],
                                                    [[4, 4], [9, 9], [3, 3],
                                                     [9, 9]],
                                                    [[1, 1], [2, 2], [5, 5],
                                                     [9, 9]]])
        new_hidden_state = torch.FloatTensor(
            [[i + 1, i + 1] for i in range(len(considered_actions))])
        new_memory_cell = torch.FloatTensor(
            [[i + 1, i + 1] for i in range(len(considered_actions))])
        new_attended_question = torch.FloatTensor(
            [[i + 1, i + 1] for i in range(len(considered_actions))])
        new_attention_weights = torch.FloatTensor(
            [[i + 1, i + 1] for i in range(len(considered_actions))])
        new_states = WikiTablesDecoderStep._compute_new_states(
            self.state,
            log_probs,
            new_hidden_state,
            new_memory_cell,
            step_action_embeddings,
            new_attended_question,
            new_attention_weights,
            considered_actions,
            allowed_actions=None,
            max_actions=max_actions)

        assert len(new_states) == 2
        new_state = new_states[0]
        # For batch instance 0, we should have selected action 1 from group index 0.
        assert new_state.batch_indices == [0]
        assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(),
                            [.9])
        # These two have values taken from what's defined in setUp() - the prior action history
        # ([1]) and the nonterminals corresponding to the action we picked ('j').
        assert new_state.action_history == [[1, 1]]
        assert new_state.grammar_state[0]._nonterminal_stack == ['g']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(
            new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [1, 1])
        assert_almost_equal(
            new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [1, 1])
        assert_almost_equal(
            new_state.rnn_state[0].previous_action_embedding.cpu().numpy().
            tolist(), [9, 9])
        assert_almost_equal(
            new_state.rnn_state[0].attended_input.cpu().numpy().tolist(),
            [1, 1])
        # And these should just be copied from the prior state.
        assert_almost_equal(
            new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(
            new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(),
            self.encoder_output_mask.data.cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions

        new_state = new_states[1]
        # For batch instance 0, we should have selected action 0 from group index 1.
        assert new_state.batch_indices == [1]
        assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(),
                            [.3])
        # These have values taken from what's defined in setUp() - the prior action history
        # ([3, 4]) and the nonterminals corresponding to the action we picked ('q').
        assert new_state.action_history == [[3, 4, 0]]
        assert new_state.grammar_state[0]._nonterminal_stack == ['q']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(
            new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(
            new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(
            new_state.rnn_state[0].previous_action_embedding.cpu().numpy().
            tolist(), [4, 4])
        assert_almost_equal(
            new_state.rnn_state[0].attended_input.cpu().numpy().tolist(),
            [2, 2])
        # And these should just be copied from the prior state.
        assert_almost_equal(
            new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(
            new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(),
            self.encoder_output_mask.data.cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions
コード例 #13
0
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 max_decoding_steps: int,
                 attention_function: SimilarityFunction,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal(self._first_action_embedding)
        torch.nn.init.normal(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

        self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                                   action_embedding_dim=action_embedding_dim,
                                                   attention_function=attention_function,
                                                   num_start_types=self._num_start_types,
                                                   num_entity_types=self._num_entity_types,
                                                   mixture_feedforward=mixture_feedforward,
                                                   dropout=dropout)
コード例 #14
0
    def test_compute_new_states(self):
        # pylint: disable=protected-access
        log_probs = Variable(torch.FloatTensor([[.1, .9, -.1, .2],
                                                [.3, 1.1, .1, .8],
                                                [.1, .25, .3, .4]]))
        considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]]
        allowed_actions = [{2, 3}, {0}, {4}]
        max_actions = 1
        step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2], [3, 3]],
                                                    [[4, 4], [9, 9], [3, 3], [9, 9]],
                                                    [[1, 1], [2, 2], [5, 5], [9, 9]]])
        new_hidden_state = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))])
        new_memory_cell = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))])
        new_attended_question = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))])
        new_attention_weights = torch.FloatTensor([[i + 1, i + 1] for i in range(len(allowed_actions))])
        new_states = WikiTablesDecoderStep._compute_new_states(self.state,
                                                               log_probs,
                                                               new_hidden_state,
                                                               new_memory_cell,
                                                               step_action_embeddings,
                                                               new_attended_question,
                                                               new_attention_weights,
                                                               considered_actions,
                                                               allowed_actions,
                                                               max_actions)

        assert len(new_states) == 2
        new_state = new_states[0]
        # For batch instance 0, we should have selected action 4 from group index 2.
        assert new_state.batch_indices == [0]
        # These three have values taken from what's defined in setUp() - the prior action history
        # (empty in this case), the initial score (2.2), and the nonterminals corresponding to the
        # action we picked ('j').
        assert new_state.action_history == [[4]]
        assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [2.2 + .3])
        assert new_state.grammar_state[0]._nonterminal_stack == ['j']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [3, 3])
        assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [3, 3])
        assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [5, 5])
        assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [3, 3])
        # And these should just be copied from the prior state.
        assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
                            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(),
                            self.encoder_output_mask.data.cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions

        new_state = new_states[1]
        # For batch instance 1, we should have selected action 0 from group index 1.
        assert new_state.batch_indices == [1]
        # These three have values taken from what's defined in setUp() - the prior action history
        # ([3, 4]), the initial score (1.1), and the nonterminals corresponding to the action we
        # picked ('q').
        assert new_state.action_history == [[3, 4, 0]]
        assert_almost_equal(new_state.score[0].data.cpu().numpy().tolist(), [1.1 + .3])
        assert new_state.grammar_state[0]._nonterminal_stack == ['q']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [4, 4])
        assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [2, 2])
        # And these should just be copied from the prior state.
        assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
                            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.data.cpu().numpy(),
                            self.encoder_output_mask.data.cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions
コード例 #15
0
    def test_compute_new_states_with_no_action_constraints(self):
        # pylint: disable=protected-access
        # This test is basically identical to the previous one, but without specifying
        # `allowed_actions`.  This makes sure we get the right behavior at test time.
        log_probs = torch.FloatTensor([[.1, .9, -.1, .2],
                                       [.3, 1.1, .1, .8],
                                       [.1, .25, .3, .4]])
        considered_actions = [[0, 1, 2, 3], [0, -1, 3, -1], [0, 2, 4, -1]]
        max_actions = 1
        step_action_embeddings = torch.FloatTensor([[[1, 1], [9, 9], [2, 2], [3, 3]],
                                                    [[4, 4], [9, 9], [3, 3], [9, 9]],
                                                    [[1, 1], [2, 2], [5, 5], [9, 9]]])
        new_hidden_state = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))])
        new_memory_cell = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))])
        new_attended_question = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))])
        new_attention_weights = torch.FloatTensor([[i + 1, i + 1] for i in range(len(considered_actions))])
        new_states = WikiTablesDecoderStep._compute_new_states(self.state,
                                                               log_probs,
                                                               new_hidden_state,
                                                               new_memory_cell,
                                                               step_action_embeddings,
                                                               new_attended_question,
                                                               new_attention_weights,
                                                               considered_actions,
                                                               allowed_actions=None,
                                                               max_actions=max_actions)

        assert len(new_states) == 2
        new_state = new_states[0]
        # For batch instance 0, we should have selected action 1 from group index 0.
        assert new_state.batch_indices == [0]
        assert_almost_equal(new_state.score[0].detach().cpu().numpy().tolist(), [.9])
        # These two have values taken from what's defined in setUp() - the prior action history
        # ([1]) and the nonterminals corresponding to the action we picked ('j').
        assert new_state.action_history == [[1, 1]]
        assert new_state.grammar_state[0]._nonterminal_stack == ['g']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [1, 1])
        assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [1, 1])
        assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [9, 9])
        assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [1, 1])
        # And these should just be copied from the prior state.
        assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
                            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.detach().cpu().numpy(),
                            self.encoder_output_mask.detach().cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions

        new_state = new_states[1]
        # For batch instance 0, we should have selected action 0 from group index 1.
        assert new_state.batch_indices == [1]
        assert_almost_equal(new_state.score[0].detach().cpu().numpy().tolist(), [.3])
        # These have values taken from what's defined in setUp() - the prior action history
        # ([3, 4]) and the nonterminals corresponding to the action we picked ('q').
        assert new_state.action_history == [[3, 4, 0]]
        assert new_state.grammar_state[0]._nonterminal_stack == ['q']
        # All of these values come from the objects instantiated directly above.
        assert_almost_equal(new_state.rnn_state[0].hidden_state.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(new_state.rnn_state[0].memory_cell.cpu().numpy().tolist(), [2, 2])
        assert_almost_equal(new_state.rnn_state[0].previous_action_embedding.cpu().numpy().tolist(), [4, 4])
        assert_almost_equal(new_state.rnn_state[0].attended_input.cpu().numpy().tolist(), [2, 2])
        # And these should just be copied from the prior state.
        assert_almost_equal(new_state.rnn_state[0].encoder_outputs.cpu().numpy(),
                            self.encoder_outputs.cpu().numpy())
        assert_almost_equal(new_state.rnn_state[0].encoder_output_mask.detach().cpu().numpy(),
                            self.encoder_output_mask.detach().cpu().numpy())
        assert_almost_equal(new_state.action_embeddings.cpu().numpy(),
                            self.action_embeddings.cpu().numpy())
        assert new_state.action_indices == self.action_indices
        assert new_state.possible_actions == self.possible_actions