def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 graph_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 enable_gating: bool = False,
                 ablation_mode: str = None,
                 gnn: bool = True,
                 graph_loss_lambda: float = 0.5,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 pruning_gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = False,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 log_path: str = '',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab, encoder, entity_encoder, question_embedder, gnn_timesteps, dropout, rule_namespace)

        self.enable_gating = enable_gating
        self.ablation_mode = ablation_mode
        self._log_path = log_path
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias

        self._parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        self._graph_loss_lambda = graph_loss_lambda

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._embedding_projector = torch.nn.Linear(question_embedder.get_output_dim(), self._embedding_dim, bias=False)
        self._bert_embedding_dim = question_embedder.get_output_dim()
        encoder_output_dim = self._encoder.get_output_dim() + self._embedding_dim

        self._neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))

        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._graph_pruning = GraphPruning(3, self._embedding_dim, encoder.get_output_dim(), dropout,
                                           timesteps=pruning_gnn_timesteps)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                                action_embedding_dim=action_embedding_dim,
                                                                                input_attention=input_attention,
                                                                                past_attention=past_attention,
                                                                                enable_gating=self.enable_gating,
                                                                                ablation_mode=self.ablation_mode,
                                                                                predict_start_type_separately=False,
                                                                                add_action_bias=self._add_action_bias,
                                                                                dropout=dropout,
                                                                                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                  action_embedding_dim=action_embedding_dim,
                                                                  input_attention=input_attention,
                                                                  predict_start_type_separately=False,
                                                                  add_action_bias=self._add_action_bias,
                                                                  dropout=dropout,
                                                                  num_layers=self._decoder_num_layers)

        if self.enable_gating:
            self._graph_attention = graph_attention
        else:
            self._graph_attention = DotProductAttention()

        self._embedding_sim_attn = CosineMatrixAttention()

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(evaluate,
                                      db_dir=os.path.join(dataset_path, 'database'),
                                      table=os.path.join(dataset_path, 'tables.json'),
                                      check_valid=False)
Example #2
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:
        super().__init__(vocab)
        self.vocab = vocab
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._question_embedder = question_embedder
        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._num_entity_types = 9
        self._embedding_dim = question_embedder.get_output_dim()

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._linking_params = torch.nn.Linear(16, 1)
        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)

        num_edge_types = 3
        self._gnn = GatedGraphConv(self._embedding_dim,
                                   gnn_timesteps,
                                   num_edge_types=num_edge_types,
                                   dropout=dropout)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=encoder_output_dim,
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        self._neighbor_params = torch.nn.Linear(self._embedding_dim,
                                                self._embedding_dim)

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)

        self.debug_parsing = debug_parsing
Example #3
0
    def __init__(self,
                 question_embedder: TextFieldEmbedder,
                 input_memory_embedder: TextFieldEmbedder,
                 output_memory_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 input_memory_encoder: Seq2VecEncoder,
                 output_memory_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 input_attention: Attention,
                 past_attention: Attention,
                 action_embedding_dim: int,
                 max_decoding_steps: int,
                 nhop: int,
                 decoding_nhop: int,
                 vocab: Vocabulary,
                 dataset_path: str = 'dataset',
                 parse_sql_on_decoding: bool = True,
                 training_beam_size: int = None,
                 add_action_bias: bool = True,
                 decoder_self_attend: bool = True,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels') -> None:
        super().__init__(vocab)

        self.question_embedder = question_embedder
        self._input_mm_embedder = input_memory_embedder
        self._output_mm_embedder = output_memory_embedder
        self._question_encoder = question_encoder
        self._input_mm_encoder = TimeDistributed(input_memory_encoder)
        self._output_mm_encoder = TimeDistributed(output_memory_encoder)

        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._self_attend = decoder_self_attend
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._rule_namespace = rule_namespace
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions,
                                          embedding_dim=input_action_dim)
        self._input_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        self._num_entity_types = 9
        self._entity_type_decoder_input_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)
        self._entity_type_decoder_output_embedding = Embedding(
            self._num_entity_types, action_embedding_dim)

        self._entity_type_encoder_embedding = Embedding(
            self._num_entity_types,
            (int)(question_encoder.get_output_dim() / 2))

        self._decoder_num_layers = decoder_num_layers
        self._action_embedding_dim = action_embedding_dim

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1,
                                       action_embedding_dim,
                                       Activation.by_name('relu')())

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._first_action_embedding = torch.nn.Parameter(
            torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(question_encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        if self._self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                past_attention=past_attention,
                decoding_nhop=decoding_nhop,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(
                encoder_output_dim=question_encoder.get_output_dim(),
                action_embedding_dim=action_embedding_dim,
                input_attention=input_attention,
                predict_start_type_separately=False,
                add_action_bias=self._add_action_bias,
                dropout=dropout,
                num_layers=self._decoder_num_layers)

        self._mm_attn = MemAttn(question_encoder.get_output_dim(), nhop)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(
            evaluate,
            db_dir=os.path.join(dataset_path, 'database'),
            table=os.path.join(dataset_path, 'tables.json'),
            check_valid=False)