예제 #1
0
    def test_get_entity_action_logits(self):
        decoder_step = WikiTablesDecoderStep(1, 5, SimilarityFunction.from_params(Params({})), 5, 3)
        actions_to_link = [[1, 2], [3, 4, 5], [6]]
        # (group_size, num_question_tokens) = (3, 3)
        attention_weights = Variable(torch.Tensor([[.2, .8, 0],
                                                   [.7, .1, .2],
                                                   [.3, .3, .4]]))
        action_logits, mask, type_embeddings = decoder_step._get_entity_action_logits(self.state,
                                                                                      actions_to_link,
                                                                                      attention_weights)
        assert_almost_equal(mask.data.cpu().numpy(), [[1, 1, 0], [1, 1, 1], [1, 0, 0]])

        assert tuple(action_logits.size()) == (3, 3)
        assert_almost_equal(action_logits[0, 0].data.cpu().numpy(), .4 * .2 + .5 * .8 + .6 * 0)
        assert_almost_equal(action_logits[0, 1].data.cpu().numpy(), .7 * .2 + .8 * .8 + .9 * 0)
        assert_almost_equal(action_logits[1, 0].data.cpu().numpy(), -.4 * .7 + -.5 * .1 + -.6 * .2)
        assert_almost_equal(action_logits[1, 1].data.cpu().numpy(), -.7 * .7 + -.8 * .1 + -.9 * .2)
        assert_almost_equal(action_logits[1, 2].data.cpu().numpy(), -1.0 * .7 + -1.1 * .1 + -1.2 * .2)
        assert_almost_equal(action_logits[2, 0].data.cpu().numpy(), 1.0 * .3 + 1.1 * .3 + 1.2 * .4)

        embedding_matrix = decoder_step._entity_type_embedding.weight.data.cpu().numpy()
        assert_almost_equal(type_embeddings[0, 0].data.cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[0, 1].data.cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 0].data.cpu().numpy(), embedding_matrix[0])
        assert_almost_equal(type_embeddings[1, 1].data.cpu().numpy(), embedding_matrix[1])
        assert_almost_equal(type_embeddings[1, 2].data.cpu().numpy(), embedding_matrix[2])
        assert_almost_equal(type_embeddings[2, 0].data.cpu().numpy(), embedding_matrix[0])
 def __init__(self,
              vocab,
              question_embedder,
              action_embedding_dim,
              encoder,
              entity_encoder,
              decoder_beam_search,
              max_decoding_steps,
              attention,
              mixture_feedforward=None,
              training_beam_size=None,
              use_neighbor_similarity_for_linking=False,
              dropout=0.0,
              num_linking_features=10,
              rule_namespace=u'rule_labels',
              tables_directory=u'/wikitables/'):
     use_similarity = use_neighbor_similarity_for_linking
     super(WikiTablesMmlSemanticParser, self).__init__(
         vocab=vocab,
         question_embedder=question_embedder,
         action_embedding_dim=action_embedding_dim,
         encoder=encoder,
         entity_encoder=entity_encoder,
         max_decoding_steps=max_decoding_steps,
         use_neighbor_similarity_for_linking=use_similarity,
         dropout=dropout,
         num_linking_features=num_linking_features,
         rule_namespace=rule_namespace,
         tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              mixture_feedforward: FeedForward,
              decoder_beam_search: BeamSearch,
              max_decoding_steps: int,
              attention_function: SimilarityFunction,
              training_beam_size: int = None,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/') -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     self._beam_search = decoder_beam_search
     self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         attention_function=attention_function,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              mixture_feedforward: FeedForward,
              input_attention: Attention,
              decoder_beam_size: int,
              decoder_num_finished_states: int,
              max_decoding_steps: int,
              normalize_beam_score_by_length: bool = False,
              checklist_cost_weight: float = 0.6,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/',
              initial_mml_model_file: str = None) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = \
             ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                      normalize_by_length=normalize_beam_score_by_length,
                                      max_decoding_steps=self._max_decoding_steps,
                                      max_num_finished_states=decoder_num_finished_states)
     unlinked_terminals_global_indices = []
     global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
     for production, index in global_vocab.items():
         right_side = production.split(" -> ")[1]
         if right_side in types.COMMON_NAME_MAPPING:
             # This is a terminal production.
             unlinked_terminals_global_indices.append(index)
     self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
     self._decoder_step = WikiTablesDecoderStep(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=input_attention,
         num_start_types=self._num_start_types,
         num_entity_types=self._num_entity_types,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout,
         unlinked_terminal_indices=unlinked_terminals_global_indices)
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if initial_mml_model_file is not None:
         if os.path.isfile(initial_mml_model_file):
             archive = load_archive(initial_mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")
예제 #5
0
    def __init__(self,
                 vocab: Vocabulary,
                 question_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 mixture_feedforward: FeedForward,
                 max_decoding_steps: int,
                 attention_function: SimilarityFunction,
                 use_neighbor_similarity_for_linking: bool = False,
                 dropout: float = 0.0,
                 num_linking_features: int = 10,
                 rule_namespace: str = 'rule_labels',
                 tables_directory: str = '/wikitables/') -> None:
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal(self._first_action_embedding)
        torch.nn.init.normal(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               "entity word average embedding dim", "question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None

        self._decoder_step = WikiTablesDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                                   action_embedding_dim=action_embedding_dim,
                                                   attention_function=attention_function,
                                                   num_start_types=self._num_start_types,
                                                   num_entity_types=self._num_entity_types,
                                                   mixture_feedforward=mixture_feedforward,
                                                   dropout=dropout)