def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, question_embedder: TextFieldEmbedder, input_attention: Attention, past_attention: Attention, graph_attention: Attention, max_decoding_steps: int, action_embedding_dim: int, enable_gating: bool = False, ablation_mode: str = None, gnn: bool = True, graph_loss_lambda: float = 0.5, decoder_use_graph_entities: bool = True, decoder_self_attend: bool = True, gnn_timesteps: int = 2, pruning_gnn_timesteps: int = 2, parse_sql_on_decoding: bool = True, add_action_bias: bool = False, use_neighbor_similarity_for_linking: bool = True, dataset_path: str = 'dataset', log_path: str = '', training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels') -> None: super().__init__(vocab, encoder, entity_encoder, question_embedder, gnn_timesteps, dropout, rule_namespace) self.enable_gating = enable_gating self.ablation_mode = ablation_mode self._log_path = log_path self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._parse_sql_on_decoding = parse_sql_on_decoding self._self_attend = decoder_self_attend self._decoder_use_graph_entities = decoder_use_graph_entities self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking self._action_padding_index = -1 # the padding value used by IndexField self._exact_match = Average() self._sql_evaluator_match = Average() self._action_similarity = Average() self._beam_hit = Average() self._action_embedding_dim = action_embedding_dim self._graph_loss_lambda = graph_loss_lambda num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._embedding_projector = torch.nn.Linear(question_embedder.get_output_dim(), self._embedding_dim, bias=False) self._bert_embedding_dim = question_embedder.get_output_dim() encoder_output_dim = self._encoder.get_output_dim() + self._embedding_dim self._neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim)) self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) torch.nn.init.normal_(self._first_attended_output) self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._graph_pruning = GraphPruning(3, self._embedding_dim, encoder.get_output_dim(), dropout, timesteps=pruning_gnn_timesteps) if decoder_self_attend: self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, past_attention=past_attention, enable_gating=self.enable_gating, ablation_mode=self.ablation_mode, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) else: self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) if self.enable_gating: self._graph_attention = graph_attention else: self._graph_attention = DotProductAttention() self._embedding_sim_attn = CosineMatrixAttention() # TODO: Remove hard-coded dirs self._evaluate_func = partial(evaluate, db_dir=os.path.join(dataset_path, 'database'), table=os.path.join(dataset_path, 'tables.json'), check_valid=False)
def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, question_embedder: TextFieldEmbedder, input_attention: Attention, past_attention: Attention, max_decoding_steps: int, action_embedding_dim: int, gnn: bool = True, decoder_use_graph_entities: bool = True, decoder_self_attend: bool = True, gnn_timesteps: int = 2, parse_sql_on_decoding: bool = True, add_action_bias: bool = True, use_neighbor_similarity_for_linking: bool = True, dataset_path: str = 'dataset', training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels', scoring_dev_params: dict = None, debug_parsing: bool = False) -> None: super().__init__(vocab) self.vocab = vocab self._encoder = encoder self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._question_embedder = question_embedder self._add_action_bias = add_action_bias self._scoring_dev_params = scoring_dev_params or {} self.parse_sql_on_decoding = parse_sql_on_decoding self._entity_encoder = TimeDistributed(entity_encoder) self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking self._self_attend = decoder_self_attend self._decoder_use_graph_entities = decoder_use_graph_entities self._action_padding_index = -1 # the padding value used by IndexField self._exact_match = Average() self._sql_evaluator_match = Average() self._action_similarity = Average() self._acc_single = Average() self._acc_multi = Average() self._beam_hit = Average() self._action_embedding_dim = action_embedding_dim num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) encoder_output_dim = encoder.get_output_dim() if gnn: encoder_output_dim += action_embedding_dim self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder_output_dim)) self._first_attended_output = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) torch.nn.init.normal_(self._first_attended_output) self._num_entity_types = 9 self._embedding_dim = question_embedder.get_output_dim() self._entity_type_encoder_embedding = Embedding( self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding( self._num_entity_types, action_embedding_dim) self._linking_params = torch.nn.Linear(16, 1) torch.nn.init.uniform_(self._linking_params.weight, 0, 1) num_edge_types = 3 self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) if decoder_self_attend: self._transition_function = AttendPastSchemaItemsTransitionFunction( encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, past_attention=past_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) else: self._transition_function = LinkingTransitionFunction( encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')()) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) # TODO: Remove hard-coded dirs self._evaluate_func = partial( evaluate, db_dir=os.path.join(dataset_path, 'database'), table=os.path.join(dataset_path, 'tables.json'), check_valid=False) self.debug_parsing = debug_parsing
def __init__(self, question_embedder: TextFieldEmbedder, input_memory_embedder: TextFieldEmbedder, output_memory_embedder: TextFieldEmbedder, question_encoder: Seq2SeqEncoder, input_memory_encoder: Seq2VecEncoder, output_memory_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, input_attention: Attention, past_attention: Attention, action_embedding_dim: int, max_decoding_steps: int, nhop: int, decoding_nhop: int, vocab: Vocabulary, dataset_path: str = 'dataset', parse_sql_on_decoding: bool = True, training_beam_size: int = None, add_action_bias: bool = True, decoder_self_attend: bool = True, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels') -> None: super().__init__(vocab) self.question_embedder = question_embedder self._input_mm_embedder = input_memory_embedder self._output_mm_embedder = output_memory_embedder self._question_encoder = question_encoder self._input_mm_encoder = TimeDistributed(input_memory_encoder) self._output_mm_encoder = TimeDistributed(output_memory_encoder) self.parse_sql_on_decoding = parse_sql_on_decoding self._self_attend = decoder_self_attend self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._rule_namespace = rule_namespace num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._input_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._num_entity_types = 9 self._entity_type_decoder_input_embedding = Embedding( self._num_entity_types, action_embedding_dim) self._entity_type_decoder_output_embedding = Embedding( self._num_entity_types, action_embedding_dim) self._entity_type_encoder_embedding = Embedding( self._num_entity_types, (int)(question_encoder.get_output_dim() / 2)) self._decoder_num_layers = decoder_num_layers self._action_embedding_dim = action_embedding_dim self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')()) if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(question_encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) if self._self_attend: self._transition_function = AttendPastSchemaItemsTransitionFunction( encoder_output_dim=question_encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, past_attention=past_attention, decoding_nhop=decoding_nhop, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) else: self._transition_function = LinkingTransitionFunction( encoder_output_dim=question_encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) self._mm_attn = MemAttn(question_encoder.get_output_dim(), nhop) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._action_padding_index = -1 # the padding value used by IndexField self._exact_match = Average() self._sql_evaluator_match = Average() self._action_similarity = Average() self._acc_single = Average() self._acc_multi = Average() self._beam_hit = Average() # TODO: Remove hard-coded dirs self._evaluate_func = partial( evaluate, db_dir=os.path.join(dataset_path, 'database'), table=os.path.join(dataset_path, 'tables.json'), check_valid=False)