class STSBTask(Task): ''' Task class for Sentence Textual Similarity Benchmark. ''' def __init__(self, path, max_seq_len, name="sts_benchmark"): ''' ''' super(STSBTask, self).__init__(name, 1) self.categorical = 0 self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.scorer1 = Average() self.scorer2 = Average() self.load_data(path, max_seq_len) def load_data(self, path, max_seq_len): ''' ''' tr_data = load_tsv(os.path.join(path, 'train.tsv'), max_seq_len, skip_rows=1, s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5) val_data = load_tsv(os.path.join(path, 'dev.tsv'), max_seq_len, skip_rows=1, s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: float(x) / 5) te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len, s1_idx=7, s2_idx=8, targ_idx=None, idx_idx=0, skip_rows=1) self.train_data_text = tr_data self.val_data_text = val_data self.test_data_text = te_data log.info("\tFinished loading STS Benchmark data.") def get_metrics(self, reset=False): # NB: I think I call it accuracy b/c something weird in training return {'accuracy': self.scorer1.get_metric(reset), 'spearmanr': self.scorer2.get_metric(reset)}
class CoLATask(Task): '''Class for Warstdadt acceptability task''' def __init__(self, path, max_seq_len, name="acceptability"): ''' ''' super(CoLATask, self).__init__(name, 2) self.pair_input = 0 self.load_data(path, max_seq_len) self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.scorer1 = Average() self.scorer2 = CategoricalAccuracy() def load_data(self, path, max_seq_len): '''Load the data''' tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len, s1_idx=3, s2_idx=None, targ_idx=1) val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len, s1_idx=3, s2_idx=None, targ_idx=1) te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len, s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1) self.train_data_text = tr_data self.val_data_text = val_data self.test_data_text = te_data log.info("\tFinished loading CoLA.") def get_metrics(self, reset=False): # NB: I think I call it accuracy b/c something weird in training return {'accuracy': self.scorer1.get_metric(reset), 'acc': self.scorer2.get_metric(reset)}
class Seq2SeqClassifier(Model): def __init__(self, word_embeddings: TextFieldEmbedder, encoder: Seq2SeqEncoder, vocab: Vocabulary, hidden_dimension: int, bs: int) -> None: super().__init__(vocab) self.word_embeddings = word_embeddings self.encoder = encoder self.bs = bs self.hidden_dim = hidden_dimension self.vocab = vocab self.tasks_vocabulary = {"default": vocab} self.current_task = "default" self.num_task = 0 self.classification_layers = torch.nn.ModuleList([torch.nn.Linear(in_features=self.hidden_dim, out_features=self.vocab.get_vocab_size('labels'))]) self.task2id = { "default": 0 } self.hidden2tag = self.classification_layers[self.task2id["default"]] self.accuracy = CategoricalAccuracy() self.loss_function = torch.nn.CrossEntropyLoss() self.average = Average() self.activations = [] self.labels = [] def add_task(self, task_tag: str, vocab: Vocabulary): self.classification_layers.append(torch.nn.Linear(in_features=self.hidden_dim, out_features=vocab.get_vocab_size('labels'))) self.num_task = self.num_task + 1 self.task2id[task_tag] = self.num_task self.tasks_vocabulary[task_tag] = vocab def set_task(self, task_tag: str): self.hidden2tag = self.classification_layers[self.task2id[task_tag]] self.current_task = task_tag self.vocab = self.tasks_vocabulary[task_tag] def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor = None) -> Dict[str, torch.Tensor]: mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) encoder_out = self.encoder(embeddings, mask) tag_logits = self.hidden2tag(torch.nn.functional.adaptive_max_pool1d(encoder_out.permute(0,2,1), (1,)).view(-1, self.hidden_dim)) output = {'logits': tag_logits } self.activations = encoder_out self.labels = label if label is not None: _, preds = tag_logits.max(dim=1) self.average(matthews_corrcoef(label.data.cpu().numpy(), preds.data.cpu().numpy())) self.accuracy(tag_logits, label) output["loss"] = self.loss_function(tag_logits, label) return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self.accuracy.get_metric(reset), "average": self.average.get_metric(reset)} def get_activations(self) -> []: return self.activations, self.labels
def perplexity(lm_path, sample_file): import kenlm model = kenlm.LanguageModel(lm_path) ppl = Average() num_lines = file_utils.get_num_lines(sample_file) with open(sample_file) as sf: with tqdm(sf, total=num_lines, desc='Computing PPL') as pbar: for sentence in pbar: ppl(model.perplexity(sentence)) pbar.set_postfix({'PPL': ppl.get_metric()}) logger.info(f'PPL for file {sample_file} = {ppl.get_metric()}')
class HatefulMemeModel(Model): def __init__(self, vocab: Vocabulary, text_model_name: str): super().__init__(vocab) self._text_model = BertForSequenceClassification.from_pretrained( text_model_name) self._num_labels = vocab.get_vocab_size() self._accuracy = Average() self._auc = Auc() self._softmax = torch.nn.Softmax(dim=1) def forward( self, source_tokens: TextFieldTensors, box_features: Optional[Tensor] = None, box_coordinates: Optional[Tensor] = None, box_mask: Optional[Tensor] = None, label: Optional[Tensor] = None, metadata: Optional[Dict] = None, ) -> Dict[str, torch.Tensor]: input_ids = source_tokens["tokens"]["token_ids"] input_mask = source_tokens["tokens"]["mask"] token_type_ids = source_tokens["tokens"]["type_ids"] outputs = self._text_model( input_ids=input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, return_dict=True, labels=label, ) if label is not None: predictions = torch.argmax(self._softmax(outputs.logits), dim=-1) for index in range(predictions.shape[0]): correct = float((predictions[index] == label[index])) self._accuracy(int(correct)) self._auc(predictions, label) return outputs @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics["accuracy"] = self._accuracy.get_metric(reset=reset) metrics["auc"] = self._auc.get_metric(reset=reset) return metrics
class NamesClassifier(Model): """ NamesClassifier that takes in a name and label. It performs forward passes and uptades the metrics such as loss and accuracy. """ def __init__(self, char_embeddings: TextFieldEmbedder, encoder: Seq2VecEncoder, vocab: Vocabulary) -> None: super().__init__(vocab) # Initialize embedding vector. self.char_embeddings = char_embeddings # Initialize encode self.encoder = encoder # Initialize hidden-tag layer. # It outputs score for wach label. self.hidden2tag = torch.nn.Linear( in_features=encoder.get_output_dim(), out_features=vocab.get_vocab_size('labels')) # Initialize the average metric. self.accuracy = Average() # it’s faster and has better numerical properties compared to Softmax self.m = LogSoftmax() # The negative log likelihood loss. It is useful to train a # classification problem with `C` classes self.loss = NLLLoss() @overrides def forward(self, name: Dict[str, torch.Tensor], label: torch.Tensor = None) -> Dict[str, torch.Tensor]: # To ignore the some sepcific indices, create a mask mask = get_text_field_mask(name) # Craete embeddings given a name embeddings = self.char_embeddings(name) # Encode the embeddings with mask encoder_out = self.encoder(embeddings, mask) # Calculate the logit scores tag_logits = self.hidden2tag(encoder_out) # Update the metrics and return output output = {"tag_logits": tag_logits} if label is not None: output["loss"] = self.loss(self.m(tag_logits), label) prediction = tag_logits.max(1)[1] self.accuracy(prediction.eq(label).double().mean()) return output @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: # Simply return accuracy after each pass return {"accuracy": float(self.accuracy.get_metric(reset))}
class MajorityClassifier(Model): def __init__(self, vocab: Vocabulary) -> None: super().__init__(vocab) self.vocab = vocab self.current_task = "default" self.tasks_vocabulary = {"default": vocab} self.classification_layers = torch.nn.ModuleList([torch.nn.Linear(in_features=10, out_features=self.vocab.get_vocab_size('labels'))]) self.loss_function = torch.nn.CrossEntropyLoss() self.accuracy = CategoricalAccuracy() self.average = Average() def add_task(self, task_tag: str, vocab: Vocabulary): self.tasks_vocabulary[task_tag] = vocab def set_task(self, task_tag: str): self.current_task = task_tag self.vocab = self.tasks_vocabulary[task_tag] def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor=None) -> Dict[str, torch.Tensor]: logi=np.zeros(2) if self.current_task == "cola": logi=np.zeros(2) logi[self.vocab.get_token_index("1", "labels")] = 1 logi=torch.Tensor(np.repeat([logi], label.size(),0)) elif self.current_task == "trec": logi=np.zeros(6) logi[self.vocab.get_token_index("ENTY", "labels")] = 1 logi=torch.Tensor(np.repeat([logi], label.size(),0)) elif self.current_task == "sst": logi=np.zeros(5) logi[self.vocab.get_token_index("3", "labels")] = 1 logi=torch.Tensor(np.repeat([logi], label.size(),0)) elif self.current_task == "subjectivity": logi=np.zeros(2) logi[self.vocab.get_token_index("SUBJECTIVE", "labels")] = 1 logi=torch.Tensor(np.repeat([logi], label.size(0),0)) output = {} #print("Going foward , do we have labels", label) if label is not None: _, preds = logi.max(dim=1) self.average(matthews_corrcoef(label.data.cpu().numpy(), preds.data.cpu().numpy())) self.accuracy(logi, label) output["loss"] = torch.tensor([0]) return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self.accuracy.get_metric(reset), "average": self.average.get_metric(reset)}
def perplexity(lm_path, pickle_path): import kenlm model = kenlm.LanguageModel(lm_path) ppl = Average() with open(pickle_path, 'rb') as sf: dialog_dict = pickle.load(sf) num_responses = len(dialog_dict) - 2 responses_list = [] for rno in range(num_responses): responses = [response for response in dialog_dict[f'response_{rno+1}']] responses_list.append(responses) flattened_responses = [response for responses in responses_list for response in responses] with tqdm(flattened_responses, desc='Computing PPL') as pbar: for sentence in pbar: ppl(model.perplexity(sentence)) pbar.set_postfix({'PPL': ppl.get_metric()}) logger.info(f'PPL for file {pickle_path} = {ppl.get_metric()}')
class Decoder(torch.nn.Module, Registrable): """``Decoder`` class is a wrapper for different decoders.""" def __init__(self, vocab: Vocabulary): super().__init__() # type: ignore self.vocab = vocab self._nll = Average() self._ppl = WordPPL() self._start_index = self.vocab.get_token_index(START_SYMBOL) self._end_index = self.vocab.get_token_index(END_SYMBOL) self._pad_index = self.vocab.get_token_index( self.vocab._padding_token) # noqa: WPS437 self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) def forward( self, encoder_outs: Dict[str, Any], target_tokens: Dict[str, torch.LongTensor], ) -> Dict[str, torch.Tensor]: """Run the module's forward function.""" raise NotImplementedError def post_process( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Post process after decoding.""" raise NotImplementedError def get_metrics(self, reset: bool = False): """Collect all available metrics.""" all_metrics: Dict[str, float] = {} if not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update({'nll': float(self._nll.get_metric(reset=reset))}) all_metrics.update({'_ppl': float(self._ppl.get_metric(reset=reset))}) return all_metrics
class MyModel(Model): def __init__( self, text_field_embedder: TextFieldEmbedder, vocab: Vocabulary, seq2vec_encoder: Seq2VecEncoder = None, dropout: float = None, regularizer: RegularizerApplicator = None, ): super().__init__(vocab, regularizer) if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = None self.sym_size = sym_size self.embeddings = text_field_embedder self.vec_encoder = seq2vec_encoder self.hidden_dim = self.vec_encoder.get_output_dim() self.linear_class = torch.nn.Linear(self.hidden_dim, self.sym_size) # self.f_linear = torch.nn.Linear(self.hidden_dim * 2, self.hidden_dim * 2) self.dim = [12, 62, 4, 40, 62] self.true_list = [Average() for i in range(5)] self.pre_total = [Average() for i in range(5)] self.pre_true = [Average() for i in range(5)] self.total_pre = Average() self.total_true = Average() self.total_pre_true = Average() self.total_future_true = Average() self.macro_f = MacroF(self.sym_size) self.turn_acc = Average() self.future_acc = Average() def forward(self, text, label, **args): bs = label.size(0) embeddings = self.embeddings(text) # bs * seq_len * embedding mask = get_text_field_mask(text) # bs * sen_num * sen_len seq_hidden = self.vec_encoder(embeddings, mask) # bs , embedding # Shape: (batch_size, num_labels) topic_probs = F.sigmoid(self.linear_class(seq_hidden)) # topic_weight = torch.ones_like(label) + 2 * label topic_weight = torch.ones_like(label) + label * 4 loss = F.binary_cross_entropy(topic_probs, label.float(), topic_weight.float()) # loss = F.binary_cross_entropy(topic_probs, label.long(), topic_weight.long()) output_dict = { 'loss': loss, 'probs': topic_probs, 'last_hidden': seq_hidden } # _, max_index = torch.max(topic_probs, -1) total_pre_list = [] total_true_list = [] total_pre_true_list = [] pre_index = (topic_probs > 0.5).long() # pre_index = (topic_probs > 0.5).float() total_pre = torch.sum(pre_index) total_true = torch.sum(label) mask_index = (label == 1).long() # mask_index = (label == 1).float() self.macro_f(pre_index.cpu(), label.cpu()) true_positive = (pre_index == label).long() * mask_index # true_positive = (pre_index == label).float() * mask_index st = 0 for i in range(5): total_pre_list.append(torch.sum(pre_index[:, st:st + self.dim[i]])) total_true_list.append(torch.sum(label[:, st:st + self.dim[i]])) total_pre_true_list.append( torch.sum(true_positive[:, st:st + self.dim[i]])) st += self.dim[i] turn_true_num = (torch.sum(true_positive, 1) == torch.sum(mask_index, 1)).long() # turn_true_num = (torch.sum(true_positive, 1) == torch.sum(mask_index, 1)).float() self.turn_acc(torch.sum(turn_true_num).item() / bs) pre_true = torch.sum(true_positive) self.total_pre(total_pre.float().item()) self.total_true(total_true.float().item()) self.total_pre_true(pre_true.float().item()) self.total_future_true( torch.sum((pre_index == args['future']).long() * (args['future'] == 1).long()).item()) # self.total_future_true(torch.sum((pre_index == args['future']) * (args['future'] == 1).float()).item()) for i in range(5): self.pre_total[i](total_pre_list[i].float().item()) self.pre_true[i](total_pre_true_list[i].float().item()) self.true_list[i](total_true_list[i].float().item()) return output_dict def get_metrics(self, reset=False): metrics = {} total_pre = self.total_pre.get_metric(reset=reset) total_pre_true = self.total_pre_true.get_metric(reset=reset) total_true = self.total_true.get_metric(reset=reset) total_futuer_true = self.total_future_true.get_metric(reset=reset) for i in range(5): pre_i = self.pre_total[i].get_metric(reset=reset) pre_true_i = self.pre_true[i].get_metric(reset=reset) true_i = self.true_list[i].get_metric(reset=reset) acc_i, rec_i, f_i = 0., 0., 0. if pre_i > 0: acc_i = pre_true_i / pre_i if true_i > 0: rec_i = pre_true_i / true_i if acc_i + rec_i > 0: f_i = 2 * acc_i * rec_i / (acc_i + rec_i) metrics['f1' + str(i)] = f_i metrics['rc' + str(i)] = rec_i metrics['ac' + str(i)] = acc_i acc, rec, f1, facc = 0., 0., 0., 0. if total_pre > 0: acc = total_pre_true / total_pre facc = total_futuer_true / total_pre if total_true > 0: rec = total_pre_true / total_true if acc + rec > 0: f1 = 2 * acc * rec / (acc + rec) metrics['acc'] = acc metrics['rec'] = rec metrics['f1'] = f1 metrics['macro_f1'] = self.macro_f.get_metric(reset=reset) metrics['turn_acc'] = self.turn_acc.get_metric(reset=reset) metrics['future_acc'] = facc return metrics
class WikiTablesSemanticParser(Model): """ A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question, and produces a logical form that answers the question when executed over the table. The logical form is generated by a `type-constrained`, `transition-based` parser. This is an abstract class that defines most of the functionality related to the transition-based parser. It does not contain the implementation for actually training the parser. You may want to train it using a learning-to-search algorithm, in which case you will want to use ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the correct denotation, you will want to use ``WikiTablesMmlSemanticParser``. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. """ # pylint: disable=abstract-method def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, max_decoding_steps: int, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: super(WikiTablesSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._max_decoding_steps = max_decoding_steps self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = WikiTablesAccuracy(tables_directory) self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(), "entity word average embedding dim", "question embedding dim") self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None if self._use_neighbor_similarity_for_linking: self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) else: self._question_entity_params = None self._question_neighbor_params = None def _get_initial_state_and_scores(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return {"initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score} @staticmethod def _get_neighbor_indices(worlds: List[WikiTablesWorld], num_entities: int, tensor: torch.Tensor) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. """ num_neighbors = 0 for world in worlds: for entity in world.table_graph.entities: if len(world.table_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.table_graph.neighbors[entity]) batch_neighbors = [] for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.table_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.table_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]] # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length(neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) return tensor.new_tensor(batch_neighbors, dtype=torch.long) @staticmethod def _get_type_vector(worlds: List[WikiTablesWorld], num_entities: int, tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the one hot encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith('fb:cell'): entity_type = 1 elif entity.startswith('fb:part'): entity_type = 2 elif entity.startswith('fb:row'): entity_type = 3 else: entity_type = 0 types.append(one_hot_vectors[entity_type]) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: [0, 0, 0, 0]) batch_types.append(padded) return tensor.new_tensor(batch_types), entity_types def _get_linking_probabilities(self, worlds: List[WikiTablesWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. dpd_acc, which is the percentage of the time that our best output action sequence is in the set of action sequences provided by DPD. This is an easy-to-compute lower bound on denotation accuracy for the set of examples where we actually have DPD output. We only score dpd_acc on that subset. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that has DPD output (make sure you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data, but not training data). 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ return { 'dpd_acc': self._action_sequence_accuracy.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'lf_percent': self._has_logical_form.get_metric(reset), } @staticmethod def _create_grammar_state(world: WikiTablesWorld, possible_actions: List[ProductionRuleArray]) -> GrammarState: valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_string = action[0] action_mapping[action_string] = i translated_valid_actions = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = [action_mapping[action_string] for action_string in action_strings] return GrammarState([START_SYMBOL], {}, translated_valid_actions, action_mapping, type_declaration.is_nonterminal) def _embed_actions(self, actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[Tuple[int, int], int]]: """ Given all of the possible actions for all batch instances, produce an embedding for them. There will be significant overlap in this list, as the production rules from the grammar are shared across all batch instances. Our returned tensor has an embedding for each `unique` action, so we also need to return a mapping from the original ``(batch_index, action_index)`` to our new ``global_action_index``, so that we can get the right action embedding during decoding. Returns ------- action_embeddings : ``torch.Tensor`` Has shape ``(num_unique_actions, action_embedding_dim)``. output_action_embeddings : ``torch.Tensor`` Has shape ``(num_unique_actions, action_embedding_dim)``. action_biases : ``torch.Tensor`` Has shape ``(num_unique_actions, 1)``. action_map : ``Dict[Tuple[int, int], int]`` Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in the ``action_embeddings`` tensor. All non-embeddable actions get mapped to `-1` here. """ # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it # needs to be. We used to embed actions differently (using some crazy ideas about # embedding the LHS and RHS separately); we could probably get away with simplifying things # further now that we're just doing a simple embedding for global actions. But I'm leaving # it like this for now to have a minimal change to go from the LHS/RHS embedding to a # single action embedding. embedded_actions = self._action_embedder.weight output_embedded_actions = self._output_action_embedder.weight action_biases = self._action_biases.weight # Now we just need to make a map from `(batch_index, action_index)` to # `global_action_index`. global_action_ids has the list of all unique actions; here we're # going over all of the actions for each batch instance so we can map them to the global # action ids. action_vocab = self.vocab.get_token_to_index_vocabulary(self._rule_namespace) action_map: Dict[Tuple[int, int], int] = {} for batch_index, instance_actions in enumerate(actions): for action_index, action in enumerate(instance_actions): if not action[0]: # This rule is padding. continue global_action_id = action_vocab.get(action[0], -1) action_map[(batch_index, action_index)] = global_action_id return embedded_actions, output_embedded_actions, action_biases, action_map @staticmethod def _map_entity_productions(linking_scores: torch.FloatTensor, worlds: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]]) -> Tuple[torch.Tensor, Dict[Tuple[int, int], int]]: """ Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``. That is, some actions correspond to terminal productions of entities from our table. We need to find those actions and map them to their corresponding entity indices, where the entity index is its position in the list of entities returned by the ``world``. This list is what defines the second dimension of the ``linking_scores`` tensor, so we can use this index to look up linking scores for each action in that tensor. For easier processing later, the mapping that we return is `flattened` - we really want to map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to have to use the result of this mapping to do ``index_selects`` on the ``linking_scores`` tensor. You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices into this flattened tensor. Parameters ---------- linking_scores : ``torch.Tensor`` A tensor representing linking scores between each table entity and each question token. Has shape ``(batch_size, num_entities, num_question_tokens)``. worlds : ``List[WikiTablesWorld]`` The ``World`` for each batch instance. The ``World`` contains a reference to the ``TableKnowledgeGraph`` that defines the set of entities in the linking. actions : ``List[List[ProductionRuleArray]]`` The list of possible actions for each batch instance. Our action indices are defined in terms of this list, so we'll find entity productions in this list and map them to entity indices from the entity list we get from the ``World``. Returns ------- flattened_linking_scores : ``torch.Tensor`` A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities, num_question_tokens)``. actions_to_entities : ``Dict[Tuple[int, int], int]`` A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``, representing which action indices correspond to which entity indices in the returned ``flattened_linking_scores`` tensor. """ batch_size, num_entities, num_question_tokens = linking_scores.size() entity_map: Dict[Tuple[int, str], int] = {} for batch_index, world in enumerate(worlds): for entity_index, entity in enumerate(world.table_graph.entities): entity_map[(batch_index, entity)] = batch_index * num_entities + entity_index actions_to_entities: Dict[Tuple[int, int], int] = {} for batch_index, action_list in enumerate(actions): for action_index, action in enumerate(action_list): if not action[0]: # This action is padding. continue _, production = action[0].split(' -> ') entity_index = entity_map.get((batch_index, production), None) if entity_index is not None: actions_to_entities[(batch_index, action_index)] = entity_index flattened_linking_scores = linking_scores.view(batch_size * num_entities, num_question_tokens) return flattened_linking_scores, actions_to_entities @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class KglmDisc(Model): """ Knowledge graph language model discriminator (for importance sampling). Parameters ---------- vocab : ``Vocabulary`` The model vocabulary. """ def __init__(self, vocab: Vocabulary, token_embedder: TextFieldEmbedder, entity_embedder: TextFieldEmbedder, relation_embedder: TextFieldEmbedder, knowledge_graph_path: str, use_shortlist: bool, hidden_size: int, num_layers: int, cutoff: int = 30, tie_weights: bool = False, dropout: float = 0.4, dropouth: float = 0.3, dropouti: float = 0.65, dropoute: float = 0.1, wdrop: float = 0.5, alpha: float = 2.0, beta: float = 1.0, initializer: InitializerApplicator = InitializerApplicator()) -> None: super(KglmDisc, self).__init__(vocab) # We extract the `Embedding` layers from the `TokenEmbedders` to apply dropout later on. # pylint: disable=protected-access self._token_embedder = token_embedder._token_embedders['tokens'] self._entity_embedder = entity_embedder._token_embedders['entity_ids'] self._relation_embedder = relation_embedder._token_embedders['relations'] self._recent_entities = RecentEntities(cutoff=cutoff) self._knowledge_graph_lookup = KnowledgeGraphLookup(knowledge_graph_path, vocab=vocab) self._use_shortlist = use_shortlist self._hidden_size = hidden_size self._num_layers = num_layers self._cutoff = cutoff self._tie_weights = tie_weights # Dropout self._locked_dropout = LockedDropout() self._dropout = dropout self._dropouth = dropouth self._dropouti = dropouti self._dropoute = dropoute self._wdrop = wdrop # Regularization strength self._alpha = alpha self._beta = beta # RNN Encoders. entity_embedding_dim = entity_embedder.get_output_dim() token_embedding_dim = token_embedder.get_output_dim() self.entity_embedding_dim = entity_embedding_dim self.token_embedding_dim = token_embedding_dim rnns: List[torch.nn.Module] = [] for i in range(num_layers): if i == 0: input_size = token_embedding_dim else: input_size = hidden_size if i == num_layers - 1: output_size = token_embedding_dim + 2 * entity_embedding_dim else: output_size = hidden_size rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns] self.rnns = torch.nn.ModuleList(rnns) # Various linear transformations. self._fc_mention_type = torch.nn.Linear( in_features=token_embedding_dim, out_features=4) if not use_shortlist: self._fc_new_entity = torch.nn.Linear( in_features=entity_embedding_dim, out_features=vocab.get_vocab_size('entity_ids')) if tie_weights: self._fc_new_entity.weight = self._entity_embedder.weight self._state: Optional[Dict[str, Any]] = None # Metrics self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN) self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk')) self._avg_mention_type_loss = Average() self._avg_new_entity_loss = Average() self._avg_knowledge_graph_entity_loss = Average() self._new_mention_f1 = F1Measure(positive_label=1) self._kg_mention_f1 = F1Measure(positive_label=2) self._new_entity_accuracy = CategoricalAccuracy() self._new_entity_accuracy20 = CategoricalAccuracy(top_k=20) self._parent_ppl = Ppl() self._relation_ppl = Ppl() initializer(self) def sample(self, source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], reset: torch.Tensor, metadata: Dict[str, Any], alias_copy_inds: torch.Tensor, shortlist: Dict[str, torch.Tensor] = None, **kwargs) -> Dict[str, Any]: # **kwargs intended to eat the other fields if they are provided. """ Sampling annotations for the generative model. Note that unlike forward, this function expects inputs from a **generative** dataset reader, not a **discriminative** one. """ # Tensorize the alias_database - this will only perform the operation once. alias_database = metadata[0]['alias_database'] alias_database.tensorize(vocab=self.vocab) # Reset the model if needed if reset.any() and (self._state is not None): for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) self._recent_entities.reset(reset) logp = 0.0 mask = get_text_field_mask(target).byte() # We encode the target tokens (**not** source) since the discriminative model makes # predictions on the current token, but the generative model expects labels for the # **next** (e.g. target) token! encoded, *_ = self._encode_source(target['tokens']) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1) # Compute new mention logits mention_logits = self._fc_mention_type(encoded_token) mention_probs = F.softmax(mention_logits, dim=-1) mention_type = parallel_sample(mention_probs) mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log() mention_logp[~mask] = 0 mention_logp = mention_logp.sum() # Compute entity logits new_entity_mask = mention_type.eq(1) new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist) if self._use_shortlist: # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up shortlist_mask = get_text_field_mask(shortlist) new_entity_probs = masked_softmax(new_entity_logits, shortlist_mask) shortlist_inds = torch.zeros_like(mention_type) # Some sequences may be full of padding in which case the shortlist # is empty not_just_padding = shortlist_mask.byte().any(-1) shortlist_inds[not_just_padding] = parallel_sample(new_entity_probs[not_just_padding]) shortlist_inds[~new_entity_mask] = 0 _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log() new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds) else: new_entity_logits = new_entity_logits # If not using shortlist, then samples are indexed w.r.t to the global vocab new_entity_probs = F.softmax(new_entity_logits, dim=-1) new_entity_samples = parallel_sample(new_entity_probs) _new_entity_logp = new_entity_probs.gather(-1, new_entity_samples.unsqueeze(-1)).log() shortlist_inds = None # Zero out masked tokens and non-new entity predictions _new_entity_logp[~mask] = 0 _new_entity_logp[~new_entity_mask] = 0 new_entity_logp = _new_entity_logp.sum() # Start filling in the entity ids entity_ids = torch.zeros_like(target['tokens']) entity_ids[new_entity_mask] = new_entity_samples[new_entity_mask] # ...UGH we also need the raw ids - remapping time raw_entity_ids = torch.zeros_like(target['tokens']) for *index, entity_id in nested_enumerate(entity_ids.tolist()): token = self.vocab.get_token_from_index(entity_id, 'entity_ids') raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids') raw_entity_ids[tuple(index)] = raw_entity_id # Derived mentions need to be computed sequentially. parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1) derived_entity_mask = mention_type.eq(2) derived_entity_logp = 0.0 sequence_length = target['tokens'].shape[1] for i in range(sequence_length): current_mask = derived_entity_mask[:, i] & mask[:, i] # ------------------- SAMPLE PARENTS --------------------- # Update recent entities with **current** entity only current_entity_id = entity_ids[:, i].unsqueeze(1) candidate_ids, candidate_mask = self._recent_entities(current_entity_id) # If no mentions are derived, there is no point continuing after entities have been updated. if not current_mask.any(): continue # Otherwise we proceed candidate_embeddings = self._entity_embedder(candidate_ids) # Compute logits w.r.t **current** hidden state only current_head_encoding = encoded_head[:, i].unsqueeze(1) selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2)) selection_probs = masked_softmax(selection_logits, candidate_mask) # Only sample if there is at least one viable candidate (e.g. if a sampling distribution # has no probability mass we cannot sample from it). Return zero as the parent for # non-viable distributions. viable_candidate_mask = candidate_mask.any(-1).squeeze() _parent_ids = torch.zeros_like(current_entity_id) parent_logp = torch.zeros_like(current_entity_id, dtype=torch.float32) if viable_candidate_mask.any(): viable_candidate_ids = candidate_ids[viable_candidate_mask] viable_candidate_probs = selection_probs[viable_candidate_mask] viable_parent_samples = parallel_sample(viable_candidate_probs) viable_logp = viable_candidate_probs.gather(-1, viable_parent_samples.unsqueeze(-1)).log() viable_parent_ids = viable_candidate_ids.gather(-1, viable_parent_samples) _parent_ids[viable_candidate_mask] = viable_parent_ids parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1) parent_ids[current_mask, i] = _parent_ids[current_mask] # TODO: Double-check derived_entity_logp += parent_logp[current_mask].sum() # ---------------------- SAMPLE RELATION ----------------------------- # Lookup sampled parent ids in the knowledge graph indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(_parent_ids) relation_embeddings = [self._relation_embedder(r) for r in relations_list] # Sample tail ids current_relation_encoding = encoded_relation[:, i].unsqueeze(1) _raw_tail_ids = torch.zeros_like(_parent_ids).squeeze(-1) _tail_ids = torch.zeros_like(_parent_ids).squeeze(-1) for index, relation_embedding, tail_id_lookup in zip(indices, relation_embeddings, tail_ids_list): # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss # code index has a slice. We don't need that here since there is always a # **single** parent. logits = torch.mv(relation_embedding, current_relation_encoding[index]) # Convert to probability tail_probs = F.softmax(logits, dim=-1) # Sample tail_sample = torch.multinomial(tail_probs, 1) # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced # null parents to zero we shouldn't be accumulating probabilities for unused predictions. tail_logp = tail_probs.gather(-1, tail_sample).log() derived_entity_logp += tail_logp.sum() # Sum is redundant, just need it to make logp a scalar # Map back to raw id raw_tail_id = tail_id_lookup[tail_sample] # Convert raw id to id tail_id_string = self.vocab.get_token_from_index(raw_tail_id.item(), 'raw_entity_ids') tail_id = self.vocab.get_token_index(tail_id_string, 'entity_ids') _raw_tail_ids[index[:-1]] = raw_tail_id _tail_ids[index[:-1]] = tail_id raw_entity_ids[current_mask, i] = _raw_tail_ids[current_mask] # TODO: Double-check entity_ids[current_mask, i] = _tail_ids[current_mask] # TODO: Double-check self._recent_entities.insert(_tail_ids, current_mask) # --------------------- CONTINUE MENTIONS --------------------------------------- continue_mask = mention_type[:, i].eq(3) & mask[:, i] if not current_mask.any() or i == 0: continue raw_entity_ids[continue_mask, i] = raw_entity_ids[continue_mask, i-1] entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1] entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1] parent_ids[continue_mask, i] = parent_ids[continue_mask, i-1] if self._use_shortlist: shortlist_inds[continue_mask, i] = shortlist_inds[continue_mask, i-1] alias_copy_inds[continue_mask, i] = alias_copy_inds[continue_mask, i-1] # Lastly, because entities won't always match the true entity ids, # we need to zero out any alias copy ids that won't be valid. if 'raw_entity_ids' in kwargs: true_raw_entity_ids = kwargs['raw_entity_ids']['raw_entity_ids'] invalid_id_mask = ~true_raw_entity_ids.eq(raw_entity_ids) alias_copy_inds[invalid_id_mask] = 0 # Pass denotes fields that are passed directly from input to output. sample = { 'source': source, # Pass 'target': target, # Pass 'reset': reset, # Pass 'metadata': metadata, # Pass 'mention_type': mention_type, 'raw_entity_ids': {'raw_entity_ids': raw_entity_ids}, 'entity_ids': {'entity_ids': entity_ids}, 'parent_ids': {'entity_ids': parent_ids}, 'relations': {'relations': None}, # We aren't using them - eventually should remove entirely 'shortlist': shortlist, # Pass 'shortlist_inds': shortlist_inds, 'alias_copy_inds': alias_copy_inds } logp = mention_logp + new_entity_logp + derived_entity_logp return {'sample': sample, 'logp': logp} @overrides def forward(self, # pylint: disable=arguments-differ source: Dict[str, torch.Tensor], reset: torch.Tensor, metadata: List[Dict[str, Any]], mention_type: torch.Tensor = None, raw_entity_ids: Dict[str, torch.Tensor] = None, entity_ids: Dict[str, torch.Tensor] = None, parent_ids: Dict[str, torch.Tensor] = None, relations: Dict[str, torch.Tensor] = None, shortlist: Dict[str, torch.Tensor] = None, shortlist_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]: # Tensorize the alias_database - this will only perform the operation once. alias_database = metadata[0]['alias_database'] alias_database.tensorize(vocab=self.vocab) # Reset the model if needed if reset.any() and (self._state is not None): for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] h[:, reset, :] = torch.zeros_like(h[:, reset, :]) c[:, reset, :] = torch.zeros_like(c[:, reset, :]) self._state['layer_%i' % layer] = (h, c) self._recent_entities.reset(reset) if entity_ids is not None: output_dict = self._forward_loop( source=source, alias_database=alias_database, mention_type=mention_type, raw_entity_ids=raw_entity_ids, entity_ids=entity_ids, parent_ids=parent_ids, relations=relations, shortlist=shortlist, shortlist_inds=shortlist_inds) else: # TODO: Figure out what we want here - probably to do some king of inference on # entities / mention types. output_dict = {} return output_dict def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor: # Extract and embed source tokens. source_embeddings = embedded_dropout( embed=self._token_embedder, words=source, dropout=self._dropoute if self.training else 0) source_embeddings = self._locked_dropout(source_embeddings, self._dropouti) # Encode. current_input = source_embeddings hidden_states = [] for layer, rnn in enumerate(self.rnns): # Retrieve previous hidden state for layer. if self._state is not None: prev_hidden = self._state['layer_%i' % layer] else: prev_hidden = None # Forward-pass. output, hidden = rnn(current_input, prev_hidden) output = output.contiguous() # Update hidden state for layer. hidden = tuple(h.detach() for h in hidden) hidden_states.append(hidden) # Apply dropout. if layer == self._num_layers - 1: dropped_output = self._locked_dropout(output, self._dropout) else: dropped_output = self._locked_dropout(output, self._dropouth) current_input = dropped_output encoded = current_input alpha_loss = dropped_output.pow(2).mean() beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean() # Update state. self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)} return encoded, alpha_loss, beta_loss def _mention_type_loss(self, encoded: torch.Tensor, mention_type: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Computes the loss for predicting whether or not the the next token will be part of an entity mention. """ logits = self._fc_mention_type(encoded) mention_type_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask, average='token') # if not self.training: self._new_mention_f1(predictions=logits, gold_labels=mention_type, mask=mask) self._kg_mention_f1(predictions=logits, gold_labels=mention_type, mask=mask) return mention_type_loss def _new_entity_logits(self, encoded: torch.Tensor, shortlist: torch.Tensor) -> torch.Tensor: if self._use_shortlist: # Embed the shortlist entries shortlist_embeddings = embedded_dropout( embed=self._entity_embedder, words=shortlist['entity_ids'], dropout=self._dropoute if self.training else 0) # Compute logits using inner product between the predicted entity embedding and the # embeddings of entities in the shortlist encodings = self._locked_dropout(encoded, self._dropout) logits = torch.bmm(encodings, shortlist_embeddings.transpose(1, 2)) else: logits = self._fc_new_entity(encoded) return logits def _new_entity_loss(self, encoded: torch.Tensor, target_inds: torch.Tensor, shortlist: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor: """ Parameters ========== target_inds : ``torch.Tensor`` Either the shortlist inds if using shortlist, otherwise the target entity ids. """ logits = self._new_entity_logits(encoded, shortlist) if self._use_shortlist: # Take masked softmax to get log probabilties and gather the targets. shortlist_mask = get_text_field_mask(shortlist) log_probs = masked_log_softmax(logits, shortlist_mask) else: logits = logits log_probs = F.log_softmax(logits, dim=-1) num_categories = log_probs.shape[-1] log_probs = log_probs.view(-1, num_categories) target_inds = target_inds.view(-1) target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1) mask = ~target_inds.eq(0) target_log_probs[~mask] = 0 if mask.any(): self._new_entity_accuracy(predictions=log_probs[mask], gold_labels=target_inds[mask]) self._new_entity_accuracy20(predictions=log_probs[mask], gold_labels=target_inds[mask]) return -target_log_probs.sum() / (target_mask.sum() + 1e-13) def _parent_log_probs(self, encoded_head: torch.Tensor, entity_ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor: # Lookup recent entities (which are candidates for parents) and get their embeddings. candidate_ids, candidate_mask = self._recent_entities(entity_ids) logger.debug('Candidate ids shape: %s', candidate_ids.shape) candidate_embeddings = embedded_dropout(self._entity_embedder, words=candidate_ids, dropout=self._dropoute if self.training else 0) # Logits are computed using a general bilinear form that measures the similarity between # the projected hidden state and the embeddings of candidate entities encoded = self._locked_dropout(encoded_head, self._dropout) selection_logits = torch.bmm(encoded, candidate_embeddings.transpose(1, 2)) # Get log probabilities using masked softmax (need to double check mask works properly). # shape: (batch_size, sequence_length, num_candidates) log_probs = masked_log_softmax(selection_logits, candidate_mask) # Now for the tricky part. We need to convert the parent ids to a mask that selects the # relevant probabilities from log_probs. To do this we need to align the candidates with # the parent ids, which can be achieved by an element-wise equality comparison. We also # need to ensure that null parents are not selected. # shape: (batch_size, sequence_length, num_parents, 1) _parent_ids = parent_ids.unsqueeze(-1) batch_size, num_candidates = candidate_ids.shape # shape: (batch_size, 1, 1, num_candidates) _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates) # shape: (batch_size, sequence_length, num_parents, num_candidates) is_parent = _parent_ids.eq(_candidate_ids) # shape: (batch_size, 1, 1, num_candidates) non_null = ~_candidate_ids.eq(0) # Since multiplication is addition in log-space, we can apply mask by adding its log (+ # some small constant for numerical stability). mask = is_parent & non_null masked_log_probs = log_probs.unsqueeze(2) + (mask.float() + 1e-45).log() logger.debug('Masked log probs shape: %s', masked_log_probs.shape) # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would # be to marginalize it out. However, since our data is sparse (the last two dims are # essentially a delta function) this would add a lot of unneccesary terms to the computation graph. # To get around this we are going to try to use a gather. _, index = torch.max(mask, dim=-1, keepdim=True) target_log_probs = torch.gather(masked_log_probs, dim=-1, index=index).squeeze(-1) return target_log_probs def _relation_log_probs(self, encoded_relation: torch.Tensor, raw_entity_ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor: # Lookup edges out of parents indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(parent_ids) # Embed relations relation_embeddings = [self._relation_embedder(r) for r in relations_list] # Logits are computed using a general bi-linear form that measures the similarity between # the projected hidden state and the embeddings of relations encoded = self._locked_dropout(encoded_relation, self._dropout) # This is a little funky, but to avoid massive amounts of padding we are going to just # iterate over the relation and tail_id vectors one-by-one. # shape: (batch_size, sequence_length, num_parents, num_relations) target_log_probs = encoded.new_empty(*parent_ids.shape).fill_(math.log(1e-45)) for index, parent_id, relation_embedding, tail_id in zip(indices, parent_ids_list, relation_embeddings, tail_ids_list): # First we compute the score for each relation w.r.t the current encoding, and convert # the scores to log-probabilities logits = torch.mv(relation_embedding, encoded[index[:-1]]) logger.debug('Relation logits shape: %s', logits.shape) log_probs = F.log_softmax(logits, dim=-1) # Next we gather the log probs for edges with the correct tail entity and sum them up target_id = raw_entity_ids[index[:-1]] mask = tail_id.eq(target_id) relevant_log_probs = log_probs.masked_select(tail_id.eq(target_id)) target_log_prob = torch.logsumexp(relevant_log_probs, dim=0) target_log_probs[index] = target_log_prob return target_log_probs def _knowledge_graph_entity_loss(self, encoded_head: torch.Tensor, encoded_relation: torch.Tensor, raw_entity_ids: torch.Tensor, entity_ids: torch.Tensor, parent_ids: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor: # First get the log probabilities of the parents and relations that lead to the current # entity. parent_log_probs = self._parent_log_probs(encoded_head, entity_ids, parent_ids) relation_log_probs = self._relation_log_probs(encoded_relation, raw_entity_ids, parent_ids) # Next take their product + marginalize combined_log_probs = parent_log_probs + relation_log_probs target_log_probs = torch.logsumexp(combined_log_probs, dim=-1) # Zero out any non-kg predictions mask = ~parent_ids.eq(0).all(dim=-1) target_log_probs = target_log_probs * mask.float() # If validating, measure ppl of the predictions: # if not self.training: self._parent_ppl(-torch.logsumexp(parent_log_probs, dim=-1)[mask].sum(), mask.float().sum()) self._relation_ppl(-torch.logsumexp(relation_log_probs, dim=-1)[mask].sum(), mask.float().sum()) # Lastly return the tokenwise average loss return -target_log_probs.sum() / (target_mask.sum() + 1e-13) def _forward_loop(self, source: Dict[str, torch.Tensor], alias_database: AliasDatabase, mention_type: torch.Tensor, raw_entity_ids: Dict[str, torch.Tensor], entity_ids: Dict[str, torch.Tensor], parent_ids: Dict[str, torch.Tensor], relations: Dict[str, torch.Tensor], shortlist: Dict[str, torch.Tensor], shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]: # Get the token mask and extract indexed text fields. # shape: (batch_size, sequence_length) target_mask = get_text_field_mask(source) source = source['tokens'] raw_entity_ids = raw_entity_ids['raw_entity_ids'] entity_ids = entity_ids['entity_ids'] parent_ids = parent_ids['entity_ids'] relations = relations['relations'] logger.debug('Source & Target shape: %s', source.shape) logger.debug('Entity ids shape: %s', entity_ids.shape) logger.debug('Relations & Parent ids shape: %s', relations.shape) logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape) # Embed source tokens. # shape: (batch_size, sequence_length, embedding_dim) encoded, alpha_loss, beta_loss = self._encode_source(source) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1) # Predict whether or not the next token will be an entity mention, and if so which type. mention_type_loss = self._mention_type_loss(encoded_token, mention_type, target_mask) self._avg_mention_type_loss(float(mention_type_loss)) # For new mentions, predict which entity (among those in the supplied shortlist) will be # mentioned. if self._use_shortlist: new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation, shortlist_inds, shortlist, target_mask) else: new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation, entity_ids, None, target_mask) self._avg_new_entity_loss(float(new_entity_loss)) # For derived mentions, first predict which parent(s) to expand... knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head, encoded_relation, raw_entity_ids, entity_ids, parent_ids, target_mask) self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss)) # Compute total loss loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss # Activation regularization if self._alpha: loss = loss + self._alpha * alpha_loss # Temporal activation regularization (slowness) if self._beta: loss = loss + self._beta * beta_loss return {'loss': loss} @overrides def train(self, mode=True): # TODO: This is a temporary hack to ensure that the internal state resets when the model # switches from training to evaluation. The complication arises from potentially differing # batch sizes (e.g. the `reset` tensor will not be the right size). # In future implementations this should be handled more robustly. super().train(mode) self._state = None @overrides def eval(self): # TODO: See train. super().eval() self._state = None def get_metrics(self, reset: bool = False) -> Dict[str, float]: out = { 'type': self._avg_mention_type_loss.get_metric(reset), 'new': self._avg_new_entity_loss.get_metric(reset), 'kg': self._avg_knowledge_graph_entity_loss.get_metric(reset), } # if not self.training: p, r, f = self._new_mention_f1.get_metric(reset) out['new_p'] = p out['new_r'] = r out['new_f1'] = f p, r, f = self._kg_mention_f1.get_metric(reset) out['kg_p'] = p out['kg_r'] = r out['kg_f1'] = f out['new_ent_acc'] = self._new_entity_accuracy.get_metric(reset) out['new_ent_acc_20'] = self._new_entity_accuracy20.get_metric(reset) out['parent_ppl'] = self._parent_ppl.get_metric(reset) out['relation_ppl'] = self._relation_ppl.get_metric(reset) return out
class WikiTablesSemanticParser(Model): u""" A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question, and produces a logical form that answers the question when executed over the table. The logical form is generated by a `type-constrained`, `transition-based` parser. This is an abstract class that defines most of the functionality related to the transition-based parser. It does not contain the implementation for actually training the parser. You may want to train it using a learning-to-search algorithm, in which case you will want to use ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the correct denotation, you will want to use ``WikiTablesMmlSemanticParser``. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. """ # pylint: disable=abstract-method def __init__(self, vocab, question_embedder, action_embedding_dim, encoder, entity_encoder, max_decoding_steps, use_neighbor_similarity_for_linking=False, dropout=0.0, num_linking_features=10, rule_namespace=u'rule_labels', tables_directory=u'/wikitables/'): super(WikiTablesSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._max_decoding_steps = max_decoding_steps self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = WikiTablesAccuracy(tables_directory) self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(), u"entity word average embedding dim", u"question embedding dim") self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None if self._use_neighbor_similarity_for_linking: self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) else: self._question_entity_params = None self._question_neighbor_params = None def _get_initial_state_and_scores(self, question, table, world, actions, example_lisp_string=None, add_world_to_initial_state=False, checklist_states=None): u""" Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table[u'text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { u'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table[u'linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions( actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions( linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState( batch_indices=range(batch_size), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return { u"initial_state": initial_state, u"linking_scores": linking_scores, u"feature_scores": feature_scores, u"similarity_scores": question_entity_similarity_max_score } @staticmethod def _get_neighbor_indices(worlds, num_entities, tensor): u""" This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. """ num_neighbors = 0 for world in worlds: for entity in world.table_graph.entities: if len(world.table_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.table_graph.neighbors[entity]) batch_neighbors = [] for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.table_graph.entities entity2index = dict( (entity, i) for i, entity in enumerate(entities)) entity2neighbors = world.table_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [ entity2index[n] for n in entity2neighbors[entity] ] # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length( neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) return tensor.new_tensor(batch_neighbors, dtype=torch.long) @staticmethod def _get_type_vector(worlds, num_entities, tensor): u""" Produces the one hot encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith(u'fb:cell'): entity_type = 1 elif entity.startswith(u'fb:part'): entity_type = 2 elif entity.startswith(u'fb:row'): entity_type = 3 else: entity_type = 0 types.append(one_hot_vectors[entity_type]) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: [0, 0, 0, 0]) batch_types.append(padded) return tensor.new_tensor(batch_types), entity_types def _get_linking_probabilities(self, worlds, linking_scores, question_mask, entity_type_dict): u""" Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select( 1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros( num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted, targets): # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max( torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() #overrides def get_metrics(self, reset=False): u""" We track three metrics here: 1. dpd_acc, which is the percentage of the time that our best output action sequence is in the set of action sequences provided by DPD. This is an easy-to-compute lower bound on denotation accuracy for the set of examples where we actually have DPD output. We only score dpd_acc on that subset. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that has DPD output (make sure you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data, but not training data). 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ return { u'dpd_acc': self._action_sequence_accuracy.get_metric(reset), u'denotation_acc': self._denotation_accuracy.get_metric(reset), u'lf_percent': self._has_logical_form.get_metric(reset), } @staticmethod def _create_grammar_state(world, possible_actions): valid_actions = world.get_valid_actions() action_mapping = {} for i, action in enumerate(possible_actions): action_string = action[0] action_mapping[action_string] = i translated_valid_actions = {} for key, action_strings in list(valid_actions.items()): translated_valid_actions[key] = [ action_mapping[action_string] for action_string in action_strings ] return GrammarState([START_SYMBOL], {}, translated_valid_actions, action_mapping, type_declaration.is_nonterminal) def _embed_actions(self, actions): u""" Given all of the possible actions for all batch instances, produce an embedding for them. There will be significant overlap in this list, as the production rules from the grammar are shared across all batch instances. Our returned tensor has an embedding for each `unique` action, so we also need to return a mapping from the original ``(batch_index, action_index)`` to our new ``global_action_index``, so that we can get the right action embedding during decoding. Returns ------- action_embeddings : ``torch.Tensor`` Has shape ``(num_unique_actions, action_embedding_dim)``. output_action_embeddings : ``torch.Tensor`` Has shape ``(num_unique_actions, action_embedding_dim)``. action_biases : ``torch.Tensor`` Has shape ``(num_unique_actions, 1)``. action_map : ``Dict[Tuple[int, int], int]`` Maps ``(batch_index, action_index)`` in the input action list to ``action_index`` in the ``action_embeddings`` tensor. All non-embeddable actions get mapped to `-1` here. """ # TODO(mattg): This whole action pipeline might be a whole lot more complicated than it # needs to be. We used to embed actions differently (using some crazy ideas about # embedding the LHS and RHS separately); we could probably get away with simplifying things # further now that we're just doing a simple embedding for global actions. But I'm leaving # it like this for now to have a minimal change to go from the LHS/RHS embedding to a # single action embedding. embedded_actions = self._action_embedder.weight output_embedded_actions = self._output_action_embedder.weight action_biases = self._action_biases.weight # Now we just need to make a map from `(batch_index, action_index)` to # `global_action_index`. global_action_ids has the list of all unique actions; here we're # going over all of the actions for each batch instance so we can map them to the global # action ids. action_vocab = self.vocab.get_token_to_index_vocabulary( self._rule_namespace) action_map = {} for batch_index, instance_actions in enumerate(actions): for action_index, action in enumerate(instance_actions): if not action[0]: # This rule is padding. continue global_action_id = action_vocab.get(action[0], -1) action_map[(batch_index, action_index)] = global_action_id return embedded_actions, output_embedded_actions, action_biases, action_map @staticmethod def _map_entity_productions(linking_scores, worlds, actions): u""" Constructs a map from ``(batch_index, action_index)`` to ``(batch_index * entity_index)``. That is, some actions correspond to terminal productions of entities from our table. We need to find those actions and map them to their corresponding entity indices, where the entity index is its position in the list of entities returned by the ``world``. This list is what defines the second dimension of the ``linking_scores`` tensor, so we can use this index to look up linking scores for each action in that tensor. For easier processing later, the mapping that we return is `flattened` - we really want to map ``(batch_index, action_index)`` to ``(batch_index, entity_index)``, but we are going to have to use the result of this mapping to do ``index_selects`` on the ``linking_scores`` tensor. You can't do ``index_select`` with tuples, so we flatten ``linking_scores`` to have shape ``(batch_size * num_entities, num_question_tokens)``, and return shifted indices into this flattened tensor. Parameters ---------- linking_scores : ``torch.Tensor`` A tensor representing linking scores between each table entity and each question token. Has shape ``(batch_size, num_entities, num_question_tokens)``. worlds : ``List[WikiTablesWorld]`` The ``World`` for each batch instance. The ``World`` contains a reference to the ``TableKnowledgeGraph`` that defines the set of entities in the linking. actions : ``List[List[ProductionRuleArray]]`` The list of possible actions for each batch instance. Our action indices are defined in terms of this list, so we'll find entity productions in this list and map them to entity indices from the entity list we get from the ``World``. Returns ------- flattened_linking_scores : ``torch.Tensor`` A flattened version of ``linking_scores``, with shape ``(batch_size * num_entities, num_question_tokens)``. actions_to_entities : ``Dict[Tuple[int, int], int]`` A mapping from ``(batch_index, action_index)`` to ``(batch_size * num_entities)``, representing which action indices correspond to which entity indices in the returned ``flattened_linking_scores`` tensor. """ batch_size, num_entities, num_question_tokens = linking_scores.size() entity_map = {} for batch_index, world in enumerate(worlds): for entity_index, entity in enumerate(world.table_graph.entities): entity_map[( batch_index, entity)] = batch_index * num_entities + entity_index actions_to_entities = {} for batch_index, action_list in enumerate(actions): for action_index, action in enumerate(action_list): if not action[0]: # This action is padding. continue _, production = action[0].split(u' -> ') entity_index = entity_map.get((batch_index, production), None) if entity_index is not None: actions_to_entities[(batch_index, action_index)] = entity_index flattened_linking_scores = linking_scores.view( batch_size * num_entities, num_question_tokens) return flattened_linking_scores, actions_to_entities #overrides def decode(self, output_dict): u""" This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict[u'action_mapping'] best_actions = output_dict[u"best_action_sequence"] debug_infos = output_dict[u'debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( izip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in izip( predicted_actions, debug_info): action_info = {} action_info[u'predicted_action'] = predicted_action considered_actions = action_debug_info[u'considered_actions'] probabilities = action_debug_info[u'probabilities'] actions = [] for action, probability in izip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = izip(*actions) action_info[u'considered_actions'] = considered_actions action_info[u'action_probabilities'] = probabilities action_info[u'question_attention'] = action_debug_info.get( u'question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict[u"predicted_actions"] = batch_action_info return output_dict
class WikiTablesErmSemanticParser(WikiTablesSemanticParser): """ A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search for logical forms that yield the correct denotations. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. Passed to super class. action_embedding_dim : ``int`` Dimension to use for action embeddings. Passed to super class. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. Passed to super class. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. Passed to super class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. decoder_beam_size : ``int`` Beam size to be used by the ExpectedRiskMinimization algorithm. decoder_num_finished_states : ``int`` Number of finished states for which costs will be computed by the ExpectedRiskMinimization algorithm. max_decoding_steps : ``int`` Maximum number of steps the decoder should take before giving up. Used both during training and evaluation. Passed to super class. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. Passed to super class. normalize_beam_score_by_length : ``bool``, optional (default=False) Should we normalize the log-probabilities by length before renormalizing the beam? This was shown to work better for NML by Edunov et al., but that many not be the case for semantic parsing. checklist_cost_weight : ``float``, optional (default=0.6) Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about denotation accuracy. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. Passed to super class. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). Passed to super class. num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 10 here matches the default in the ``KnowledgeGraphField``, which is to use all ten defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. Passed to super class. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. Passed to super class. mml_model_file : ``str``, optional (default=None) If you want to initialize this model using weights from another model trained using MML, pass the path to the ``model.tar.gz`` file of that model here. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', mml_model_file: str = None) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, add_action_bias=add_action_bias, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = \ ExpectedRiskMinimization(beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states) self._decoder_step = LinkingCoverageTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # We don't need a separate beam search since the trainer does that already. But we're defining one just to # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch`` # class) in the demo. We'll use this only at test time. self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size) # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if mml_model_file is not None: if os.path.isfile(mml_model_file): archive = load_archive(mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.") def _initialize_weights_from_archive(self, archive: Archive) -> None: logger.info("Initializing weights from MML model.") model_parameters = dict(self.named_parameters()) archived_parameters = dict(archive.model.named_parameters()) question_embedder_weight = "_question_embedder.token_embedder_tokens.weight" if question_embedder_weight not in archived_parameters or \ question_embedder_weight not in model_parameters: raise RuntimeError( "When initializing model weights from an MML model, we need " "the question embedder to be a TokenEmbedder using namespace called " "tokens.") for name, weights in archived_parameters.items(): if name in model_parameters: if name == question_embedder_weight: # The shapes of embedding weights will most likely differ between the two models # because the vocabularies will most likely be different. We will get a mapping # of indices from this model's token indices to the archived model's and copy # the tensor accordingly. vocab_index_mapping = self._get_vocab_index_mapping( archive.model.vocab) archived_embedding_weights = weights.data new_weights = model_parameters[name].data.clone() for index, archived_index in vocab_index_mapping: new_weights[index] = archived_embedding_weights[ archived_index] logger.info("Copied embeddings of %d out of %d tokens", len(vocab_index_mapping), new_weights.size()[0]) else: new_weights = weights.data logger.info("Copying parameter %s", name) model_parameters[name].data.copy_(new_weights) def _get_vocab_index_mapping( self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]: vocab_index_mapping: List[Tuple[int, int]] = [] for index in range(self.vocab.get_vocab_size(namespace='tokens')): token = self.vocab.get_token_from_index(index=index, namespace='tokens') archived_token_index = archived_vocab.get_token_index( token, namespace='tokens') # Checking if we got the UNK token index, because we don't want all new token # representations initialized to UNK token's representation. We do that by checking if # the two tokens are the same. They will not be if the token at the archived index is # UNK. if archived_vocab.get_token_from_index( archived_token_index, namespace="tokens") == token: vocab_index_mapping.append((index, archived_token_index)) return vocab_index_mapping @overrides def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesLanguage], actions: List[List[ProductionRule]], agenda: torch.LongTensor, target_values: List[List[str]] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesLanguage]`` We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``world`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. agenda : ``torch.LongTensor`` Agenda vectors that the checklist vectors will be compared against to compute the checklist cost. target_values : ``List[List[str]]``, optional (default = None) For each instance, a list of target values taken from the example lisp string. We pass this list to the evaluator along with logical forms to compute denotation accuracy. metadata : ``List[Dict[str, Any]]``, optional (default = None) Metadata containing the original tokenized question within a 'question_tokens' field. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state( question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageState( batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=target_values, debug_info=None) if target_values is not None: logger.warning(f"TARGET VALUES: {target_values}") trainer_outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) outputs.update(trainer_outputs) else: initial_state.debug_info = [[] for _ in range(batch_size)] batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] best_final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, target_values, metadata, outputs) return outputs @staticmethod def _get_checklist_info( agenda: torch.LongTensor, all_actions: List[ProductionRule], terminal_productions: Set[str], max_num_terminals: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda, a list of all actions, a set of terminal productions in the corresponding world, and a length to pad the checklist vectors to, and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. ``terminal_productions`` : ``Set[str]`` String representations of terminal productions in the corresponding world. ``max_num_terminals`` : ``int`` Length to which the checklist vectors will be padded till. This is the max number of terminal productions in all the worlds in the batch. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = { int(x) for x in agenda.squeeze(0).detach().cpu().numpy() } # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) while len(target_checklist_list) < max_num_terminals: target_checklist_list.append([0]) terminal_indices.append([-1]) # (max_num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (max_num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask def _get_state_cost(self, worlds: List[WikiTablesLanguage], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError( "_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance)**2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum( state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [ state.possible_actions[batch_index][i][0] for i in action_history ] target_values = state.extras[batch_index] evaluation = False executor_logger = \ logging.getLogger('allennlp.semparse.domain_languages.wikitables_language') executor_logger.setLevel(logging.ERROR) evaluation = world.evaluate_action_sequence(action_strings, target_values) if evaluation: cost = checklist_cost else: cost = checklist_cost + ( 1 - self._checklist_cost_weight) * denotation_cost return cost @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ The base class returns a dict with dpd accuracy, denotation accuracy, and logical form percentage metrics. We add the agenda coverage metric here. """ metrics = super().get_metrics(reset) metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset) return metrics
class Seq2SeqClaimRank(Model): """ A ``Seq2SeqClaimRank`` model. This model is intended to be trained with a multi-instance learning objective that simultaneously tries to: - Decode the given post modifier (e.g. the ``target`` sequence). - Ensure that the model is attending to the proper claims during decoding (which are identified by the ``labels`` variable). The basic architecture is a seq2seq model with attention where the input sequence is the source sentence (without post-modifier), and the output sequence is the post-modifier. The main difference is that instead of performing attention over the input sequence, attention is performed over a collection of claims. Parameters ========== text_field_embedder : ``TextFieldEmbedder`` Embeds words in the source sentence / claims. sentence_encoder : ``Seq2VecEncoder`` Encodes the entire source sentence into a single vector. claim_encoder : ``Seq2SeqEncoder`` Encodes each claim into a single vector. attention : ``Attention`` Type of attention mechanism used. WARNING: Do not normalize attention scores, and make sure to use a sigmoid activation. Otherwise the claim ranking loss will not work properly! max_steps : ``int`` Maximum number of decoding steps. Default: 100 (same as ONMT). beam_size: ``int`` Beam size used during evaluation. Default: 5 (same as ONMT). beta: ``float`` Weight of attention loss term. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sentence_encoder: Seq2VecEncoder, claim_encoder: Seq2SeqEncoder, attention: Attention, max_steps: int = 100, beam_size: int = 5, beta: float = 1.0) -> None: super(Seq2SeqClaimRank, self).__init__(vocab) self.text_field_embedder = text_field_embedder self.sentence_encoder = sentence_encoder self.claim_encoder = TimeDistributed(claim_encoder) # Handles additional sequence dim self.claim_encoder_dim = claim_encoder.get_output_dim() self.attention = attention self.decoder_embedding_dim = text_field_embedder.get_output_dim() self.max_steps = max_steps self.beam_size = beam_size self.beta = beta # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim) # Since we are using the sentence encoding as the initial hidden state to the decoder, the # decoder hidden dim must match the sentence encoder hidden dim. self.decoder_output_dim = sentence_encoder.get_output_dim() self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim, self.decoder_output_dim) self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim, self.decoder_output_dim) # When projecting out we will use attention to combine claim embeddings into a single # context embedding, this will be concatenated with the decoder cell output before being # fed to the projection layer. Hence the expected input size is: # decoder output dim + claim encoder output dim projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim self.output_projection_layer = torch.nn.Linear(projection_input_dim, vocab.get_vocab_size()) self._start_index = self.vocab.get_token_index('<s>') self._end_index = self.vocab.get_token_index('</s>') self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size) pad_index = vocab.get_token_index(vocab._padding_token) self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index}) self.avg_reconstruction_loss = Average() self.avg_claim_scoring_loss = Average() def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ output_projections, _, state = self._prepare_output_projections(last_predictions, state) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward(self, inputs: Dict[str, torch.LongTensor], claims: Dict[str, torch.LongTensor], targets: Dict[str, torch.LongTensor] = None, labels: torch.Tensor = None) -> torch.Tensor: """Forward pass of the model + decoder logic. Parameters ---------- inputs : ``Dict[str, torch.LongTensor]`` Output of `TextField.as_array()` from the `input` field. claims : ``Dict[str, torch.LongTensor]`` Output of `ListField.as_array()` from the `claims` field. targets : ``Dict[str, torch.LongTensor]`` Output of `TextField.as_array()` from the `target` field. Only expected during training and validation. labels : ``torch.Tensor`` Output of `LabelField.as_array()` from the `labels` field, indicating which claims were used. Only expected during training and validation. Returns ------- Dict[str, torch.Tensor] Dictionary containing loss tensor and decoder outputs. """ # Obtain an encoding for each input sentence (e.g. the contexts) input_mask = util.get_text_field_mask(inputs) input_word_embeddings = self.text_field_embedder(inputs) input_encodings = self.sentence_encoder(input_word_embeddings, input_mask) # Next we encode claims. Note that here we have two additional sequence dimensions (since # there are multiple claims per instance, and we want to apply attention at the word # level). To deal with this we need to set `num_wrapping_dims=1` for the embedder, and make # the claim encoder TimeDistributed. claim_mask = util.get_text_field_mask(claims, num_wrapping_dims=1) claim_word_embeddings = self.text_field_embedder(claims, num_wrapping_dims=1) claim_encodings = self.claim_encoder(claim_word_embeddings, claim_mask) # Package the encoder outputs into a state dictionary. state = { 'input_mask': input_mask, 'input_encodings': input_encodings, 'claim_mask': claim_mask, 'claim_encodings': claim_encodings } # If ``target`` (the post-modifier) and ``labels`` (indicator of which claims are used) are # provided then we use them to compute loss. if (targets is not None) and (labels is not None): state = self._init_decoder_state(state) output_dict = self._forward_loop(state, targets, labels) else: output_dict = {} # If model is not training, then we perform beam search for decoding to obtain higher # quality outputs. if not self.training: # Perform beam search state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) # Compute BLEU top_k_predictions = output_dict['predictions'] best_predictions = top_k_predictions[:, 0, :] self.bleu(best_predictions, targets['tokens']) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict['predictions'] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x) for x in indices] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Adds fields to the state required to initialize the decoder.""" batch_size = state['input_mask'].shape[0] # First decoder layer gets jack (trying to approximate the structure in # opennmt's graphic state['decoder_0_h'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) state['decoder_0_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) # Initialize LSTM hidden state (e.g. h_0) with output of the sentence encoder. state['decoder_1_h'] = state['input_encodings'] # Initialize LSTM context state (e.g. c_0) with zeros. state['decoder_1_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) # Initialize previous context. state['prev_context'] = state['input_encodings'].new_zeros(batch_size, self.claim_encoder_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], labels: torch.Tensor) -> Dict[str, torch.Tensor]: """Compute loss using greedy decoding.""" batch_size = state['input_mask'].shape[0] target_tokens = targets['tokens'] num_decoding_steps = target_tokens.shape[1] - 1 # Greedy decoding phase output_logit_list = [] attention_logit_list = [] select_idx_list = [] for timestep in range(num_decoding_steps): # Feed target sequence as input decoder_input = target_tokens[:, timestep] output_logits, attention_logits, state = self._prepare_output_projections(decoder_input, state) # Store output and attention logits output_logit_list.append(output_logits.unsqueeze(1)) attention_logit_list.append(attention_logits.unsqueeze(1)) # Compute reconstruction loss output_logit_tensor = torch.cat(output_logit_list, dim=1) relevant_target_tokens = target_tokens[:, 1:].contiguous() target_mask = util.get_text_field_mask(targets)[:, 1:].contiguous() reconstruction_loss = util.sequence_cross_entropy_with_logits(output_logit_tensor, relevant_target_tokens, target_mask) # Compute claim scoring loss. A loss is computed between **each** attention vector and the # true label. In order for that to work we need to: # a. Tile the source labels (so that they are copied for each word) # b. Mask out padding tokens - this requires taking the outer-product of the target mask # and the claim mask attention_logit_tensor = torch.cat(attention_logit_list, dim=1) claim_level_mask = (state['claim_mask'].sum(-1) > 0).long() attention_mask = target_mask.unsqueeze(-1) * claim_level_mask.unsqueeze(1) labels = labels.unsqueeze(1).repeat(1, num_decoding_steps, 1).float() claim_scoring_loss = F.binary_cross_entropy_with_logits(attention_logit_tensor, labels, reduction='none') claim_scoring_loss *= attention_mask.float() # Apply mask # We want to apply 'batch' reduction (as is done in `sequence_cross_entropy...` which # entails averaging over each dimension. denom = attention_mask for i in range(3): denom = denom.sum(-1) claim_scoring_loss = claim_scoring_loss.sum(-1) / (denom.float() + 1e-13) denom = (denom > 0) total_loss = reconstruction_loss + self.beta * claim_scoring_loss # Update metrics self.avg_reconstruction_loss(reconstruction_loss) self.avg_claim_scoring_loss(claim_scoring_loss) output_dict = { "loss": total_loss, "reconstruction_loss": reconstruction_loss, "claim_scoring_loss": claim_scoring_loss, "attention_logits": attention_logit_tensor } return output_dict def _prepare_output_projections(self, decoder_input: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: # Embed decoder input decoder_word_embeddings = self.text_field_embedder({'tokens': decoder_input}) # Concat with previous context concat = torch.cat((decoder_word_embeddings, state['prev_context']), dim=-1) # Run forward pass of decoder RNN decoder_0_h, decoder_0_c = self.decoder_0_cell(concat, (state['decoder_0_h'], state['decoder_0_c'])) decoder_1_h, decoder_1_c = self.decoder_1_cell(decoder_0_h, (state['decoder_1_h'], state['decoder_1_c'])) state['decoder_0_h'] = decoder_0_h state['decoder_0_c'] = decoder_0_c state['decoder_1_h'] = decoder_1_h state['decoder_1_c'] = decoder_1_c # Compute attention and get context embedding. We get an attention score for each word in # each claim. Then we sum up scores to get a claim level score (so we can use overlap as # supervision). claim_encodings = state['claim_encodings'] claim_mask = state['claim_mask'] batch_size, n_claims, claim_length, dim = claim_encodings.shape flattened_claim_encodings = claim_encodings.view(batch_size, -1, dim) flattened_claim_mask = claim_mask.view(batch_size, -1) flattened_attention_logits = self.attention(decoder_1_h, flattened_claim_encodings, flattened_claim_mask) attention_logits = flattened_attention_logits.view(batch_size, n_claims, claim_length) # Now get claim level encodings by summing word level attention. word_level_attention = util.masked_softmax(attention_logits, claim_mask) claim_encodings = util.weighted_sum(claim_encodings, word_level_attention) # If not training, get max attention word to replace unk if not self.training: max_word = word_level_attention.argmax(dim=-1, keepdim=True) gathered = word_level_attention.gather(dim=-1, index=max_word) max_claim = gathered.squeeze().argmax(dim=-1, keepdim=True) max_word = max_word.squeeze().gather(dim=1, index=max_claim) select_idx = torch.cat((max_claim, max_word), dim=-1) else: select_idx = None # We compute our context directly from the claim word embeddings claim_mask = (claim_mask.sum(-1) > 0).float() attention_logits = attention_logits.sum(-1) attention_weights = torch.sigmoid(attention_logits) * claim_mask normalized_attention_weights = attention_weights / (attention_weights.sum(-1, True) + 1e-13) context_embedding = util.weighted_sum(claim_encodings, normalized_attention_weights) state['prev_context'] = context_embedding # Concatenate RNN output w/ context vector and feed through final hidden layer projection_input = torch.cat((decoder_1_h, context_embedding), dim=-1) output_logits = self.output_projection_layer(projection_input) return output_logits, attention_logits, state def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state['input_mask'].size()[0] start_predictions = state['input_mask'].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self.beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = { 'recon': self.avg_reconstruction_loss.get_metric(reset=reset).data.item(), 'claim': self.avg_claim_scoring_loss.get_metric(reset=reset).data.item() } # Only update BLEU score during validation and evaluation if not self.training: all_metrics.update(self.bleu.get_metric(reset=reset)) return all_metrics
class WikiTablesSemanticParser(Model): """ A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question, and produces a logical form that answers the question when executed over the table. The logical form is generated by a `type-constrained`, `transition-based` parser. This is an abstract class that defines most of the functionality related to the transition-based parser. It does not contain the implementation for actually training the parser. You may want to train it using a learning-to-search algorithm, in which case you will want to use ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the correct denotation, you will want to use ``WikiTablesMmlSemanticParser``. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. """ def __init__( self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, max_decoding_steps: int, add_action_bias: bool = True, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = "rule_labels", ) -> None: super().__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = Average() self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) check_dimensions_match( entity_encoder.get_output_dim(), question_embedder.get_output_dim(), "entity word average embedding dim", "question embedding dim", ) self._num_entity_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._entity_type_encoder_embedding = Embedding( num_embeddings=self._num_entity_types, embedding_dim=self._embedding_dim) self._entity_type_decoder_embedding = Embedding( num_embeddings=self._num_entity_types, embedding_dim=action_embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None if self._use_neighbor_similarity_for_linking: self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) else: self._question_entity_params = None self._question_neighbor_params = None def _get_initial_rnn_and_grammar_state( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesLanguage], actions: List[List[ProductionRuleArray]], outputs: Dict[str, Any], ) -> Tuple[List[RnnStatelet], List[GrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and GrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table["text"] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) if neighbor_indices is not None: # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { "ignored": { "ignored": neighbor_indices + 1 } }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2), ) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table["linking"] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = (projected_question_entity_similarity + projected_question_neighbor_similarity) feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list, )) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs["linking_scores"] = linking_scores if feature_scores is not None: outputs["feature_scores"] = feature_scores outputs["similarity_scores"] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state @staticmethod def _get_neighbor_indices(worlds: List[WikiTablesLanguage], num_entities: int, tensor: torch.Tensor) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[WikiTablesLanguage]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch have no neighbors, None will be returned. """ num_neighbors = 0 for world in worlds: for entity in world.table_graph.entities: if len(world.table_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.table_graph.neighbors[entity]) batch_neighbors = [] no_entities_have_neighbors = True for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.table_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.table_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [ entity2index[n] for n in entity2neighbors[entity] ] if entity_neighbors: no_entities_have_neighbors = False # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length( neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) # It is possible that none of the entities has any neighbors, since our definition of the # knowledge graph allows it when no entities or numbers were extracted from the question. if no_entities_have_neighbors: return None return tensor.new_tensor(batch_neighbors, dtype=torch.long) @staticmethod def _get_type_vector( worlds: List[WikiTablesLanguage], num_entities: int, tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesLanguage]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): # We need numbers to be first, then date columns, then number columns, strings, and # string columns, in that order, because our entities are going to be sorted. We do # a split by type and then a merge later, and it relies on this sorting. if entity.startswith("date_column:"): entity_type = 1 elif entity.startswith("number_column:"): entity_type = 2 elif entity.startswith("string:"): entity_type = 3 elif entity.startswith("string_column:"): entity_type = 4 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types def _get_linking_probabilities( self, worlds: List[WikiTablesLanguage], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int], ) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesLanguage]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "date_column:", followed by "number_column:", "string:", and "string_column:". # This is not a great assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select( 1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros( num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max( torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. lf_retrieval_acc, which is the percentage of the time that our best output action sequence is in the set of action sequences provided by offline search. This is an easy-to-compute lower bound on denotation accuracy for the set of examples where we actually have offline output. We only score lf_retrieval_acc on that subset. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that has DPD output (make sure you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data, but not training data). 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ return { "lf_retrieval_acc": self._action_sequence_accuracy.get_metric(reset), "denotation_acc": self._denotation_accuracy.get_metric(reset), "lf_percent": self._has_logical_form.get_metric(reset), } def _create_grammar_state( self, world: WikiTablesLanguage, possible_actions: List[ProductionRuleArray], linking_scores: torch.Tensor, entity_types: torch.Tensor, ) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The way we represent the valid expansions is a little complicated: we use a dictionary of `action types`, where the key is the action type (like "global", "linked", or whatever your model is expecting), and the value is a tuple representing all actions of that type. The tuple is (input tensor, output tensor, action id). The input tensor has the representation that is used when `selecting` actions, for all actions of this type. The output tensor has the representation that is used when feeding the action to the next step of the decoder (this could just be the same as the input tensor). The action ids are a list of indices into the main action list for each batch instance. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRuleArrays``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``WikiTablesLanguage`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRuleArray]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ # TODO(mattg): Move the "valid_actions" construction to another method. action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_nonterminal_productions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions if any. if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases( global_action_tensor) global_input_embeddings = torch.cat( [global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]["global"] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids), ) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(" -> ")[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding( entity_type_tensor) translated_valid_actions[key]["linked"] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids), ) return GrammarStatelet([START_SYMBOL], translated_valid_actions, world.is_nonterminal) def _compute_validation_outputs( self, actions: List[List[ProductionRuleArray]], best_final_states: Mapping[int, Sequence[GrammarBasedState]], world: List[WikiTablesLanguage], target_list: List[List[str]], metadata: List[Dict[str, Any]], outputs: Dict[str, Any], ) -> None: """ Does common things for validation time: computing logical form accuracy (which is expensive and unnecessary during training), adding visualization info to the output dictionary, etc. This doesn't return anything; instead it `modifies` the given ``outputs`` dictionary, and calls metrics on ``self``. """ batch_size = len(actions) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs["action_mapping"] = action_mapping outputs["best_action_sequence"] = [] outputs["debug_info"] = [] outputs["entities"] = [] outputs["logical_form"] = [] outputs["answer"] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). outputs["logical_form"].append([]) if i in best_final_states: all_action_indices = [ best_final_states[i][j].action_history[0] for j in range(len(best_final_states[i])) ] found_denotation = False for action_indices in all_action_indices: action_strings = [ action_mapping[(i, action_index)] for action_index in action_indices ] has_logical_form = False try: logical_form = world[ i].action_sequence_to_logical_form(action_strings) has_logical_form = True except ParsingError: logical_form = "Error producing logical form" if target_list is not None: denotation_correct = world[i].evaluate_logical_form( logical_form, target_list[i]) else: denotation_correct = False if not found_denotation: try: denotation = world[i].execute(logical_form) if denotation: outputs["answer"].append(denotation) found_denotation = True except ExecutionError: pass if found_denotation: if has_logical_form: self._has_logical_form(1.0) else: self._has_logical_form(0.0) if target_list: self._denotation_accuracy( 1.0 if denotation_correct else 0.0) outputs["best_action_sequence"].append( action_strings) outputs["logical_form"][-1].append(logical_form) if not found_denotation: outputs["answer"].append(None) self._denotation_accuracy(0.0) outputs["debug_info"].append( best_final_states[i][0].debug_info[0]) # type: ignore outputs["entities"].append(world[i].table_graph.entities) else: self._has_logical_form(0.0) self._denotation_accuracy(0.0) if metadata is not None: outputs["question_tokens"] = [ x["question_tokens"] for x in metadata ] @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in the ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict["action_mapping"] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict["debug_info"] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions, debug_info): action_info = {} action_info["predicted_action"] = predicted_action considered_actions = action_debug_info["considered_actions"] probabilities = action_debug_info["probabilities"] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info["considered_actions"] = considered_actions action_info["action_probabilities"] = probabilities action_info["question_attention"] = action_debug_info.get( "question_attention", []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class NlvrCoverageSemanticParser(NlvrSemanticParser): """ ``NlvrSemanticCoverageParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack of annotated logical forms by maximizing coverage of the output sequences over a prespecified agenda. In addition to the signal from coverage, we also compute the denotations given by the logical forms and define a hybrid cost based on coverage and denotation errors. The training process then minimizes the expected value of this cost over an approximate set of logical forms produced by the parser, obtained by performing beam search. Parameters ---------- vocab : ``Vocabulary`` Passed to super-class. sentence_embedder : ``TextFieldEmbedder`` Passed to super-class. action_embedding_dim : ``int`` Passed to super-class. encoder : ``Seq2SeqEncoder`` Passed to super-class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the DecoderStep. beam_size : ``int`` Beam size for the beam search used during training. max_num_finished_states : ``int``, optional (default=None) Maximum number of finished states the trainer should compute costs for. normalize_beam_score_by_length : ``bool``, optional (default=False) Should the log probabilities be normalized by length before renormalizing them? Edunov et al. do this in their work, but we found that not doing it works better. It's possible they did this because their task is NMT, and longer decoded sequences are not necessarily worse, and shouldn't be penalized, while we will mostly want to penalize longer logical forms. max_decoding_steps : ``int`` Maximum number of steps for the beam search during training. dropout : ``float``, optional (default=0.0) Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions. checklist_cost_weight : ``float``, optional (default=0.6) Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about denotation accuracy. dynamic_cost_weight : ``Dict[str, Union[int, float]]``, optional (default=None) A dict containing keys ``wait_num_epochs`` and ``rate`` indicating the number of steps after which we should start decreasing the weight on checklist cost in favor of denotation cost, and the rate at which we should do it. We will decrease the weight in the following way - ``checklist_cost_weight = checklist_cost_weight - rate * checklist_cost_weight`` starting at the apropriate epoch. The weight will remain constant if this is not provided. penalize_non_agenda_actions : ``bool``, optional (default=False) Should we penalize the model for producing terminal actions that are outside the agenda? initial_mml_model_file : ``str`` , optional (default=None) If you want to initialize this model using weights from another model trained using MML, pass the path to the ``model.tar.gz`` file of that model here. """ def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, max_num_finished_states: int = None, dropout: float = 0.0, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, dynamic_cost_weight: Dict[str, Union[int, float]] = None, penalize_non_agenda_actions: bool = False, initial_mml_model_file: str = None) -> None: super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._agenda_coverage = Average() self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] = \ ExpectedRiskMinimization(beam_size=beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=max_decoding_steps, max_num_finished_states=max_num_finished_states) # Instantiating an empty NlvrWorld just to get the number of terminals. self._terminal_productions = set(NlvrWorld([]).terminal_productions.values()) self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, dropout=dropout, use_coverage=True) self._checklist_cost_weight = checklist_cost_weight self._dynamic_cost_wait_epochs = None self._dynamic_cost_rate = None if dynamic_cost_weight: self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"] self._dynamic_cost_rate = dynamic_cost_weight["rate"] self._penalize_non_agenda_actions = penalize_non_agenda_actions self._last_epoch_in_forward: int = None # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning("MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.") def _initialize_weights_from_archive(self, archive: Archive) -> None: logger.info("Initializing weights from MML model.") model_parameters = dict(self.named_parameters()) archived_parameters = dict(archive.model.named_parameters()) sentence_embedder_weight = "_sentence_embedder.token_embedder_tokens.weight" if sentence_embedder_weight not in archived_parameters or \ sentence_embedder_weight not in model_parameters: raise RuntimeError("When initializing model weights from an MML model, we need " "the sentence embedder to be a TokenEmbedder using namespace called " "tokens.") for name, weights in archived_parameters.items(): if name in model_parameters: if name == "_sentence_embedder.token_embedder_tokens.weight": # The shapes of embedding weights will most likely differ between the two models # because the vocabularies will most likely be different. We will get a mapping # of indices from this model's token indices to the archived model's and copy # the tensor accordingly. vocab_index_mapping = self._get_vocab_index_mapping(archive.model.vocab) archived_embedding_weights = weights.data new_weights = model_parameters[name].data.clone() for index, archived_index in vocab_index_mapping: new_weights[index] = archived_embedding_weights[archived_index] logger.info("Copied embeddings of %d out of %d tokens", len(vocab_index_mapping), new_weights.size()[0]) else: new_weights = weights.data logger.info("Copying parameter %s", name) model_parameters[name].data.copy_(new_weights) def _get_vocab_index_mapping(self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]: vocab_index_mapping: List[Tuple[int, int]] = [] for index in range(self.vocab.get_vocab_size(namespace='tokens')): token = self.vocab.get_token_from_index(index=index, namespace='tokens') archived_token_index = archived_vocab.get_token_index(token, namespace='tokens') # Checking if we got the UNK token index, because we don't want all new token # representations initialized to UNK token's representation. We do that by checking if # the two tokens are the same. They will not be if the token at the archived index is # UNK. if archived_vocab.get_token_from_index(archived_token_index, namespace="tokens") == token: vocab_index_mapping.append((index, archived_token_index)) return vocab_index_mapping @overrides def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError("If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds, label_strings=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, self._get_state_cost) if identifier is not None: outputs['identifier'] = identifier best_action_sequences = outputs['best_action_sequences'] batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs def _get_checklist_info(self, agenda: torch.LongTensor, all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRuleArray]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) for index, action in enumerate(all_actions): # Each action is a ProductionRuleArray, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask def _update_metrics(self, action_strings: List[List[List[str]]], worlds: List[List[NlvrWorld]], label_strings: List[List[str]], possible_actions: List[List[ProductionRuleArray]], agenda_data: List[List[int]]) -> None: # TODO(pradeep): Move this to the base class. # TODO(pradeep): action_strings contains k-best lists. This method only uses the top decoded # sequence currently. Maybe define top-k metrics? batch_size = len(worlds) for i in range(batch_size): # Using only the top decoded sequence per instance. instance_action_strings = action_strings[i][0] if action_strings[i] else [] sequence_is_correct = [False] in_agenda_ratio = 0.0 instance_possible_actions = possible_actions[i] if instance_action_strings: terminal_agenda_actions = [] for rule_id in agenda_data[i]: if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] right_side = action_string.split(" -> ")[1] if right_side.isdigit() or ('[' not in right_side and len(right_side) > 1): terminal_agenda_actions.append(action_string) actions_in_agenda = [action in instance_action_strings for action in terminal_agenda_actions] in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda) instance_label_strings = label_strings[i] instance_worlds = worlds[i] sequence_is_correct = self._check_denotation(instance_action_strings, instance_label_strings, instance_worlds) for correct_in_world in sequence_is_correct: self._denotation_accuracy(1 if correct_in_world else 0) self._consistency(1 if all(sequence_is_correct) else 0) self._agenda_coverage(in_agenda_ratio) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'denotation_accuracy': self._denotation_accuracy.get_metric(reset), 'consistency': self._consistency.get_metric(reset), 'agenda_coverage': self._agenda_coverage.get_metric(reset) } def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor: """ Return the costs a finished state. Since it is a finished state, the group size will be 1, and hence we'll return just one cost. """ if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. checklist_balance = state.checklist_state[0].get_balance() checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. # Note: If we are penalizing the model for producing non-agenda actions, this is not the # upper limit on the checklist cost. That would be the number of terminal actions. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on # how many worlds the logical form is correct in? # label_strings being None happens when we are testing. We do not care about the cost then. # TODO (pradeep): Make this cleaner. if state.label_strings is None or all(self._check_state_denotations(state)): cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost def _get_state_info(self, state) -> Dict[str, List]: """ This method is here for debugging purposes, in case you want to look at the what the model is learning. It may be inefficient to call it while training the model on real data. """ if len(state.batch_indices) == 1 and state.is_finished(): costs = [float(self._get_state_cost(state).detach().cpu().numpy())] else: costs = [] model_scores = [float(score.detach().cpu().numpy()) for score in state.score] all_actions = state.possible_actions[0] action_sequences = [[self._get_action_string(all_actions[action]) for action in history] for history in state.action_history] agenda_sequences = [] all_agenda_indices = [] for agenda, checklist_target in zip(state.terminal_actions, state.checklist_target): agenda_indices = [] for action, is_wanted in zip(agenda, checklist_target): action_int = int(action.detach().cpu().numpy()) is_wanted_int = int(is_wanted.detach().cpu().numpy()) if is_wanted_int != 0: agenda_indices.append(action_int) agenda_sequences.append([self._get_action_string(all_actions[action]) for action in agenda_indices]) all_agenda_indices.append(agenda_indices) return {"agenda": agenda_sequences, "agenda_indices": all_agenda_indices, "history": action_sequences, "history_indices": state.action_history, "costs": costs, "scores": model_scores}
class Text2SqlParser(Model): """ Parameters ---------- vocab : ``Vocabulary`` utterance_embedder : ``TextFieldEmbedder`` Embedder for utterances. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. input_attention: ``Attention`` We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). """ def __init__(self, vocab: Vocabulary, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._dropout = torch.nn.Dropout(p=dropout) self._exact_match = Average() self._valid_sql_query = Average() self._action_similarity = Average() self._denotation_accuracy = Average() # the padding value used by IndexField self._action_padding_index = -1 num_actions = vocab.get_vocab_size("rule_labels") input_action_dim = action_embedding_dim if self._add_action_bias: input_action_dim += 1 self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1) self._transition_function = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout) initializer(self) @overrides def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() batch_size = embedded_utterance.size(0) # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. loss_output = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[i][action_index] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs def _get_initial_state(self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(actions[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track four metrics here: 1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that can be parsed. (make sure you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data, but not training data). 3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we're trying to produce a super long SQL query and run out of time steps, or something. 4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match. """ validation_correct = self._exact_match._total_value # pylint: disable=protected-access validation_total = self._exact_match._count # pylint: disable=protected-access return { '_exact_match_count': validation_correct, '_example_count': validation_total, 'exact_match': self._exact_match.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'valid_sql_query': self._valid_sql_query.get_metric(reset), 'action_similarity': self._action_similarity.get_metric(reset) } def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append((action, i)) else: raise ValueError("The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append((production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder(global_action_tensor) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[batch_index][action], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['utterance_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class STS14Task(Task): ''' Task class for Sentence Textual Similarity 14. Training data is STS12 and STS13 data, as provided in the dataset. ''' def __init__(self, path, max_seq_len, name="sts14"): ''' ''' super(STS14Task, self).__init__(name, 1) self.name = name self.pair_input = 1 self.categorical = 0 #self.val_metric = "%s_accuracy" % self.name self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.scorer = Average() self.load_data(path, max_seq_len) def load_data(self, path, max_seq_len): ''' Process the dataset located at path. TODO: preprocess and store data so don't have to wait? Args: - path (str): path to data ''' def load_year_split(path): sents1, sents2, targs = [], [], [] input_files = glob.glob('%s/STS.input.*.txt' % path) targ_files = glob.glob('%s/STS.gs.*.txt' % path) input_files.sort() targ_files.sort() for inp, targ in zip(input_files, targ_files): topic_sents1, topic_sents2, topic_targs = \ load_file(path, inp, targ) sents1 += topic_sents1 sents2 += topic_sents2 targs += topic_targs assert len(sents1) == len(sents2) == len(targs) return sents1, sents2, targs def load_file(path, inp, targ): sents1, sents2, targs = [], [], [] with open(inp) as fh, open(targ) as gh: for raw_sents, raw_targ in zip(fh, gh): raw_sents = raw_sents.split('\t') sent1 = process_sentence(raw_sents[0], max_seq_len) sent2 = process_sentence(raw_sents[1], max_seq_len) if not sent1 or not sent2: continue sents1.append(sent1) sents2.append(sent2) targs.append(float(raw_targ) / 5) # rescale for cosine return sents1, sents2, targs sort_data = lambda s1, s2, t: \ sorted(zip(s1, s2, t), key=lambda x: (len(x[0]), len(x[1]))) unpack = lambda x: [l for l in map(list, zip(*x))] sts2topics = { 12: ['MSRpar', 'MSRvid', 'SMTeuroparl', 'surprise.OnWN', \ 'surprise.SMTnews'], 13: ['FNWN', 'headlines', 'OnWN'], 14: ['deft-forum', 'deft-news', 'headlines', 'images', \ 'OnWN', 'tweet-news'] } sents1, sents2, targs = [], [], [] train_dirs = ['STS2012-train', 'STS2012-test', 'STS2013-test'] for train_dir in train_dirs: res = load_year_split(path + train_dir + '/') sents1 += res[0] sents2 += res[1] targs += res[2] data = [(s1, s2, t) for s1, s2, t in zip(sents1, sents2, targs)] random.shuffle(data) sents1, sents2, targs = unpack(data) split_pt = int(.8 * len(sents1)) tr_data = sort_data(sents1[:split_pt], sents2[:split_pt], targs[:split_pt]) val_data = sort_data(sents1[split_pt:], sents2[split_pt:], targs[split_pt:]) te_data = sort_data(*load_year_split(path)) self.train_data_text = unpack(tr_data) self.val_data_text = unpack(val_data) self.test_data_text = unpack(te_data) log.info("\tFinished loading STS14 data.") def get_metrics(self, reset=False): return {'accuracy': self.scorer.get_metric(reset)}
class AtisSemanticParser(Model): """ Parameters ---------- vocab : ``Vocabulary`` utterance_embedder : ``TextFieldEmbedder`` Embedder for utterances. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. input_attention: ``Attention`` We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. database_file: ``str``, optional (default=/atis/atis.db) The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need the file location to connect to it. """ def __init__(self, vocab: Vocabulary, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels', database_file='/atis/atis.db') -> None: # Atis semantic parser init super().__init__(vocab) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._exact_match = Average() self._valid_sql_query = Average() self._action_similarity = Average() self._denotation_accuracy = Average() self._executor = SqlExecutor(database_file) self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._num_entity_types = 2 # TODO(kevin): get this in a more principled way somehow? self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) @overrides def forward(self, # type: ignore utterance: Dict[str, torch.LongTensor], world: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor, target_action_sequence: torch.LongTensor = None, sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- utterance : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. world : ``List[AtisWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. linking_scores: ``torch.Tensor`` A matrix of the linking the utterance tokens and the entities. This is a binary matrix that is deterministically generated where each entry indicates whether a token generated an entity. This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ initial_state = self._get_initial_state(utterance, world, actions, linking_scores) batch_size = linking_scores.shape[0] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_mask = target_action_sequence != self._action_padding_index else: target_mask = None if self.training: # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for # the MML trainer. return self._decoder_trainer.decode(initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) else: # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs. action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} outputs['linking_scores'] = linking_scores if target_action_sequence is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] outputs['utterance'] = [] outputs['tokenized_utterance'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if target_action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) if sql_queries and sql_queries[i]: denotation_correct = self._executor.evaluate_sql_query(predicted_sql_query, sql_queries[i]) self._denotation_accuracy(denotation_correct) outputs['sql_queries'].append(sql_queries[i]) outputs['utterance'].append(world[i].utterances[-1]) outputs['tokenized_utterance'].append([token.text for token in world[i].tokenized_utterances[-1]]) outputs['entities'].append(world[i].entities) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) else: initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state @staticmethod def _get_type_vector(worlds: List[AtisWorld], num_entities: int, tensor: torch.Tensor = None) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[AtisWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] entities = [('number', entity) if any([entity.startswith(numeric_nonterminal) for numeric_nonterminal in NUMERIC_NONTERMINALS]) else ('string', entity) for entity in world.entities] for entity_index, entity in enumerate(entities): # We need numbers to be first, then strings, since our entities are going to be # sorted. We do a split by type and then a merge later, and it relies on this sorting. if entity[0] == 'number': entity_type = 1 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track four metrics here: 1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that can be parsed. (make sure you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data, but not training data). 3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we're trying to produce a super long SQL query and run out of time steps, or something. 4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match. """ return { 'exact_match': self._exact_match.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'valid_sql_query': self._valid_sql_query.get_metric(reset), 'action_similarity': self._action_similarity.get_metric(reset) } def _create_grammar_state(self, world: AtisWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``AtisWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities = world.entities for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).to(entity_types.device).long() global_input_embeddings = self._action_embedder(global_action_tensor) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = linked_rules entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor) .to(entity_types.device) .float()) translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['utterance_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class WikiTablesErmSemanticParser(WikiTablesSemanticParser): """ A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search for logical forms that yield the correct denotations. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. Passed to super class. action_embedding_dim : ``int`` Dimension to use for action embeddings. Passed to super class. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. Passed to super class. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. Passed to super class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. decoder_beam_size : ``int`` Beam size to be used by the ExpectedRiskMinimization algorithm. decoder_num_finished_states : ``int`` Number of finished states for which costs will be computed by the ExpectedRiskMinimization algorithm. max_decoding_steps : ``int`` Maximum number of steps the decoder should take before giving up. Used both during training and evaluation. Passed to super class. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. Passed to super class. normalize_beam_score_by_length : ``bool``, optional (default=False) Should we normalize the log-probabilities by length before renormalizing the beam? This was shown to work better for NML by Edunov et al., but that many not be the case for semantic parsing. checklist_cost_weight : ``float``, optional (default=0.6) Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about denotation accuracy. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. Passed to super class. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). Passed to super class. num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 10 here matches the default in the ``KnowledgeGraphField``, which is to use all ten defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. Passed to super class. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. Passed to super class. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. Passed to super class. mml_model_file : ``str``, optional (default=None) If you want to initialize this model using weights from another model trained using MML, pass the path to the ``model.tar.gz`` file of that model here. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/', mml_model_file: str = None) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, add_action_bias=add_action_bias, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = \ ExpectedRiskMinimization(beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states) unlinked_terminals_global_indices = [] global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace) for production, index in global_vocab.items(): right_side = production.split(" -> ")[1] if right_side in types.COMMON_NAME_MAPPING: # This is a terminal production. unlinked_terminals_global_indices.append(index) self._num_unlinked_terminals = len(unlinked_terminals_global_indices) self._decoder_step = LinkingCoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, predict_start_type_separately=True, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if mml_model_file is not None: if os.path.isfile(mml_model_file): archive = load_archive(mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning("MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.") def _initialize_weights_from_archive(self, archive: Archive) -> None: logger.info("Initializing weights from MML model.") model_parameters = dict(self.named_parameters()) archived_parameters = dict(archive.model.named_parameters()) question_embedder_weight = "_question_embedder.token_embedder_tokens.weight" if question_embedder_weight not in archived_parameters or \ question_embedder_weight not in model_parameters: raise RuntimeError("When initializing model weights from an MML model, we need " "the question embedder to be a TokenEmbedder using namespace called " "tokens.") for name, weights in archived_parameters.items(): if name in model_parameters: if name == question_embedder_weight: # The shapes of embedding weights will most likely differ between the two models # because the vocabularies will most likely be different. We will get a mapping # of indices from this model's token indices to the archived model's and copy # the tensor accordingly. vocab_index_mapping = self._get_vocab_index_mapping(archive.model.vocab) archived_embedding_weights = weights.data new_weights = model_parameters[name].data.clone() for index, archived_index in vocab_index_mapping: new_weights[index] = archived_embedding_weights[archived_index] logger.info("Copied embeddings of %d out of %d tokens", len(vocab_index_mapping), new_weights.size()[0]) else: new_weights = weights.data logger.info("Copying parameter %s", name) model_parameters[name].data.copy_(new_weights) def _get_vocab_index_mapping(self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]: vocab_index_mapping: List[Tuple[int, int]] = [] for index in range(self.vocab.get_vocab_size(namespace='tokens')): token = self.vocab.get_token_from_index(index=index, namespace='tokens') archived_token_index = archived_vocab.get_token_index(token, namespace='tokens') # Checking if we got the UNK token index, because we don't want all new token # representations initialized to UNK token's representation. We do that by checking if # the two tokens are the same. They will not be if the token at the archived index is # UNK. if archived_vocab.get_token_from_index(archived_token_index, namespace="tokens") == token: vocab_index_mapping.append((index, archived_token_index)) return vocab_index_mapping @overrides def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], agenda: torch.LongTensor, example_lisp_string: List[str], metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]`` The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [set(instance_world.terminal_productions.values()) for instance_world in world] max_num_terminals = max([len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip(actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) checklist_states.append(ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = CoverageState(batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, checklist_state=checklist_states, possible_actions=actions, extras=example_lisp_string, debug_info=None) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] outputs = self._decoder_trainer.decode(initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, world)) best_final_states = outputs['best_final_states'] if not self.training: batch_size = len(actions) agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: action_sequence = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in action_sequence] instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [action in action_strings for action in agenda_actions] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda) self._agenda_coverage(in_agenda_ratio) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs @staticmethod def _get_checklist_info(agenda: torch.LongTensor, all_actions: List[ProductionRule], terminal_productions: Set[str], max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda, a list of all actions, a set of terminal productions in the corresponding world, and a length to pad the checklist vectors to, and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. ``terminal_productions`` : ``Set[str]`` String representations of terminal productions in the corresponding world. ``max_num_terminals`` : ``int`` Length to which the checklist vectors will be padded till. This is the max number of terminal productions in all the worlds in the batch. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) while len(target_checklist_list) < max_num_terminals: target_checklist_list.append([0]) terminal_indices.append([-1]) # (max_num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (max_num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError("_get_state_cost() is not defined for unfinished states!") world = worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance) ** 2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [state.possible_actions[batch_index][i][0] for i in action_history] logical_form = world.get_logical_form(action_strings) lisp_string = state.extras[batch_index] if self._executor.evaluate_logical_form(logical_form, lisp_string): cost = checklist_cost else: cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost return cost @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ The base class returns a dict with dpd accuracy, denotation accuracy, and logical form percentage metrics. We add the agenda coverage metric here. """ metrics = super().get_metrics(reset) metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset) return metrics
class SpiderParser(Model): def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, question_embedder: TextFieldEmbedder, input_attention: Attention, past_attention: Attention, max_decoding_steps: int, action_embedding_dim: int, gnn: bool = True, decoder_use_graph_entities: bool = True, decoder_self_attend: bool = True, gnn_timesteps: int = 2, parse_sql_on_decoding: bool = True, add_action_bias: bool = True, use_neighbor_similarity_for_linking: bool = True, dataset_path: str = 'dataset', training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels', scoring_dev_params: dict = None, debug_parsing: bool = False) -> None: super().__init__(vocab) self.vocab = vocab self._encoder = encoder self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._question_embedder = question_embedder self._add_action_bias = add_action_bias self._scoring_dev_params = scoring_dev_params or {} self.parse_sql_on_decoding = parse_sql_on_decoding self._entity_encoder = TimeDistributed(entity_encoder) self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking self._self_attend = decoder_self_attend self._decoder_use_graph_entities = decoder_use_graph_entities self._action_padding_index = -1 # the padding value used by IndexField self._exact_match = Average() self._sql_evaluator_match = Average() self._action_similarity = Average() self._acc_single = Average() self._acc_multi = Average() self._beam_hit = Average() self._action_embedding_dim = action_embedding_dim num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) encoder_output_dim = encoder.get_output_dim() if gnn: encoder_output_dim += action_embedding_dim self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder_output_dim)) self._first_attended_output = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) torch.nn.init.normal_(self._first_attended_output) self._num_entity_types = 9 self._embedding_dim = question_embedder.get_output_dim() self._entity_type_encoder_embedding = Embedding( self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding( self._num_entity_types, action_embedding_dim) self._linking_params = torch.nn.Linear(16, 1) torch.nn.init.uniform_(self._linking_params.weight, 0, 1) num_edge_types = 3 self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) if decoder_self_attend: self._transition_function = AttendPastSchemaItemsTransitionFunction( encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, past_attention=past_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) else: self._transition_function = LinkingTransitionFunction( encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')()) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) # TODO: Remove hard-coded dirs self._evaluate_func = partial( evaluate, db_dir=os.path.join(dataset_path, 'database'), table=os.path.join(dataset_path, 'tables.json'), check_valid=False) self.debug_parsing = debug_parsing @overrides def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], world: List[SpiderWorld], schema: Dict[str, torch.LongTensor], action_sequence: torch.LongTensor = None ) -> Dict[str, torch.Tensor]: batch_size = len(world) device = utterance['tokens'].device initial_state = self._get_initial_state(utterance, world, schema, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) action_mask = action_sequence != self._action_padding_index else: action_mask = None if self.training: decode_output = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1))) return {'loss': decode_output['loss']} else: loss = torch.tensor([0]).float().to(device) if action_sequence is not None and action_sequence.size(1) > 1: try: loss = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))['loss'] except ZeroDivisionError: # reached a dead-end during beam search pass outputs: Dict[str, Any] = {'loss': loss} num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) self._compute_validation_outputs(valid_actions, best_final_states, world, action_sequence, outputs) return outputs def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: schema_text = schema['text'] """KAIMARY""" # TextFieldEmbedder needs a "token" key in the Dict """ embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim]) schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens]) embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim]) entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim]) """ embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1) schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() embedded_utterance = self._question_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size() num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) num_question_tokens = utterance['tokens'].size(1) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding( entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_schema.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_utterance, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) """KAIMARY""" # Variable: linking_scores # The entitiy linking score s(e, i) in the Krishnamurthy 2017 # (batch_size, num_entities, num_question_tokens, num_features) linking_features = schema['linking'] linking_scores = question_entity_similarity_max_score feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores """KAIMARY""" # linking_probabilities # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device) if self._use_neighbor_similarity_for_linking and neighbor_indices is not None: """KAIMARY""" # Seq2VecEncoder get the hidden state of the last step as the unique output # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_schema, schema_mask) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) """KAIMARY""" # Variable: entity_embedding # Rv in B Bogin 2019 # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) """KAIMARY""" # Variable: link_embedding # Li in B Bogin 2019 # Is an average of entity vectors weighted by the resulting distribution link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) """KAIMARY""" # Variable: encoder_input # [Wi, Li] in B Bogin 2019 encoder_input = torch.cat([link_embedding, embedded_utterance], 2) # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) """KAIMARY""" # Variable: max_entities_relevance # ρv = maxi plink(v | xi) in B Bogin 2019 # Is the maximum probability of v for any word xi max_entities_relevance = linking_probabilities.max(dim=1)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() """KAIMARY""" # entity_type_embeddings ??? # Variable: graph_initial_embedding # hv(0) in B Bogin 2019 # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN graph_initial_embedding = entity_type_embeddings * entities_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: """KAIMARY""" # Variable: entities_graph_encoding # φv in B Bogin 2019 # Is the final representation of each schema item after L steps entities_graph_encoding = self._get_schema_graph_encoding( worlds, graph_initial_embedding) """KAIMARY""" # Variable: graph_link_embedding # Lφ,i in B Bogin 2019 graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities) encoder_outputs = torch.cat( (encoder_outputs, graph_link_embedding), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim( ) else: entities_graph_encoding = None if self._self_attend: # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding) entities_ff = self._ent2ent_ff(entities_graph_encoding) linked_actions_linking_scores = torch.bmm( entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state( worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[i] if entities_graph_encoding is not None else None) for i in range(batch_size) ] initial_sql_state = [ SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[ w.get_action_entity_mapping() for w in worlds ]) return initial_state @staticmethod def _get_neighbor_indices(worlds: List[SpiderWorld], num_entities: int, device: torch.device) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[SpiderWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch have no neighbors, None will be returned. """ num_neighbors = 0 for world in worlds: for entity in world.db_context.knowledge_graph.entities: if len(world.db_context.knowledge_graph.neighbors[entity] ) > num_neighbors: num_neighbors = len( world.db_context.knowledge_graph.neighbors[entity]) batch_neighbors = [] no_entities_have_neighbors = True for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.db_context.knowledge_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.db_context.knowledge_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [ entity2index[n] for n in entity2neighbors[entity] ] if entity_neighbors: no_entities_have_neighbors = False # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length( neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) # It is possible that none of the entities has any neighbors, since our definition of the # knowledge graph allows it when no entities or numbers were extracted from the question. if no_entities_have_neighbors: return None return torch.tensor(batch_neighbors, device=device, dtype=torch.long) def _get_schema_graph_encoding( self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: max_num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) batch_size = initial_graph_embeddings.size(0) graph_data_list = [] for batch_index, world in enumerate(worlds): x = initial_graph_embeddings[batch_index] adj_list = self._get_graph_adj_lists( initial_graph_embeddings.device, world, initial_graph_embeddings.size(1) - 1) graph_data = Data(x) for i, l in enumerate(adj_list): graph_data[f'edge_index_{i}'] = l graph_data_list.append(graph_data) batch = Batch.from_data_list(graph_data_list) gnn_output = self._gnn(batch.x, [ batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types) ]) num_nodes = max_num_entities gnn_output = gnn_output.view(batch_size, num_nodes, -1) # entities_encodings = gnn_output entities_encodings = gnn_output[:, :max_num_entities] # global_node_encodings = gnn_output[:, max_num_entities] return entities_encodings @staticmethod def _get_graph_adj_lists(device, world, global_entity_id, global_node=False): entity_mapping = {} for i, entity in enumerate(world.db_context.knowledge_graph.entities): entity_mapping[entity] = i entity_mapping['_global_'] = global_entity_id adj_list_own = [] # column--table adj_list_link = [] # table->table / foreign->primary adj_list_linked = [] # table<-table / foreign<-primary adj_list_global = [] # node->global # TODO: Prepare in advance? for key, neighbors in world.db_context.knowledge_graph.neighbors.items( ): idx_source = entity_mapping[key] for n_key in neighbors: idx_target = entity_mapping[n_key] if n_key.startswith("table") or key.startswith("table"): adj_list_own.append((idx_source, idx_target)) elif n_key.startswith("string") or key.startswith("string"): adj_list_own.append((idx_source, idx_target)) elif key.startswith("column:foreign"): adj_list_link.append((idx_source, idx_target)) src_table_key = f"table:{key.split(':')[2]}" tgt_table_key = f"table:{n_key.split(':')[2]}" idx_source_table = entity_mapping[src_table_key] idx_target_table = entity_mapping[tgt_table_key] adj_list_link.append((idx_source_table, idx_target_table)) elif n_key.startswith("column:foreign"): adj_list_linked.append((idx_source, idx_target)) src_table_key = f"table:{key.split(':')[2]}" tgt_table_key = f"table:{n_key.split(':')[2]}" idx_source_table = entity_mapping[src_table_key] idx_target_table = entity_mapping[tgt_table_key] adj_list_linked.append( (idx_source_table, idx_target_table)) else: assert False adj_list_global.append((idx_source, entity_mapping['_global_'])) all_adj_types = [adj_list_own, adj_list_link, adj_list_linked] if global_node: all_adj_types.append(adj_list_global) return [ torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1) if l else torch.tensor(l, device=device, dtype=torch.long) for l in all_adj_types ] def _create_grammar_state( self, world: SpiderWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, linked_actions_linking_scores: torch.Tensor, entity_types: torch.Tensor, entity_graph_encoding: torch.Tensor) -> GrammarStatelet: action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities = world.entities_names for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat( global_action_tensors, dim=0).to(global_action_tensors[0].device).long() global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [ rule.split(' -> ')[1].strip('[]\"') for rule in linked_rules ] entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] if linked_actions_linking_scores is not None: entity_action_linking_scores = linked_actions_linking_scores[ entity_ids] if not self._decoder_use_graph_entities: entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = ( self._entity_type_decoder_embedding( entity_type_tensor).to( entity_types.device).float()) else: entity_type_embeddings = entity_graph_encoding.index_select( dim=0, index=torch.tensor( entity_ids, device=entity_graph_encoding.device)) if self._self_attend: translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids), entity_action_linking_scores) else: translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True def _get_linking_probabilities( self, worlds: List[SpiderWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "date_column:", followed by "number_column:", "string:", and "string_column:". # This is not a great assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.db_context.knowledge_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select( 1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros( num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max( torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item() @staticmethod def _query_difficulty(targets: torch.LongTensor, action_mapping, batch_index): number_tables = len([ action_mapping[(batch_index, int(a))] for a in targets if a >= 0 and action_mapping[(batch_index, int(a))].startswith('table_name') ]) return number_tables > 1 @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { '_match/exact_match': self._exact_match.get_metric(reset), 'sql_match': self._sql_evaluator_match.get_metric(reset), '_others/action_similarity': self._action_similarity.get_metric(reset), '_match/match_single': self._acc_single.get_metric(reset), '_match/match_hard': self._acc_multi.get_metric(reset), 'beam_hit': self._beam_hit.get_metric(reset) } @staticmethod def _get_type_vector(worlds: List[SpiderWorld], num_entities: int, device) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[AtisWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] column_type_ids = [ 'boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time' ] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate( world.db_context.knowledge_graph.entities): parts = entity.split(':') entity_main_type = parts[0] if entity_main_type == 'column': column_type = parts[1] entity_type = column_type_ids.index(column_type) elif entity_main_type == 'string': # cell value entity_type = len(column_type_ids) elif entity_main_type == 'table': entity_type = len(column_type_ids) + 1 else: raise (Exception("Unkown entity")) types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return torch.tensor(batch_types, dtype=torch.long, device=device), entity_types def _compute_validation_outputs(self, actions: List[List[ProductionRuleArray]], best_final_states: Mapping[ int, Sequence[GrammarBasedState]], world: List[SpiderWorld], target_list: List[List[str]], outputs: Dict[str, Any]) -> None: batch_size = len(actions) outputs['predicted_sql_query'] = [] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): # gold sql exactly as given original_gold_sql_query = ' '.join( world[i].get_query_without_table_hints()) if i not in best_final_states: self._exact_match(0) self._action_similarity(0) self._sql_evaluator_match(0) self._acc_multi(0) self._acc_single(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in best_action_indices ] predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) outputs['predicted_sql_query'].append( sqlparse.format(predicted_sql_query, reindent=False)) if target_list is not None: targets = target_list[i].data sequence_in_targets = self._action_history_match( best_action_indices, targets) self._exact_match(sequence_in_targets) sql_evaluator_match = self._evaluate_func( original_gold_sql_query, predicted_sql_query, world[i].db_id) self._sql_evaluator_match(sql_evaluator_match) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) difficulty = self._query_difficulty(targets, action_mapping, i) if difficulty: self._acc_multi(sql_evaluator_match) else: self._acc_single(sql_evaluator_match) beam_hit = False for pos, final_state in enumerate(best_final_states[i]): action_indices = final_state.action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in action_indices ] candidate_sql_query = action_sequence_to_sql( action_strings, add_table_names=True) if target_list is not None: correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id) if correct: beam_hit = True self._beam_hit(beam_hit)
class BaseRollinRolloutDecoder(SeqDecoder): """ An base decoder with rollin and rollout formulation that will be used to define the other decoders such as autoregressive decoder, reinforce decoder, SEARNN decoder, etc. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. decoder_net : ``DecoderNet``, required Module that contains implementation of neural network for decoding output elements max_decoding_steps : ``int`` Maximum length of decoded sequences. target_embedder : ``Embedding`` Embedder for target tokens. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. beam_size : ``int``, optional (default = 4) Width of the beam for beam search. tensor_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes raw tensors when its called. This metric must accept two arguments when called: a batched tensor of predicted token indices, and a batched tensor of gold token indices. token_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes lists of lists of tokens as input. This metric must accept two arguments when called, both of type `List[List[str]]`. The first is a predicted sequence for each item in the batch and the second is a gold sequence for each item in the batch. scheduled_sampling_ratio : ``float`` optional (default = 0) Defines ratio between teacher forced training and real output usage. If its zero (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output predictions in a single forward pass of the `decoder_net`. """ default_implementation = "auto_regressive_seq_decoder" def __init__(self, vocab: Vocabulary, max_decoding_steps: int, decoder_net: DecoderNet, target_embedder: Embedding, loss_criterion: LossCriterion, generation_batch_size: int = 200, use_in_seq2seq_mode: bool = False, target_namespace: str = "tokens", beam_size: int = None, scheduled_sampling_ratio: float = 0.0, scheduled_sampling_k: int = 100, scheduled_sampling_type: str = 'uniform', rollin_mode: str = 'mixed', rollout_mode: str = 'learned', dropout: float = None, start_token: str = START_SYMBOL, end_token: str = END_SYMBOL, num_decoder_layers: int = 1, mask_pad_and_oov: bool = False, tie_output_embedding: bool = False, rollout_mixing_prob:float = 0.5, use_bleu: bool = False, use_hamming: bool = False, sample_rollouts: bool = False, beam_search_sampling_temperature: float = 1., top_k=0, top_p=0, tensor_based_metric: Metric = None, tensor_based_metric_mask: Metric = None, token_based_metric: Metric = None, eval_beam_size: int = 1, ) -> None: super().__init__(target_embedder) self.current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' self._vocab = vocab self._seq2seq_mode = use_in_seq2seq_mode # Decodes the sequence of encoded hidden states into e new sequence of hidden states. self._max_decoding_steps = max_decoding_steps self._generation_batch_size = generation_batch_size self._decoder_net = decoder_net self._target_namespace = target_namespace # TODO #4 (Kushal): Maybe make them modules so that we can add more of these later. # TODO #8 #7 (Kushal): Rename "mixed" rollin mode to "scheduled sampling". self._rollin_mode = rollin_mode self._rollout_mode = rollout_mode self._scheduled_sampling_ratio = scheduled_sampling_ratio self._scheduled_sampling_k = scheduled_sampling_k self._scheduled_sampling_type = scheduled_sampling_type self._sample_rollouts = sample_rollouts self._mask_pad_and_oov = mask_pad_and_oov self._rollout_mixing_prob = rollout_mixing_prob # At prediction time, we use a beam search to find the most likely sequence of target tokens. # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self._vocab.get_token_index(start_token, self._target_namespace) self._end_index = self._vocab.get_token_index(end_token, self._target_namespace) self._padding_index = self._vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace) self._oov_index = self._vocab.get_token_index(DEFAULT_OOV_TOKEN, self._target_namespace) if self._mask_pad_and_oov: self._vocab_mask = torch.ones(self._vocab.get_vocab_size(self._target_namespace), device=self.current_device) \ .scatter(0, torch.tensor([self._padding_index, self._oov_index, self._start_index], device=self.current_device), 0) if use_bleu: pad_index = self._vocab.get_token_index(self._vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None if use_hamming: self._hamming = HammingLoss() else: self._hamming = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 # TODO(Kushal): Pass in the arguments for sampled. Also, make sure you do not sample in case of Seq2Seq models. self._beam_search = SampledBeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size, temperature=beam_search_sampling_temperature) self._num_classes = self._vocab.get_vocab_size(self._target_namespace) if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim: raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." + f"target_embedder_dim: {self.target_embedder.get_output_dim()}, " + f"decoder input dim: {self._decoder_net.target_embedding_dim}." ) self._ss_ratio = Average() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x self.training_iteration = 0 # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_net.get_output_dim(), self._num_classes) if tie_output_embedding: if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape: raise ConfigurationError( f"Can't tie embeddings with output linear layer, due to shape mismatch. " + f"{self._output_projection_layer.weight.shape} and {self.target_embedder.weight.shape}" ) self._output_projection_layer.weight = self.target_embedder.weight self._loss_criterion = loss_criterion self._top_k = top_k self._top_p = top_p self._eval_beam_size = eval_beam_size self._mle_loss = MaximumLikelihoodLossCriterion() self._perplexity = Perplexity() # These metrics will be updated during training and validation self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._tensor_based_metric_mask = tensor_based_metric_mask self._decode_tokens = partial(decode_tokens, vocab=self._vocab, start_index=self._start_index, end_index=self._end_index) def get_output_dim(self): return self._decoder_net.get_output_dim() def rollin_policy(self, timestep: int, last_predictions: torch.LongTensor, target_tokens: Dict[str, torch.Tensor] = None, rollin_mode = None) -> torch.LongTensor: """ Roll-in policy to use. This takes in targets, timestep and last_predictions, and decide which to use for taking next step i.e., generating next token. What to do is decided by rolling mode. Options are - teacher_forcing, - learned, - mixed, By default the mode is mixed with scheduled_sampling_ratio=0.0. This defaults to teacher_forcing. You can also explicitly run with teacher_forcing mode. Arguments: timestep {int} -- Current timestep decides which target token to use. In case of teacher_forcing this is usually {t-1}^{th} timestep for predicting t^{th} token. last_predictions {torch.LongTensor} -- {t-1}^th token predicted by the model. Keyword Arguments: targets {torch.LongTensor} -- Targets value if it is available. This will be available in training mode but not in inference mode. (default: {None}) rollin_mode {str} -- Rollin mode. Options are teacher_forcing, learned, scheduled-sampling (default: {'teacher_forcing'}) Returns: torch.LongTensor -- The method returns input token for predicting next token. """ rollin_mode = rollin_mode or self._rollin_mode # For first timestep, you are passing start token, so don't do anything smart. if (timestep == 0 or # If no targets, no way to do teacher_forcing, so use your own predictions. target_tokens is None or rollin_mode == 'learned'): # shape: (batch_size,) return last_predictions targets = util.get_token_ids_from_text_field_tensors(target_tokens) if rollin_mode == 'teacher_forcing': # shape: (batch_size,) input_choices = targets[:, timestep] elif rollin_mode == 'mixed': if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - self._scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] else: raise ConfigurationError(f"invalid configuration for rollin policy: {rollin_mode}") return input_choices def copy_reference_policy(self, timestep, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor], ) -> torch.FloatTensor: targets = util.get_token_ids_from_text_field_tensors(target_tokens) seq_len = targets.size(1) batch_size = last_predictions.shape[0] if seq_len > timestep + 1: # + 1 because timestep is an index, indexed at 0. # As we might be overriding the next/predicted token/ # We have to use the value corresponding to {t+1}^{th} # timestep. target_at_timesteps = targets[:, timestep + 1] else: # We have overshot the seq_len, so just repeat the # last token which is either _end_token or _pad_token. target_at_timesteps = targets[:, -1] # TODO: Add support to allow other types of reference policies. # target_logits: (batch_size, num_classes). # This tensor has 0 at targets and (near) -inf at other places. target_logits = (target_at_timesteps.new_zeros((batch_size, self._num_classes)) + 1e-45) \ .scatter_(dim=1, index=target_at_timesteps.unsqueeze(1), value=1.0).log() return target_logits, state def oracle_reference_policy(self, timestep: int, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor], token_to_idx: Dict[str, int], idx_to_token: Dict[int, str], ) -> torch.FloatTensor: # TODO(Kushal): #5 This is a temporary fix. Ideally, we should have # an individual oracle for this which is different from cost function. assert hasattr(self._loss_criterion, "_rollout_cost_function") and \ hasattr(self._loss_criterion._rollout_cost_function, "_oracle"), \ "For oracle reference policy, we will need noisy oracle loss function" start_time = time.time() target_logits, state = self._loss_criterion \ ._rollout_cost_function \ ._oracle \ .reference_step_rollout( step=timestep, last_predictions=last_predictions, state=state, token_to_idx=token_to_idx, idx_to_token=idx_to_token) end_time = time.time() logger.info(f"Oracle Reference time: {end_time - start_time} s") return target_logits, state def rollout_policy(self, timestep: int, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor], logits: torch.FloatTensor, reference_policy:ReferencePolicyType, rollout_mode: str = None, rollout_mixing_func: RolloutMixingProbFuncType = None, ) -> torch.FloatTensor: """Rollout policy to use. This takes in predicted logits at timestep {t}^{th} and depending upon the rollout_mode replaces some of the predictions with targets. The options for rollout mode are: - learned, - reference, - mixed. Arguments: timestep {int} -- Current timestep decides which target token to use. In case of reference this is usually {t-1}^{th} timestep for predicting t^{th} token. logits {torch.LongTensor} -- Logits generated by the model for {t}^{th} timestep. (batch_size, num_classes). Keyword Arguments: targets {torch.LongTensor} -- Targets value if it is available. This will be available in training mode but not in inference mode. (default: {None}) rollout_mode {str} -- Rollout mode: Options are: learned, reference, mixed. (default: {'learned'}) rollout_mixing_func {RolloutMixingProbFuncType} -- Function to get mask to choose predicted logits vs targets in case of mixed rollouts. (default: {0.5}) Returns: torch.LongTensor -- The method returns logits with rollout policy applied. """ rollout_mode = rollout_mode or self._rollout_mode output_logits = logits if rollout_mode == 'learned': # For learned rollout policy, just return the same logits. return output_logits, state target_logits, state = reference_policy(timestep, last_predictions, state) batch_size = logits.size(0) if rollout_mode == 'reference': output_logits += target_logits elif rollout_mode == 'mixed': # Based on the mask (Value=1), copy target values. if rollout_mixing_func is not None: rollout_mixing_prob_tensor = rollout_mixing_func() else: # This returns a (batch_size, num_classes) boolean map where the rows are either all zeros or all ones. rollout_mixing_prob_tensor = torch.bernoulli(torch.ones(batch_size) * self._rollout_mixing_prob) rollout_mixing_mask = rollout_mixing_prob_tensor \ .unsqueeze(1) \ .expand(logits.shape) \ .to(self.current_device) # The target_logits ranges from (-inf , 0), so, by adding those to logits, # we turn the values that are not target tokens to -inf, hence making the distribution # skew towards the target. output_logits += rollout_mixing_mask * target_logits else: raise ConfigurationError(f"Incompatible rollout mode: {rollout_mode}") return output_logits, state def take_step(self, timestep: int, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], rollin_policy:RollinPolicyType=default_rollin_policy, rollout_policy:RolloutPolicyType=default_rollout_policy, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ input_choices = rollin_policy(timestep, last_predictions) # State timestep which we might in _prepare_output_projections. state['timestep'] = timestep # shape: (group_size, num_classes) class_logits, state = self._prepare_output_projections( last_predictions=input_choices, state=state) if not self.training and self._mask_pad_and_oov: # This implementation is copied from masked_log_softmax from allennlp.nn.util. mask = (self._vocab_mask.expand(class_logits.shape) + 1e-45).log() # shape: (group_size, num_classes) class_logits = class_logits + mask # shape: (group_size, num_classes) class_logits, state = rollout_policy(timestep, last_predictions, state, class_logits) class_logits = top_k_top_p_filtering(class_logits, self._top_k, self._top_p) return class_logits, state @overrides def forward(self, # type: ignore encoder_out: Dict[str, torch.LongTensor] = {}, target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. source_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. Returns ------- Dict[str, torch.Tensor] """ output_dict: Dict[str, torch.Tensor] = {} state: Dict[str, torch.Tensor] = {} decoder_init_state: Dict[str, torch.Tensor] = {} state.update(copy.copy(encoder_out)) # In Seq2Seq setting, we will encode the source sequence, # and init the state object with encoder output and decoder # cell will use these encoder outputs for attention/initing # the decoder states. if self._seq2seq_mode: decoder_init_state = \ self._decoder_net.init_decoder_state(state) state.update(decoder_init_state) # Initialize target predictions with the start index. # shape: (batch_size,) start_predictions: torch.LongTensor = \ self._get_start_predictions(state, target_tokens, self._generation_batch_size) # In case we have target_tokens, roll-in and roll-out # only till those many steps, otherwise we roll-out for # `self._max_decoding_steps`. if target_tokens: # shape: (batch_size, max_target_sequence_length) targets: torch.LongTensor = \ util.get_token_ids_from_text_field_tensors(target_tokens) _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps: int = target_sequence_length - 1 else: num_decoding_steps: int = self._max_decoding_steps if target_tokens: decoder_output_dict, rollin_dict, rollout_dict_iter = \ self._forward_loop( state=state, start_predictions=start_predictions, num_decoding_steps=num_decoding_steps, target_tokens=target_tokens) output_dict.update(decoder_output_dict) predictions = decoder_output_dict['predictions'] predicted_tokens = self._decode_tokens(predictions, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_predictions"] = predicted_tokens decoded_targets = self._decode_tokens(targets, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_targets"] = decoded_targets output_dict.update(self._loss_criterion( rollin_output_dict=rollin_dict, rollout_output_dict_iter=rollout_dict_iter, state=state, target_tokens=target_tokens)) mle_loss_output = self._mle_loss( rollin_output_dict=rollin_dict, rollout_output_dict_iter=rollout_dict_iter, state=state, target_tokens=target_tokens) mle_loss = mle_loss_output['loss'] self._perplexity(mle_loss) if not self.training: # While validating or testing we need to roll out the learned policy and the output # of this rollout is used to compute the secondary metrics # like BLEU. state: Dict[str, torch.Tensor] = {} state.update(copy.copy(encoder_out)) state.update(decoder_init_state) rollout_output_dict = self.rollout(state, start_predictions, rollout_steps=num_decoding_steps, rollout_mode='learned', sampled=self._sample_rollouts, beam_size=self._eval_beam_size, # TODO #6 (Kushal): Add a reason why truncate_at_end_all is False here. truncate_at_end_all=False) output_dict.update(rollout_output_dict) predictions = decoder_output_dict['predictions'] predicted_tokens = self._decode_tokens(predictions, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_predictions"] = predicted_tokens decoded_predictions = [predictions[0] \ for predictions in output_dict["decoded_predictions"]] # shape (predictions): (batch_size, beam_size, num_decoding_steps) predictions = rollout_output_dict['predictions'] # shape (best_predictions): (batch_size, num_decoding_steps) best_predictions = predictions[:, 0, :] if target_tokens: targets = util.get_token_ids_from_text_field_tensors(target_tokens) target_mask = util.get_text_field_mask(target_tokens) decoded_targets = self._decode_tokens(targets, vocab_namespace=self._target_namespace, truncate=True) # TODO #3 (Kushal): Maybe abstract out these losses and use loss_metric like AllenNLP uses. if self._bleu and target_tokens: self._bleu(best_predictions, targets) if self._hamming and target_tokens: self._hamming(best_predictions, targets, target_mask) if self._tensor_based_metric is not None: self._tensor_based_metric( # type: ignore predictions=best_predictions, gold_targets=targets, ) if self._tensor_based_metric_mask is not None: self._tensor_based_metric_mask( # type: ignore predictions=best_predictions, gold_targets=targets, mask=~target_mask, ) if self._token_based_metric is not None: self._token_based_metric( # type: ignore predictions=decoded_predictions, gold_targets=decoded_targets, ) return output_dict @overrides def post_process(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] all_predicted_tokens = self._decode_tokens(predicted_indices, vocab_namespace=self._target_namespace, truncate=True) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _apply_scheduled_sampling(self): if not self.training: raise RuntimeError("Scheduled Sampling can only be applied during training.") k = self._scheduled_sampling_k i = self.training_iteration if self._scheduled_sampling_type == 'uniform': # This is same scheduled sampling ratio set by config. pass elif self._scheduled_sampling_type == 'linear': self._scheduled_sampling_ratio = i/float(k) elif self._scheduled_sampling_type == 'inverse_sigmoid': self._scheduled_sampling_ratio = 1 - k/(k + math.exp(i/k)) else: raise ConfigurationError(f"{self._scheduled_sampling_type} is not a valid scheduled sampling type.") self._ss_ratio(self._scheduled_sampling_ratio) def rollin(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, rollin_steps: int, target_tokens: Dict[str, torch.LongTensor] = None, beam_size: int = 1, per_node_beam_size: int = None, sampled: bool = False, truncate_at_end_all: bool = False, rollin_mode: str = None, ): self.training_iteration += 1 # We cannot make a class variable as default, so making default value # as None and in case it is None, setting it to num_classes. per_node_beam_size: int = per_node_beam_size or self._num_classes if self.training: self._apply_scheduled_sampling() rollin_policy = partial(self.rollin_policy, target_tokens=target_tokens, rollin_mode=rollin_mode) rolling_policy = partial(self.take_step, rollin_policy=rollin_policy) # shape (step_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes) step_predictions, log_probabilities, logits = \ self._beam_search.search(start_predictions, state, rolling_policy, max_steps=rollin_steps, beam_size=beam_size, per_node_beam_size=per_node_beam_size, sampled=sampled, truncate_at_end_all=truncate_at_end_all) logits = torch.cat(logits, dim=2) batch_size, beam_size, _ = step_predictions.shape start_prediction_length = start_predictions.size(0) step_predictions = torch.cat([start_predictions.unsqueeze(1) \ .expand(batch_size, beam_size) \ .reshape(batch_size, beam_size, 1), step_predictions], dim=-1) output_dict = { "predictions": step_predictions, "logits": logits, "class_log_probabilities": log_probabilities, } return output_dict def rollin_parallel(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, rollin_steps: int, target_tokens: Dict[str, torch.LongTensor] = None, beam_size: int = 1, per_node_beam_size: int = None, sampled: bool = False, truncate_at_end_all: bool = False, rollin_mode: str = None, ): assert self._decoder_net.decodes_parallel, \ "Rollin Parallel is only applicable for transformer style decoders" + \ "that decode whole sequence in parallel." assert not rollin_mode or rollin_mode == "learned", \ "Parallel Decoding only works when following " + \ "teacher forcing rollin policy (rollin_mode='learned')." assert self._scheduled_sampling_ratio == 0, \ "For learned rollin mode, scheduled sampling ratio should always be 0." self.training_iteration += 1 # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = util.get_token_ids_from_text_field_tensors(target_tokens) # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1], ) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) # Unsqueeze logit to add beam size dimension. logits = logits.unsqueeze(dim=1) log_probabilities, step_predictions = torch.max(logits, dim=-1) return { "predictions": step_predictions, "logits": logits, "class_log_probabilities": log_probabilities, } def rollout(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, rollout_steps: int, beam_size: int = None, per_node_beam_size: int = None, target_tokens: Dict[str, torch.LongTensor] = None, sampled: bool = True, truncate_at_end_all: bool = True, # shape (prediction_prefixes): (batch_size, prefix_length) prediction_prefixes: torch.LongTensor = None, target_prefixes: torch.LongTensor = None, rollout_mixing_func: RolloutMixingProbFuncType = None, reference_policy_type:str = "copy", rollout_mode: str = None, ): state['rollout_params'] = {} if reference_policy_type == 'oracle': reference_policy = partial(self.oracle_reference_policy, token_to_idx=self._vocab._token_to_index['target_tokens'], idx_to_token=self._vocab._index_to_token['target_tokens'], ) num_steps_to_take = rollout_steps state['rollout_params']['rollout_prefixes'] = prediction_prefixes else: reference_policy = partial(self.copy_reference_policy, target_tokens=target_tokens) num_steps_to_take = rollout_steps rollout_policy = partial(self.rollout_policy, rollout_mode=rollout_mode, rollout_mixing_func=rollout_mixing_func, reference_policy=reference_policy, ) rolling_policy=partial(self.take_step, rollout_policy=rollout_policy) # shape (step_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes) step_predictions, log_probabilities, logits = \ self._beam_search.search(start_predictions, state, rolling_policy, max_steps=num_steps_to_take, beam_size=beam_size, per_node_beam_size=per_node_beam_size, sampled=sampled, truncate_at_end_all=truncate_at_end_all) logits = torch.cat(logits, dim=2) # Concatenate the start tokens to the predictions.They are not # added to the predictions by default. batch_size, beam_size, _ = step_predictions.shape start_prediction_length = start_predictions.size(0) step_predictions = torch.cat([start_predictions.unsqueeze(1) \ .expand(batch_size, beam_size) \ .reshape(batch_size, beam_size, 1), step_predictions], dim=-1) # There might be some predictions which might have been made by # rollin policy. If passed, concatenate them here. if prediction_prefixes is not None: prefixes_length = prediction_prefixes.size(1) step_predictions = torch.cat([prediction_prefixes.unsqueeze(1)\ .expand(batch_size, beam_size, prefixes_length), step_predictions], dim=-1) step_prediction_masks = self._get_mask(step_predictions \ .reshape(batch_size * beam_size, -1)) \ .reshape(batch_size, beam_size, -1) output_dict = { "predictions": step_predictions, "prediction_masks": step_prediction_masks, "logits": logits, "class_log_probabilities": log_probabilities, } step_targets = None step_target_masks = None if target_tokens is not None: step_targets = util.get_token_ids_from_text_field_tensors(target_tokens) if target_prefixes is not None: prefixes_length = target_prefixes.size(1) step_targets = torch.cat([target_prefixes, step_targets], dim=-1) step_target_masks = util.get_text_field_mask({'tokens': {'tokens': step_targets}}) output_dict.update({ "targets": step_targets, "target_masks": step_target_masks, }) return output_dict def compute_sentence_probs(self, sequences_dict: Dict[str, torch.LongTensor], ) -> torch.FloatTensor: """ Given a batch of tokens, compute the per-token log probability of sequences given the trained model. Arguments: sequences_dict {Dict[str, torch.LongTensor]} -- The sequences that needs to be scored. Returns: seq_probs {torch.FloatTensor} -- Probabilities of the sequence. seq_lens {torch.LongTensor} -- Length of the non padded sequence. per_step_seq_probs {torch.LongTensor} -- Probability of per prediction in a sequence """ state = {} sequences = util.get_token_ids_from_text_field_tensors(sequences_dict) batch_size = sequences.size(0) seq_len = sequences.size(1) start_predictions = self._get_start_predictions(state, sequences_dict, batch_size) # We are now computing probability considering given the sequence, # So, we will use rollin_mode=teacher_forcing as we want to select # token from the sequences for which we need to compute the probability. rollin_output_dict = self.rollin(state={}, start_predictions=start_predictions, rollin_steps=seq_len - 1, target_tokens=sequences_dict, rollin_mode='teacher_forcing', ) step_log_probs = F.log_softmax(rollin_output_dict['logits'].squeeze(1), dim=-1) per_step_seq_probs = torch.gather(step_log_probs, 2, sequences[:,1:].unsqueeze(2)) \ .squeeze(2) sequence_mask = util.get_text_field_mask(sequences_dict) per_step_seq_probs_summed = torch.sum(per_step_seq_probs * sequence_mask[:, 1:], dim=-1) non_batch_dims = tuple(range(1, len(sequence_mask.shape))) # shape : (batch_size,) sequence_mask_sum = sequence_mask[:, 1:].sum(dim=non_batch_dims) # (seq_probs, seq_lens, per_step_seq_probs) return torch.exp(per_step_seq_probs_summed/sequence_mask_sum), \ sequence_mask_sum, \ torch.exp(per_step_seq_probs) def _forward_loop(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, num_decoding_steps: int, target_tokens: Dict[str, torch.LongTensor] = None, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: raise NotImplementedError() def _get_start_predictions(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, generation_batch_size:int = None) -> torch.LongTensor: if self._seq2seq_mode: source_mask = state["source_mask"] batch_size = source_mask.size()[0] elif target_tokens: targets = util.get_token_ids_from_text_field_tensors(target_tokens) batch_size = targets.size(0) else: batch_size = generation_batch_size # Initialize target predictions with the start index. # shape: (batch_size,) return torch.zeros((batch_size,), dtype=torch.long, device=self.current_device) \ .fill_(self._start_index) def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state.get("encoder_outputs", None) # shape: (group_size, max_input_sequence_length) source_mask = state.get("source_mask", None) # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions", None) # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1) if previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0: # There is no previous steps, except for start vectors in `last_predictions` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1 ) decoder_state, decoder_output = self._decoder_net( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions, ) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder_net.decodes_parallel: decoder_output = decoder_output[:, -1, :] # add dropout decoder_hidden_with_dropout = self._dropout(decoder_output) # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden_with_dropout) return output_projections, state def _get_mask(self, predictions) -> torch.FloatTensor: # SEARNN with KL might not produce the sequences that # match target sequence on length. This is especially true # with LM done with learned rollins. The pattern observed # here is that sequence lengths keep shrinking. # This code computes mask from predicted tokens by observing # first time eos token is produced. Everything after that is # masked out. mask = predictions.new_ones(predictions.shape) for i, indices in enumerate(predictions.detach().cpu().tolist()): if self._end_index in indices: end_idx = indices.index(self._end_index) mask[i, :end_idx + 1] = 1 mask[i, end_idx + 1:] = 0 return mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} all_metrics.update({ 'ss_ratio': self._ss_ratio.get_metric(reset=reset), 'training_iter': self.training_iteration, 'perplexity': self._perplexity.get_metric(reset=reset), }) if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) if self._hamming and not self.training: all_metrics.update({'hamming': self._hamming.get_metric(reset=reset)}) if self._loss_criterion and self._loss_criterion._shall_compute_rollout_loss: all_metrics.update(self._loss_criterion.get_metric(reset=reset)) if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric(reset=reset) # type: ignore ) if self._token_based_metric is not None: all_metrics.update(self._token_based_metric.get_metric(reset=reset)) # type: ignore return all_metrics
class AtisSemanticParser(Model): """ Parameters ---------- vocab : ``Vocabulary`` utterance_embedder : ``TextFieldEmbedder`` Embedder for utterances. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. input_attention: ``Attention`` We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. database_file: ``str``, optional (default=/atis/atis.db) The path of the SQLite database when evaluating SQL queries. SQLite is disk based, so we need the file location to connect to it. """ def __init__( self, vocab: Vocabulary, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = "rule_labels", database_file="/atis/atis.db", ) -> None: # Atis semantic parser init super().__init__(vocab) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._exact_match = Average() self._valid_sql_query = Average() self._action_similarity = Average() self._denotation_accuracy = Average() self._executor = SqlExecutor(database_file) self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._num_entity_types = 2 # TODO(kevin): get this in a more principled way somehow? self._entity_type_decoder_embedding = Embedding( num_embeddings=self._num_entity_types, embedding_dim=action_embedding_dim) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) self._transition_function = LinkingTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers, ) def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], world: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor, target_action_sequence: torch.LongTensor = None, sql_queries: List[List[str]] = None, ) -> Dict[str, torch.Tensor]: """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- utterance : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. world : ``List[AtisWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. linking_scores: ``torch.Tensor`` A matrix of the linking the utterance tokens and the entities. This is a binary matrix that is deterministically generated where each entry indicates whether a token generated an entity. This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ initial_state = self._get_initial_state(utterance, world, actions, linking_scores) batch_size = linking_scores.shape[0] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_mask = target_action_sequence != self._action_padding_index else: target_mask = None if self.training: # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for # the MML trainer. return self._decoder_trainer.decode( initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)), ) else: # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs. action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {"action_mapping": action_mapping} outputs["linking_scores"] = linking_scores if target_action_sequence is not None: outputs["loss"] = self._decoder_trainer.decode( initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)), )["loss"] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False, ) outputs["best_action_sequence"] = [] outputs["debug_info"] = [] outputs["entities"] = [] outputs["predicted_sql_query"] = [] outputs["sql_queries"] = [] outputs["utterance"] = [] outputs["tokenized_utterance"] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs["predicted_sql_query"].append("") continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[(i, action_index)] for action_index in best_action_indices ] predicted_sql_query = action_sequence_to_sql(action_strings) if target_action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match( best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher( None, best_action_indices, targets) self._action_similarity(similarity.ratio()) if sql_queries and sql_queries[i]: denotation_correct = self._executor.evaluate_sql_query( predicted_sql_query, sql_queries[i]) self._denotation_accuracy(denotation_correct) outputs["sql_queries"].append(sql_queries[i]) outputs["utterance"].append(world[i].utterances[-1]) outputs["tokenized_utterance"].append([ token.text for token in world[i].tokenized_utterances[-1] ]) outputs["entities"].append(world[i].entities) outputs["best_action_sequence"].append(action_strings) outputs["predicted_sql_query"].append( sqlparse.format(predicted_sql_query, reindent=True)) outputs["debug_info"].append( best_final_states[i][0].debug_info[0]) # type: ignore return outputs def _get_initial_state( self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor, ) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance) batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append( RnnStatelet( final_encoder_output[i].repeat( self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) else: initial_rnn_state.append( RnnStatelet( final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None, ) return initial_state @staticmethod def _get_type_vector( worlds: List[AtisWorld], num_entities: int, tensor: torch.Tensor = None ) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[AtisWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] entities = [("number", entity) if any([ entity.startswith(numeric_nonterminal) for numeric_nonterminal in NUMERIC_NONTERMINALS ]) else ("string", entity) for entity in world.entities] for entity_index, entity in enumerate(entities): # We need numbers to be first, then strings, since our entities are going to be # sorted. We do a split by type and then a merge later, and it relies on this sorting. if entity[0] == "number": entity_type = 1 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track four metrics here: 1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that can be parsed. (make sure you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data, but not training data). 3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we're trying to produce a super long SQL query and run out of time steps, or something. 4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match. """ return { "exact_match": self._exact_match.get_metric(reset), "denotation_acc": self._denotation_accuracy.get_metric(reset), "valid_sql_query": self._valid_sql_query.get_metric(reset), "action_similarity": self._action_similarity.get_metric(reset), } def _create_grammar_state( self, world: AtisWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor, ) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``AtisWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities: Iterable[str] = world.entities for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = (torch.cat( global_action_tensors, dim=0).to(entity_types.device).long()) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]["global"] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids), ) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = linked_rules entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = ( self._entity_type_decoder_embedding(entity_type_tensor).to( entity_types.device).float()) translated_valid_actions[key]["linked"] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids), ) return GrammarStatelet(["statement"], translated_valid_actions, self.is_nonterminal) def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ action_mapping = output_dict["action_mapping"] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict["debug_info"] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions, debug_info): action_info = {} action_info["predicted_action"] = predicted_action considered_actions = action_debug_info["considered_actions"] probabilities = action_debug_info["probabilities"] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info["considered_actions"] = considered_actions action_info["action_probabilities"] = probabilities action_info["utterance_attention"] = action_debug_info.get( "question_attention", []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class GateBidirectionalAttentionFlow(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, gate_sent_encoder: Seq2SeqEncoder, gate_self_attention_layer: Seq2SeqEncoder, span_gate: Seq2SeqEncoder, dropout: float = 0.2, output_att_scores: bool = True, sent_labels_src: str = 'sp', regularizer: Optional[RegularizerApplicator] = None) -> None: super(GateBidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._dropout = torch.nn.Dropout(p=dropout) self._output_att_scores = output_att_scores self._sent_labels_src = sent_labels_src self._span_gate = span_gate if span_gate._gate_self_att: self._gate_sent_encoder = gate_sent_encoder self._gate_self_attention_layer = gate_self_attention_layer else: self._gate_sent_encoder = None self._gate_self_attention_layer = None self._f1_metrics = F1Measure(1) self.evd_ans_metric = Average() self._loss_trackers = {'loss': Average()} def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask) print('\nBert wordpiece size:', passage['bert'].shape) # bert embedding for answer prediction # shape: [batch_size, max_q_len, emb_size] embedded_question = self._text_field_embedder(question, num_wrapping_dims=0) # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim] embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1) # print('\npassage size:', embedded_passage.shape) #embedded_question = self._bert_projection(embedded_question) #embedded_passage = self._bert_projection(embedded_passage) #print('size embedded_passage:', embedded_passage.shape) # mask ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float() context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float() # gate prediction # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = context_mask.size() loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() gate = gate.view(batch_size, num_spans) output_dict = { "pred_sent_labels": gate, #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score # Compute the loss for training. try: #loss = strong_sup_loss self._loss_trackers['loss'](loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] #token_spans_sp = [] #token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_sent_tokens']) #token_spans_sp.append(metadata[i]['token_spans_sp']) #token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] #offsets = metadata[i]['token_offsets'] answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) # shift sentence indice back evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0]) ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']]) if len(metadata[i]['ans_sent_idxs']) > 0: pred_sent_gate = gate[i].detach().cpu().numpy() if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]): self.evd_ans_metric(1) else: self.evd_ans_metric(0) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_sent_tokens'] = passage_tokens #output_dict['token_spans_sp'] = token_spans_sp #output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: p, r, evidence_f1_socre = self._f1_metrics.get_metric(reset) ans_in_evd = self.evd_ans_metric.get_metric(reset) metrics = { 'evd_p': p, 'evd_r': r, 'evd_f1': evidence_f1_socre, 'ans_in_evd': ans_in_evd } for name, tracker in self._loss_trackers.items(): metrics[name] = tracker.get_metric(reset).item() return metrics @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class NlvrCoverageSemanticParser(NlvrSemanticParser): """ ``NlvrSemanticCoverageParser`` is an ``NlvrSemanticParser`` that gets around the problem of lack of annotated logical forms by maximizing coverage of the output sequences over a prespecified agenda. In addition to the signal from coverage, we also compute the denotations given by the logical forms and define a hybrid cost based on coverage and denotation errors. The training process then minimizes the expected value of this cost over an approximate set of logical forms produced by the parser, obtained by performing beam search. Parameters ---------- vocab : ``Vocabulary`` Passed to super-class. sentence_embedder : ``TextFieldEmbedder`` Passed to super-class. action_embedding_dim : ``int`` Passed to super-class. encoder : ``Seq2SeqEncoder`` Passed to super-class. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the TransitionFunction. beam_size : ``int`` Beam size for the beam search used during training. max_num_finished_states : ``int``, optional (default=None) Maximum number of finished states the trainer should compute costs for. normalize_beam_score_by_length : ``bool``, optional (default=False) Should the log probabilities be normalized by length before renormalizing them? Edunov et al. do this in their work, but we found that not doing it works better. It's possible they did this because their task is NMT, and longer decoded sequences are not necessarily worse, and shouldn't be penalized, while we will mostly want to penalize longer logical forms. max_decoding_steps : ``int`` Maximum number of steps for the beam search during training. dropout : ``float``, optional (default=0.0) Probability of dropout to apply on encoder outputs, decoder outputs and predicted actions. checklist_cost_weight : ``float``, optional (default=0.6) Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about denotation accuracy. dynamic_cost_weight : ``Dict[str, Union[int, float]]``, optional (default=None) A dict containing keys ``wait_num_epochs`` and ``rate`` indicating the number of steps after which we should start decreasing the weight on checklist cost in favor of denotation cost, and the rate at which we should do it. We will decrease the weight in the following way - ``checklist_cost_weight = checklist_cost_weight - rate * checklist_cost_weight`` starting at the appropriate epoch. The weight will remain constant if this is not provided. penalize_non_agenda_actions : ``bool``, optional (default=False) Should we penalize the model for producing terminal actions that are outside the agenda? initial_mml_model_file : ``str`` , optional (default=None) If you want to initialize this model using weights from another model trained using MML, pass the path to the ``model.tar.gz`` file of that model here. """ def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, max_num_finished_states: int = None, dropout: float = 0.0, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, dynamic_cost_weight: Dict[str, Union[int, float]] = None, penalize_non_agenda_actions: bool = False, initial_mml_model_file: str = None) -> None: super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._agenda_coverage = Average() self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \ ExpectedRiskMinimization(beam_size=beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=max_decoding_steps, max_num_finished_states=max_num_finished_states) # Instantiating an empty NlvrWorld just to get the number of terminals. self._terminal_productions = set( NlvrWorld([]).terminal_productions.values()) self._decoder_step = CoverageTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=1, activation=Activation.by_name('tanh')(), predict_start_type_separately=False, add_action_bias=False, dropout=dropout) self._checklist_cost_weight = checklist_cost_weight self._dynamic_cost_wait_epochs = None self._dynamic_cost_rate = None if dynamic_cost_weight: self._dynamic_cost_wait_epochs = dynamic_cost_weight[ "wait_num_epochs"] self._dynamic_cost_rate = dynamic_cost_weight["rate"] self._penalize_non_agenda_actions = penalize_non_agenda_actions self._last_epoch_in_forward: int = None # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.") def _initialize_weights_from_archive(self, archive: Archive) -> None: logger.info("Initializing weights from MML model.") model_parameters = dict(self.named_parameters()) archived_parameters = dict(archive.model.named_parameters()) sentence_embedder_weight = "_sentence_embedder.token_embedder_tokens.weight" if sentence_embedder_weight not in archived_parameters or \ sentence_embedder_weight not in model_parameters: raise RuntimeError( "When initializing model weights from an MML model, we need " "the sentence embedder to be a TokenEmbedder using namespace called " "tokens.") for name, weights in archived_parameters.items(): if name in model_parameters: if name == "_sentence_embedder.token_embedder_tokens.weight": # The shapes of embedding weights will most likely differ between the two models # because the vocabularies will most likely be different. We will get a mapping # of indices from this model's token indices to the archived model's and copy # the tensor accordingly. vocab_index_mapping = self._get_vocab_index_mapping( archive.model.vocab) archived_embedding_weights = weights.data new_weights = model_parameters[name].data.clone() for index, archived_index in vocab_index_mapping: new_weights[index] = archived_embedding_weights[ archived_index] logger.info("Copied embeddings of %d out of %d tokens", len(vocab_index_mapping), new_weights.size()[0]) else: new_weights = weights.data logger.info("Copying parameter %s", name) model_parameters[name].data.copy_(new_weights) def _get_vocab_index_mapping( self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]: vocab_index_mapping: List[Tuple[int, int]] = [] for index in range(self.vocab.get_vocab_size(namespace='tokens')): token = self.vocab.get_token_from_index(index=index, namespace='tokens') archived_token_index = archived_vocab.get_token_index( token, namespace='tokens') # Checking if we got the UNK token index, because we don't want all new token # representations initialized to UNK token's representation. We do that by checking if # the two tokens are the same. They will not be if the token at the archived index is # UNK. if archived_vocab.get_token_from_index( archived_token_index, namespace="tokens") == token: vocab_index_mapping.append((index, archived_token_index)) return vocab_index_mapping @overrides def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError( "If you want a dynamic cost weight, use the " "EpochTrackingBucketIterator!") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) initial_checklist_states.append( ChecklistStatelet(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_state = CoverageState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states) agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode( initial_state, # type: ignore self._decoder_step, partial(self._get_state_cost, worlds)) if identifier is not None: outputs['identifier'] = identifier best_final_states = outputs['best_final_states'] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [ state.action_history[0] for state in states ] batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data) else: # We're testing. outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs def _get_checklist_info( self, agenda: torch.LongTensor, all_actions: List[ProductionRule] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set( [int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask def _update_metrics(self, action_strings: List[List[List[str]]], worlds: List[List[NlvrWorld]], label_strings: List[List[str]], possible_actions: List[List[ProductionRule]], agenda_data: List[List[int]]) -> None: # TODO(pradeep): Move this to the base class. # TODO(pradeep): action_strings contains k-best lists. This method only uses the top decoded # sequence currently. Maybe define top-k metrics? batch_size = len(worlds) for i in range(batch_size): # Using only the top decoded sequence per instance. instance_action_strings = action_strings[i][0] if action_strings[ i] else [] sequence_is_correct = [False] in_agenda_ratio = 0.0 instance_possible_actions = possible_actions[i] if instance_action_strings: terminal_agenda_actions = [] for rule_id in agenda_data[i]: if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] right_side = action_string.split(" -> ")[1] if right_side.isdigit() or ('[' not in right_side and len(right_side) > 1): terminal_agenda_actions.append(action_string) actions_in_agenda = [ action in instance_action_strings for action in terminal_agenda_actions ] in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) instance_label_strings = label_strings[i] instance_worlds = worlds[i] sequence_is_correct = self._check_denotation( instance_action_strings, instance_label_strings, instance_worlds) for correct_in_world in sequence_is_correct: self._denotation_accuracy(1 if correct_in_world else 0) self._consistency(1 if all(sequence_is_correct) else 0) self._agenda_coverage(in_agenda_ratio) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'denotation_accuracy': self._denotation_accuracy.get_metric(reset), 'consistency': self._consistency.get_metric(reset), 'agenda_coverage': self._agenda_coverage.get_metric(reset) } def _get_state_cost(self, batch_worlds: List[List[NlvrWorld]], state: CoverageState) -> torch.Tensor: """ Return the cost of a finished state. Since it is a finished state, the group size will be 1, and hence we'll return just one cost. The ``batch_worlds`` parameter here is because we need the world to check the denotation accuracy of the action sequence in the finished state. Instead of adding a field to the ``State`` object just for this method, we take the ``World`` as a parameter here. """ if not state.is_finished(): raise RuntimeError( "_get_state_cost() is not defined for unfinished states!") instance_worlds = batch_worlds[state.batch_indices[0]] # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. checklist_balance = state.checklist_state[0].get_balance() checklist_cost = torch.sum((checklist_balance)**2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. # Note: If we are penalizing the model for producing non-agenda actions, this is not the # upper limit on the checklist cost. That would be the number of terminal actions. denotation_cost = torch.sum( state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on # how many worlds the logical form is correct in? # extras being None happens when we are testing. We do not care about the cost # then. TODO (pradeep): Make this cleaner. if state.extras is None or all( self._check_state_denotations(state, instance_worlds)): cost = checklist_cost else: cost = checklist_cost + ( 1 - self._checklist_cost_weight) * denotation_cost return cost def _get_state_info( self, state: CoverageState, batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]: """ This method is here for debugging purposes, in case you want to look at the what the model is learning. It may be inefficient to call it while training the model on real data. """ if len(state.batch_indices) == 1 and state.is_finished(): costs = [ float( self._get_state_cost(batch_worlds, state).detach().cpu().numpy()) ] else: costs = [] model_scores = [ float(score.detach().cpu().numpy()) for score in state.score ] all_actions = state.possible_actions[0] action_sequences = [[ self._get_action_string(all_actions[action]) for action in history ] for history in state.action_history] agenda_sequences = [] all_agenda_indices = [] for checklist_state in state.checklist_state: agenda_indices = [] for action, is_wanted in zip(checklist_state.terminal_actions, checklist_state.checklist_target): action_int = int(action.detach().cpu().numpy()) is_wanted_int = int(is_wanted.detach().cpu().numpy()) if is_wanted_int != 0: agenda_indices.append(action_int) agenda_sequences.append([ self._get_action_string(all_actions[action]) for action in agenda_indices ]) all_agenda_indices.append(agenda_indices) return { "agenda": agenda_sequences, "agenda_indices": all_agenda_indices, "history": action_sequences, "history_indices": state.action_history, "costs": costs, "scores": model_scores }
class QuarelSemanticParser(Model): """ A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various tweaks and changes. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. decoder_beam_search : ``BeamSearch`` When we're not training, this is how we will do decoding. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. use_entities : ``bool``, optional (default=False) Whether dynamic entities are part of the action space num_entity_bits : ``int``, optional (default=0) Whether any bits are added to encoder input/output to represent tagged entities entity_bits_output : ``bool``, optional (default=False) Whether entity bits are added to the encoder output or input denotation_only : ``bool``, optional (default=False) Whether to only predict target denotation, skipping the the whole logical form decoder entity_similarity_mode : ``str``, optional (default="dot_product") How to compute vector similarity between question and entity tokens, can take values "dot_product" or "weighted_dot_product" (learned weights on each dimension) rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention: Attention, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, dropout: float = 0.0, num_linking_features: int = 0, num_entity_bits: int = 0, entity_bits_output: bool = True, use_entities: bool = False, denotation_only: bool = False, # Deprecated parameter to load older models entity_encoder: Seq2VecEncoder = None, # pylint: disable=unused-argument entity_similarity_mode: str = "dot_product", rule_namespace: str = 'rule_labels') -> None: super(QuarelSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = Average() self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._embedding_dim = question_embedder.get_output_dim() self._use_entities = use_entities # Note: there's only one non-trivial entity type in QuaRel for now, so most of the # entity_type stuff is irrelevant self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._entity_similarity_layer = None self._entity_similarity_mode = entity_similarity_mode if self._entity_similarity_mode == "weighted_dot_product": self._entity_similarity_layer = \ TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False)) # Center initial values around unweighted dot product self._entity_similarity_layer._module.weight.data += 1 # pylint: disable=protected-access elif self._entity_similarity_mode == "dot_product": pass else: raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode)) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None self._decoder_trainer = MaximumMarginalLikelihood() self._encoder_output_dim = self._encoder.get_output_dim() if entity_bits_output: self._encoder_output_dim += num_entity_bits self._entity_bits_output = entity_bits_output self._debug_count = 10 self._num_denotation_cats = 2 # Hardcoded for simplicity self._denotation_only = denotation_only if self._denotation_only: self._denotation_accuracy_cat = CategoricalAccuracy() self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim, self._num_denotation_cats) # Rest of init not needed for denotation only where no decoding to actions needed return self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) self._num_actions = num_actions self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) # We are tying the action embeddings used for input and output # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = self._action_embedder # tied weights self._add_action_bias = add_action_bias if self._add_action_bias: self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, predict_start_type_separately=False, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout) @overrides def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view(batch_size, num_question_tokens*num_entities*num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer(product) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier(final_encoder_output) loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get('question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get('world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp(best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs @staticmethod def _get_type_vector(worlds: List[QuarelWorld], num_entities: int, tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith('fb:cell'): entity_type = 1 elif entity.startswith('fb:part'): entity_type = 2 elif entity.startswith('fb:row'): entity_type = 3 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types def _get_linking_probabilities(self, worlds: List[QuarelWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[QuarelWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float: if predicted_answer_index < 0: # Logical form doesn't properly resolve, we do random guess with appropriate credit return 1.0/self._num_denotation_cats elif predicted_answer_index == target_answer_index: return 1.0 return 0.0 @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. parse_acc, which is the percentage of the time that our best output action sequence corresponds to a correct logical form 2. denotation_acc, which is the percentage of examples where we get the correct denotation, including spurious correct answers using the wrong logical form 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ if self._denotation_only: metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)} else: metrics = { 'parse_acc': self._action_sequence_accuracy.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'lf_percent': self._has_logical_form.get_metric(reset), } return metrics def _create_grammar_state(self, world: QuarelWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``QuarelWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder(global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases(global_action_tensor) global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor) translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet([START_SYMBOL], translated_valid_actions, type_declaration.is_nonterminal) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class RedditTask(RankingTask): ''' Task class for Reddit data. ''' def __init__(self, path, max_seq_len, name, **kw): ''' ''' super().__init__(name, **kw) self.scorer1 = Average() # CategoricalAccuracy() self.scorer2 = None self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.files_by_split = {split: os.path.join(path, "%s.csv" % split) for split in ["train", "val", "test"]} self.max_seq_len = max_seq_len def get_split_text(self, split: str): ''' Get split text as iterable of records. Split should be one of 'train', 'val', or 'test'. ''' return self.load_data(self.files_by_split[split]) def load_data(self, path): ''' Load data ''' with open(path, 'r') as txt_fh: for row in txt_fh: row = row.strip().split('\t') if len(row) < 4 or not row[2] or not row[3]: continue sent1 = process_sentence(row[2], self.max_seq_len) sent2 = process_sentence(row[3], self.max_seq_len) targ = 1 yield (sent1, sent2, targ) def get_sentences(self) -> Iterable[Sequence[str]]: ''' Yield sentences, used to compute vocabulary. ''' for split in self.files_by_split: # Don't use test set for vocab building. if split.startswith("test"): continue path = self.files_by_split[split] for sent1, sent2, _ in self.load_data(path): yield sent1 yield sent2 def count_examples(self): ''' Compute here b/c we're streaming the sentences. ''' example_counts = {} for split, split_path in self.files_by_split.items(): example_counts[split] = sum(1 for line in open(split_path)) self.example_counts = example_counts def process_split(self, split, indexers) -> Iterable[Type[Instance]]: ''' Process split text into a list of AllenNLP Instances. ''' def _make_instance(input1, input2, labels): d = {} d["input1"] = sentence_to_text_field(input1, indexers) #d['sent1_str'] = MetadataField(" ".join(input1[1:-1])) d["input2"] = sentence_to_text_field(input2, indexers) #d['sent2_str'] = MetadataField(" ".join(input2[1:-1])) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) return Instance(d) for sent1, sent2, trg in split: yield _make_instance(sent1, sent2, trg) def get_metrics(self, reset=False): '''Get metrics specific to the task''' acc = self.scorer1.get_metric(reset) return {'accuracy': acc}
class BCECentreDistanceBoxModel(BCEBoxModel): def __init__(self, num_entities: int, num_relations: int, embedding_dim: int, box_type: str = 'SigmoidBoxTensor', single_box: bool = False, softbox_temp: float = 10., margin: float = 1., number_of_negative_samples: int = 0, debug: bool = False, regularization_weight: float = 0, init_interval_center: float = 0.25, init_interval_delta: float = 0.1) -> None: super().__init__(num_entities, num_relations, embedding_dim, box_type=box_type, single_box=single_box, softbox_temp=softbox_temp, number_of_negative_samples=number_of_negative_samples, debug=debug, regularization_weight=regularization_weight, init_interval_center=init_interval_center, init_interval_delta=init_interval_delta) self.number_of_negative_samples = number_of_negative_samples self.centre_loss_metric = Average() self.loss_f_centre: torch.nn.modules._Loss = torch.nn.MarginRankingLoss( # type: ignore margin=margin, reduction='mean') def get_scores(self, embeddings: Dict) -> torch.Tensor: p = self._get_triple_score(embeddings['h'], embeddings['t'], embeddings['r']) diff = (embeddings['h'].centre - embeddings['t'].centre).norm(p=1, dim=-1) pos_diff = torch.mean( diff[torch.where(embeddings['label'] == 1)]).view(1) neg_diff = torch.mean( diff[torch.where(embeddings['label'] == 0)]).view(1) if torch.isnan(pos_diff): pos_diff = torch.Tensor([0.]) if torch.isnan(neg_diff): neg_diff = torch.Tensor([margin]) return p, pos_diff, neg_diff def get_loss(self, scores: torch.Tensor, label: torch.Tensor) -> torch.Tensor: log_p = scores[0] log1mp = log1mexp(log_p) logits = torch.stack([log1mp, log_p], dim=-1) centre_loss = self.loss_f_centre(scores[2], scores[1], torch.Tensor([1])) self.centre_loss_metric(centre_loss.item()) loss = self.loss_f(logits, label) + self.regularization_weight * centre_loss return loss def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'hr_rank': self.head_replacement_rank_avg.get_metric(reset), 'tr_rank': self.tail_replacement_rank_avg.get_metric(reset), 'avg_rank': self.avg_rank.get_metric(reset), 'hitsat10': self.hitsat10.get_metric(reset), 'hr_mrr': self.head_replacement_mrr.get_metric(reset), 'tr_mrr': self.tail_replacement_mrr.get_metric(reset), 'int_volume_train': self.int_volume_train.get_metric(reset), 'int_volume_dev': self.int_volume_dev.get_metric(reset), 'regularization_loss': self.regularization_loss.get_metric(reset), 'hr_hitsat1': self.head_hitsat1.get_metric(reset), 'tr_hitsat1': self.tail_hitsat1.get_metric(reset), 'hr_hitsat3': self.head_hitsat3.get_metric(reset), 'tr_hitsat3': self.tail_hitsat3.get_metric(reset), 'mrr': self.mrr.get_metric(reset), 'centre_loss': self.centre_loss_metric.get_metric(reset) } return metrics
class WikiTablesErmSemanticParser(WikiTablesSemanticParser): """ A ``WikiTablesErmSemanticParser`` is a :class:`WikiTablesSemanticParser` that learns to search for logical forms that yield the correct denotations. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. Passed to super class. action_embedding_dim : ``int`` Dimension to use for action embeddings. Passed to super class. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. Passed to super class. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. Passed to super class. input_attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to WikiTablesDecoderStep. decoder_beam_size : ``int`` Beam size to be used by the ExpectedRiskMinimization algorithm. decoder_num_finished_states : ``int`` Number of finished states for which costs will be computed by the ExpectedRiskMinimization algorithm. max_decoding_steps : ``int`` Maximum number of steps the decoder should take before giving up. Used both during training and evaluation. Passed to super class. normalize_beam_score_by_length : ``bool``, optional (default=False) Should we normalize the log-probabilities by length before renormalizing the beam? This was shown to work better for NML by Edunov et al., but that many not be the case for semantic parsing. checklist_cost_weight : ``float``, optional (default=0.6) Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about denotation accuracy. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. Passed to super class. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). Passed to super class. num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 10 here matches the default in the ``KnowledgeGraphField``, which is to use all ten defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. Passed to super class. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. Passed to super class. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. Passed to super class. initial_mml_model_file : ``str``, optional (default=None) If you want to initialize this model using weights from another model trained using MML, pass the path to the ``model.tar.gz`` file of that model here. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, mixture_feedforward: FeedForward, input_attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/', initial_mml_model_file: str = None) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = \ ExpectedRiskMinimization(beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states) unlinked_terminals_global_indices = [] global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace) for production, index in global_vocab.items(): right_side = production.split(" -> ")[1] if right_side in types.COMMON_NAME_MAPPING: # This is a terminal production. unlinked_terminals_global_indices.append(index) self._num_unlinked_terminals = len(unlinked_terminals_global_indices) self._decoder_step = WikiTablesDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout, unlinked_terminal_indices=unlinked_terminals_global_indices) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.") def _initialize_weights_from_archive(self, archive: Archive) -> None: logger.info("Initializing weights from MML model.") model_parameters = dict(self.named_parameters()) archived_parameters = dict(archive.model.named_parameters()) question_embedder_weight = "_question_embedder.token_embedder_tokens.weight" if question_embedder_weight not in archived_parameters or \ question_embedder_weight not in model_parameters: raise RuntimeError( "When initializing model weights from an MML model, we need " "the question embedder to be a TokenEmbedder using namespace called " "tokens.") for name, weights in archived_parameters.items(): if name in model_parameters: if name == question_embedder_weight: # The shapes of embedding weights will most likely differ between the two models # because the vocabularies will most likely be different. We will get a mapping # of indices from this model's token indices to the archived model's and copy # the tensor accordingly. vocab_index_mapping = self._get_vocab_index_mapping( archive.model.vocab) archived_embedding_weights = weights.data new_weights = model_parameters[name].data.clone() for index, archived_index in vocab_index_mapping: new_weights[index] = archived_embedding_weights[ archived_index] logger.info("Copied embeddings of %d out of %d tokens", len(vocab_index_mapping), new_weights.size()[0]) else: new_weights = weights.data logger.info("Copying parameter %s", name) model_parameters[name].data.copy_(new_weights) def _get_vocab_index_mapping( self, archived_vocab: Vocabulary) -> List[Tuple[int, int]]: vocab_index_mapping: List[Tuple[int, int]] = [] for index in range(self.vocab.get_vocab_size(namespace='tokens')): token = self.vocab.get_token_from_index(index=index, namespace='tokens') archived_token_index = archived_vocab.get_token_index( token, namespace='tokens') # Checking if we got the UNK token index, because we don't want all new token # representations initialized to UNK token's representation. We do that by checking if # the two tokens are the same. They will not be if the token at the archived index is # UNK. if archived_vocab.get_token_from_index( archived_token_index, namespace="tokens") == token: vocab_index_mapping.append((index, archived_token_index)) return vocab_index_mapping @overrides def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], agenda: torch.LongTensor, example_lisp_string: List[str]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]`` The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. """ batch_size = list(question.values())[0].size(0) # Each instance's agenda is of size (agenda_size, 1) agenda_list = [agenda[i] for i in range(batch_size)] checklist_states = [] all_terminal_productions = [ set(instance_world.terminal_productions.values()) for instance_world in world ] max_num_terminals = max( [len(terminals) for terminals in all_terminal_productions]) for instance_actions, instance_agenda, terminal_productions in zip( actions, agenda_list, all_terminal_productions): checklist_info = self._get_checklist_info(instance_agenda, instance_actions, terminal_productions, max_num_terminals) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) checklist_states.append( ChecklistState(terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist)) initial_info = self._get_initial_state_and_scores( question=question, table=table, world=world, actions=actions, example_lisp_string=example_lisp_string, add_world_to_initial_state=True, checklist_states=checklist_states) initial_state = initial_info["initial_state"] # TODO(pradeep): Keep track of debug info. It's not straightforward currently because the # ERM's decode does not return the best states. outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, self._get_state_cost) if not self.training: # TODO(pradeep): Can move most of this block to super class. linking_scores = initial_info["linking_scores"] feature_scores = initial_info["feature_scores"] similarity_scores = initial_info["similarity_scores"] batch_size = list(question.values())[0].size(0) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping outputs['entities'] = [] outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = similarity_scores outputs['logical_form'] = [] best_action_sequences = outputs['best_action_sequences'] outputs["best_action_sequence"] = [] outputs['debug_info'] = [] agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda] for i in range(batch_size): in_agenda_ratio = 0.0 # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). outputs['logical_form'].append([]) if i in best_action_sequences: for j, action_sequence in enumerate( best_action_sequences[i]): action_strings = [ action_mapping[(i, action_index)] for action_index in action_sequence ] try: logical_form = world[i].get_logical_form( action_strings, add_var_function=False) outputs['logical_form'][-1].append(logical_form) except ParsingError: logical_form = "Error producing logical form" if j == 0: # Updating denotation accuracy and has_logical_form only based on the # first logical form. if logical_form.startswith("Error"): self._has_logical_form(0.0) else: self._has_logical_form(1.0) if example_lisp_string: self._denotation_accuracy( logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append( action_strings) outputs['entities'].append(world[i].table_graph.entities) instance_possible_actions = actions[i] agenda_actions = [] for rule_id in agenda_indices[i]: rule_id = int(rule_id) if rule_id == -1: continue action_string = instance_possible_actions[rule_id][0] agenda_actions.append(action_string) actions_in_agenda = [ action in action_strings for action in agenda_actions ] if actions_in_agenda: # Note: This means that when there are no actions on agenda, agenda coverage # will be 0, not 1. in_agenda_ratio = sum(actions_in_agenda) / len( actions_in_agenda) else: outputs['best_action_sequence'].append([]) outputs['logical_form'][-1].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) self._agenda_coverage(in_agenda_ratio) return outputs @staticmethod def _get_checklist_info( agenda: torch.LongTensor, all_actions: List[ProductionRuleArray], terminal_productions: Set[str], max_num_terminals: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda, a list of all actions, a set of terminal productions in the corresponding world, and a length to pad the checklist vectors to, and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRuleArray]`` All actions for one instance. ``terminal_productions`` : ``Set[str]`` String representations of terminal productions in the corresponding world. ``max_num_terminals`` : ``int`` Length to which the checklist vectors will be padded till. This is the max number of terminal productions in all the worlds in the batch. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set( [int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. for index, action in enumerate(all_actions): # Each action is a ProductionRuleArray, a tuple where the first item is the production # rule string. if action[0] in terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) while len(target_checklist_list) < max_num_terminals: target_checklist_list.append([0]) terminal_indices.append([-1]) # (max_num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (max_num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask def _get_state_cost(self, state: WikiTablesDecoderState) -> torch.Tensor: if not state.is_finished(): raise RuntimeError( "_get_state_cost() is not defined for unfinished states!") # Our checklist cost is a sum of squared error from where we want to be, making sure we # take into account the mask. We clamp the lower limit of the balance at 0 to avoid # penalizing agenda actions produced multiple times. checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0) checklist_cost = torch.sum((checklist_balance)**2) # This is the number of items on the agenda that we want to see in the decoded sequence. # We use this as the denotation cost if the path is incorrect. denotation_cost = torch.sum( state.checklist_state[0].checklist_target.float()) checklist_cost = self._checklist_cost_weight * checklist_cost action_history = state.action_history[0] batch_index = state.batch_indices[0] action_strings = [ state.possible_actions[batch_index][i][0] for i in action_history ] logical_form = state.world[batch_index].get_logical_form( action_strings) lisp_string = state.example_lisp_string[batch_index] if self._denotation_accuracy.evaluate_logical_form( logical_form, lisp_string): cost = checklist_cost else: cost = checklist_cost + ( 1 - self._checklist_cost_weight) * denotation_cost return cost @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ The base class returns a dict with dpd accuracy, denotation accuracy, and logical form percentage metrics. We add the agenda coverage metric here. """ metrics = super().get_metrics(reset) metrics["agenda_coverage"] = self._agenda_coverage.get_metric(reset) return metrics @classmethod def from_params(cls, vocab, params: Params) -> 'WikiTablesErmSemanticParser': question_embedder = TextFieldEmbedder.from_params( vocab, params.pop("question_embedder")) action_embedding_dim = params.pop_int("action_embedding_dim") encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) entity_encoder = Seq2VecEncoder.from_params( params.pop('entity_encoder')) mixture_feedforward_type = params.pop('mixture_feedforward', None) if mixture_feedforward_type is not None: mixture_feedforward = FeedForward.from_params( mixture_feedforward_type) else: mixture_feedforward = None input_attention = Attention.from_params(params.pop("attention")) decoder_beam_size = params.pop_int("decoder_beam_size") decoder_num_finished_states = params.pop_int( "decoder_num_finished_states", None) max_decoding_steps = params.pop_int("max_decoding_steps") normalize_beam_score_by_length = params.pop( "normalize_beam_score_by_length", False) use_neighbor_similarity_for_linking = params.pop_bool( "use_neighbor_similarity_for_linking", False) dropout = params.pop_float('dropout', 0.0) num_linking_features = params.pop_int('num_linking_features', 10) tables_directory = params.pop('tables_directory', '/wikitables/') rule_namespace = params.pop('rule_namespace', 'rule_labels') checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.6) mml_model_file = params.pop('mml_model_file', None) params.assert_empty(cls.__name__) return cls( vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, mixture_feedforward=mixture_feedforward, input_attention=input_attention, decoder_beam_size=decoder_beam_size, decoder_num_finished_states=decoder_num_finished_states, max_decoding_steps=max_decoding_steps, normalize_beam_score_by_length=normalize_beam_score_by_length, checklist_cost_weight=checklist_cost_weight, use_neighbor_similarity_for_linking= use_neighbor_similarity_for_linking, dropout=dropout, num_linking_features=num_linking_features, tables_directory=tables_directory, rule_namespace=rule_namespace, initial_mml_model_file=mml_model_file)
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. To instantiate this model with parameters matching those in the original paper, simply use ``BidirectionalAttentionFlow.from_params(vocab, Params({}))``. This will construct all of the various dependencies needed for the constructor for you. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. attention_similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. initializer : ``InitializerApplicator`` We will use this to initialize the parameters in the model, calling ``initializer(self)``. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. evaluation_json_file : ``str``, optional If given, we will load this JSON into memory and use it to compute official metrics against. We need this separately from the validation dataset, because the official metrics use all of the annotations, while our dataset reader picks the most frequent one. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, mask_lstms: bool = True, evaluation_json_file: str = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_em = Average() self._official_f1 = Average() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms if evaluation_json_file: logger.info("Prepping official evaluation dataset from %s", evaluation_json_file) with open(evaluation_json_file) as dataset_file: dataset_json = json.load(dataset_file) question_to_answers = {} for article in dataset_json['data']: for paragraph in article['paragraphs']: for question in paragraph['qas']: question_id = question['id'] answers = [answer['text'] for answer in question['answers']] question_to_answers[question_id] = answers self._official_eval_dataset = question_to_answers else: self._official_eval_dataset = None def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. The ending position is `exclusive`, so our :class:`~allennlp.data.dataset_readers.SquadReader` adds a special ending token to the end of the passage, to allow for the last token to be included in the answer span. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `exclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span end position (exclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self._get_best_span(span_start_logits, span_end_logits) output_dict = {"span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span} if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None and self._official_eval_dataset: output_dict['best_span_str'] = [] for i in range(batch_size): predicted_span = tuple(best_span[i].data.cpu().numpy()) best_span_string = self._compute_official_metrics(metadata[i], predicted_span) # type: ignore output_dict['best_span_str'].append(best_span_string) return output_dict def _compute_official_metrics(self, metadata: Dict[str, Any], predicted_span: Tuple[int, int]) -> str: passage = metadata['original_passage'] offsets = metadata['token_offsets'] question_id = metadata.get('id', None) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] span_string = passage[start_offset:end_offset] if question_id in self._official_eval_dataset: ground_truth = self._official_eval_dataset[question_id] exact_match = squad_eval.metric_max_over_ground_truths( squad_eval.exact_match_score, span_string, ground_truth) f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, span_string, ground_truth) self._official_em(100 * exact_match) self._official_f1(100 * f1_score) return span_string def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': self._official_em.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } def predict_span(self, question: TextField, passage: TextField) -> Dict[str, Any]: """ Given a question and a passage, predicts the span in the passage that answers the question. Parameters ---------- question : ``TextField`` passage : ``TextField`` A ``TextField`` containing the tokens in the passage. Note that we typically add ``SquadReader.STOP_TOKEN`` as the final token in the passage, because we use exclusive span indices. Be sure you've added that to the passage you pass in here. Returns ------- A Dict containing: span_start_probs : numpy.ndarray span_end_probs : numpy.ndarray best_span : (int, int) """ instance = Instance({'question': question, 'passage': passage}) instance.index_fields(self.vocab) model_input = util.arrays_to_variables(instance.as_array_dict(), add_batch_dimension=True, for_training=False) output_dict = self.forward(**model_input) # Here we're just removing the batch dimension and converting things to numpy arrays / # tuples instead of pytorch variables. return { "span_start_probs": output_dict["span_start_probs"].data.squeeze(0).cpu().numpy(), "span_end_probs": output_dict["span_end_probs"].data.squeeze(0).cpu().numpy(), "best_span": tuple(output_dict["best_span"].data.squeeze(0).cpu().numpy()), } @staticmethod def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new() .resize_(batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params) num_highway_layers = params.pop("num_highway_layers") phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) similarity_function = SimilarityFunction.from_params(params.pop("similarity_function")) modeling_layer = Seq2SeqEncoder.from_params(params.pop("modeling_layer")) span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder")) initializer = InitializerApplicator.from_params(params.pop("initializer", [])) dropout = params.pop('dropout', 0.2) evaluation_json_file = params.pop('evaluation_json_file', None) mask_lstms = params.pop('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, num_highway_layers=num_highway_layers, phrase_layer=phrase_layer, attention_similarity_function=similarity_function, modeling_layer=modeling_layer, span_end_encoder=span_end_encoder, initializer=initializer, dropout=dropout, mask_lstms=mask_lstms, evaluation_json_file=evaluation_json_file)
class PTNChainBidirectionalAttentionFlow(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, gate_sent_encoder: Seq2SeqEncoder, gate_self_attention_layer: Seq2SeqEncoder, bert_projection: FeedForward, span_gate: Seq2SeqEncoder, dropout: float = 0.2, output_att_scores: bool = True, regularizer: Optional[RegularizerApplicator] = None) -> None: super(PTNChainBidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._dropout = torch.nn.Dropout(p=dropout) self._output_att_scores = output_att_scores self._span_gate = span_gate self._bert_projection = bert_projection #self._gate_sent_encoder = gate_sent_encoder self._gate_self_attention_layer = gate_self_attention_layer self._gate_sent_encoder = None self._gate_self_attention_layer = None self._f1_metrics = AttF1Measure(0.5, top_k=False) self._loss_trackers = {'loss': Average(), 'rl_loss': Average()} self.evd_sup_acc_metric = ChainAccuracy() self.evd_ans_metric = Average() self.evd_beam_ans_metric = Average() def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, transition_mask: torch.IntTensor = None, start_transition_mask: torch.Tensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # In this model, we only take the first chain in ``evd_chain_labels`` for supervision evd_chain_labels = evd_chain_labels[:, 0] if not evd_chain_labels is None else None # there may be some instances that we can't find any evd chain for training # In that case, use the mask to ignore those instances evd_instance_mask = (evd_chain_labels[:, 0] != 0).float() if not evd_chain_labels is None else None #print('passage size:', passage['bert'].shape) # bert embedding for answer prediction # shape: [batch_size, max_q_len, emb_size] print('\nBert wordpiece size:', passage['bert'].shape) embedded_question = self._text_field_embedder(question) # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim] embedded_passage = self._text_field_embedder(passage, ) # print('\npassage size:', embedded_passage.shape) #embedded_question = self._bert_projection(embedded_question) #embedded_passage = self._bert_projection(embedded_passage) #print('size embedded_passage:', embedded_passage.shape) # mask ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float() context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float() #print(context_mask.shape) # chain prediction # Shape(all_predictions): (batch_size, num_decoding_steps) # Shape(all_logprobs): (batch_size, num_decoding_steps) # Shape(seq_logprobs): (batch_size,) # Shape(gate): (batch_size * num_spans, 1) # Shape(gate_probs): (batch_size * num_spans, 1) # Shape(gate_mask): (batch_size, num_spans) # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans) # Shape(orders): (batch_size, K, num_spans) all_predictions, \ all_logprobs, \ seq_logprobs, \ gate, \ gate_probs, \ gate_mask, \ g_att_score, \ orders = self._span_gate(embedded_passage, context_mask, embedded_question, ques_mask, evd_chain_labels, self._gate_self_attention_layer, self._gate_sent_encoder, transition_mask, start_transition_mask) batch_size, num_spans, max_batch_span_width = context_mask.size() output_dict = { "pred_sent_labels": gate.squeeze(1).view(batch_size, num_spans), #[B, num_span] "gate_probs": gate_probs.squeeze(1).view(batch_size, num_spans), #[B, num_span] "pred_sent_orders": orders, #[B, K, num_span] } if self._output_att_scores: if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score # compute evd rl training metric, rewards, and loss print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) evd_TP, evd_NP, evd_NT = self._f1_metrics(gate.squeeze(1).view(batch_size, num_spans), sent_labels, mask=gate_mask, instance_mask=evd_instance_mask if self.training else None, sum=False) # print("TP:", evd_TP) # print("NP:", evd_NP) # print("NT:", evd_NT) evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13) evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13) evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13)) predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1), eos_idx=0)[0] gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0] # default to take multiple predicted chains, so unsqueeze dim 1 self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1), gold_labels=evd_chain_labels, predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask) print("gold chain:", evd_chain_labels) predict_mask = predict_mask.float().squeeze(1) rl_loss = -torch.mean(torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None], dim=1)) # torch.cuda.empty_cache() # Compute the EM and F1 on SQuAD and add the tokenized input to the output. # Compute before loss for rl if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] #token_spans_sp = [] #token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] pred_chains_include_ans = [] beam_pred_chains_include_ans = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_sent_tokens']) #token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] #offsets = metadata[i]['token_offsets'] answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) # shift sentence indice back evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0]) ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']]) print("ans_sent_idxs:", metadata[i]['ans_sent_idxs']) if len(metadata[i]['ans_sent_idxs']) > 0: pred_sent_orders = orders[i].detach().cpu().numpy() if any([pred_sent_orders[0][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]): self.evd_ans_metric(1) pred_chains_include_ans.append(1) else: self.evd_ans_metric(0) pred_chains_include_ans.append(0) if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) for beam in range(len(pred_sent_orders))]): self.evd_beam_ans_metric(1) beam_pred_chains_include_ans.append(1) else: self.evd_beam_ans_metric(0) beam_pred_chains_include_ans.append(0) output_dict['question_tokens'] = question_tokens output_dict['passage_sent_tokens'] = passage_tokens #output_dict['token_spans_sp'] = token_spans_sp #output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['pred_chains_include_ans'] = pred_chains_include_ans output_dict['beam_pred_chains_include_ans'] = beam_pred_chains_include_ans output_dict['_id'] = ids # Compute the loss for training. if evd_chain_labels is not None: try: loss = rl_loss self._loss_trackers['loss'](loss) self._loss_trackers['rl_loss'](rl_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(output_dict['_id']) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: p, r, evidence_f1_socre = self._f1_metrics.get_metric(reset) ans_in_evd = self.evd_ans_metric.get_metric(reset) beam_ans_in_evd = self.evd_beam_ans_metric.get_metric(reset) metrics = { 'evd_p': p, 'evd_r': r, 'evd_f1': evidence_f1_socre, 'ans_in_evd': ans_in_evd, 'beam_ans_in_evd': beam_ans_in_evd, } for name, tracker in self._loss_trackers.items(): metrics[name] = tracker.get_metric(reset).item() evd_sup_acc = self.evd_sup_acc_metric.get_metric(reset) metrics['evd_sup_acc'] = evd_sup_acc return metrics @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class Nlvr2EndToEndModuleNetwork(Model): """ A re-implementation of `End-to-End Module Networks for Visual Question Answering <https://www.semanticscholar.org/paper/Learning-to-Reason%3A-End-to-End-Module-Networks-for-Hu-Andreas/5e07d6951b7bc0c4113313a9586ce8178eacdf57>`_ This implementation is based on our semantic parsing framework, and uses marginal likelihood to train the parser when labeled action sequences are not available. It is `not` an exact re-implementation, but rather a very similar model with some significant differences in how the grammar is used. Parameters ---------- vocab : ``Vocabulary`` encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. freeze_encoder: ``bool``, optional (default=True) If true, weights of the encoder will be frozen during training. dropout : ``float``, optional (default=0) Dropout to be applied to encoder outputs and in modules tokens_namespace : ``str``, optional (default=tokens) The vocabulary namespace to use for tokens. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. denotation_namespace : ``str``, optional (default=labels) The vocabulary namespace to use for output labels. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. num_parse_only_batches : ``int``, optional (default=0) We will use this many training batches of `only` parse supervision, not denotation supervision. This is helpful in cases where learning the correct programs at the same time as learning the program executor, both from scratch is challenging. This only works if you have labeled programs. use_gold_program_for_eval : ``bool``, optional (default=True) If true, we will use the gold program for evaluation when it is available (this only tests the program executor, not the parser). load_weights: ``str``, optional (default=None) Path from which to load model weights. If None or if path does not exist, no weights are loaded. use_modules: ``bool``, optional (default=True) If True, use modules and execute them according to programs. If False, use a feedforward network on top of the encoder to directly predict the label. positive_iou_threshold: ``float``, optional (default=0.5) Intersection-over-union (IOU) threshold to use for determining matches between ground-truth and predicted boxes in the faithfulness recall score. negative_iou_threshold: ``float``, optional (default=0.5) Intersection-over-union (IOU) threshold to use for determining matches between ground-truth and predicted boxes in the faithfulness precision score. nmn_settings: Dict, optional (default=None) A dictionary specifying choices determining architectures of the modules. This should not be None if use_modules == True. """ def __init__( self, vocab: Vocabulary, encoder: Seq2SeqEncoder, freeze_encoder: bool = False, dropout: float = 0.0, tokens_namespace: str = "tokens", rule_namespace: str = "rule_labels", denotation_namespace: str = "labels", num_parse_only_batches: int = 0, use_gold_program_for_eval: bool = True, load_weights: str = None, use_modules: bool = True, positive_iou_threshold: float = 0.5, negative_iou_threshold: float = 0.5, nmn_settings: Dict = None, ) -> None: super().__init__(vocab) self._encoder = encoder self._max_decoding_steps = 10 self._add_action_bias = True self._dropout = torch.nn.Dropout(p=dropout) self._tokens_namespace = tokens_namespace self._rule_namespace = rule_namespace self._denotation_namespace = denotation_namespace self._denotation_accuracy = denotation_namespace self._num_parse_only_batches = num_parse_only_batches self._use_gold_program_for_eval = use_gold_program_for_eval self._nmn_settings = nmn_settings self._use_modules = use_modules self._training_batches_so_far = 0 self._denotation_accuracy = CategoricalAccuracy() self._box_f1_score = ClassificationModuleScore( positive_iou_threshold=positive_iou_threshold, negative_iou_threshold=negative_iou_threshold, ) self._best_box_f1_score = ClassificationModuleScore( positive_iou_threshold=positive_iou_threshold, negative_iou_threshold=negative_iou_threshold, ) # TODO(mattg): use FullSequenceMatch instead of this. self._program_accuracy = Average() self._program_similarity = Average() self.loss = torch.nn.BCELoss() self.loss_with_logits = torch.nn.BCEWithLogitsLoss() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) action_embedding_dim = 100 if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) if self._use_modules: self._language_parameters = VisualReasoningNlvr2Parameters( hidden_dim=self._encoder.get_output_dim(), initializer=self._encoder.encoder.model.init_bert_weights, max_boxes=self._nmn_settings["max_boxes"], dropout=dropout, nmn_settings=nmn_settings, ) else: hid_dim = self._encoder.get_output_dim() self.logit_fc = torch.nn.Sequential( torch.nn.Linear(hid_dim * 2, hid_dim * 2), GeLU(), BertLayerNorm(hid_dim * 2, eps=1e-12), torch.nn.Linear(hid_dim * 2, 1), ) self.logit_fc.apply(self._encoder.encoder.model.init_bert_weights) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. encoder_output_dim = self._encoder.get_output_dim() self._decoder_num_layers = 1 self._beam_search = BeamSearch(beam_size=10) self._decoder_trainer = MaximumMarginalLikelihood() self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder_output_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._transition_function = BasicTransitionFunction( encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=AdditiveAttention(vector_dim=encoder_output_dim, matrix_dim=encoder_output_dim), add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers, ) # Our language is constant across instances, so we just create one up front that we can # re-use to construct the `GrammarStatelet`. self._world = VisualReasoningNlvr2Language(None, None, None, None, None, None) if load_weights is not None: if not os.path.exists(load_weights): print('Could not find weights path: ' + load_weights + '. Continuing without loading weights.') else: if torch.cuda.is_available(): state = torch.load(load_weights) else: state = torch.load(load_weights, map_location="cpu") encoder_prefix = "_encoder" lang_params_prefix = "_language_parameters" for key in list(state.keys()): if (key[:len(encoder_prefix)] != encoder_prefix and key[:len(lang_params_prefix)] != lang_params_prefix): del state[key] if "relate_layer" in key: del state[key] self.load_state_dict(state, strict=False) if freeze_encoder: for param in self._encoder.parameters(): param.requires_grad = False self.consistency_group_map = {} def consistency(self, reset: bool = False): if reset: self.consistency_group_map = {} if len(self.consistency_group_map) == 0: return 0.0 consistency = len([ group for group in self.consistency_group_map if self.consistency_group_map[group] == True ]) / float(len(self.consistency_group_map)) return consistency @overrides def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], visual_feat: torch.Tensor, pos: torch.Tensor, image_id: List[str], gold_question_attentions: torch.Tensor = None, gold_box_annotations: List[List[List[List[float]]]] = None, identifier: List[str] = None, logical_form: List[str] = None, actions: List[List[ProductionRule]] = None, target_action_sequence: torch.LongTensor = None, valid_target_sequence: torch.Tensor = None, denotation: torch.Tensor = None, metadata: Dict = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ batch_size, img_num, obj_num, feat_size = visual_feat.size() assert img_num == 2 and feat_size == 2048 text_masks = util.get_text_field_mask(sentence) (l1, v1, text, vis_only1), x1 = self._encoder(sentence[self._tokens_namespace], text_masks, visual_feat[:, 0], pos[:, 0]) (l2, v2, text, vis_only2), x2 = self._encoder(sentence[self._tokens_namespace], text_masks, visual_feat[:, 1], pos[:, 1]) l_orig = torch.cat((l1.unsqueeze(1), l2.unsqueeze(1)), dim=1) v_orig = torch.cat((v1.unsqueeze(1), v2.unsqueeze(1)), dim=1) x_orig = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1) vis_only = torch.cat((vis_only1.unsqueeze(1), vis_only2.unsqueeze(1)), dim=1) # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images) # Can also try concatenating (dim=-1) the two encodings encoded_sentence = text valid_target_sequence = valid_target_sequence.long() initial_state = self._get_initial_state( encoded_sentence[valid_target_sequence == 1], text_masks[valid_target_sequence == 1], actions, ) initial_state.debug_info = [[] for _ in range(batch_size)] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_action_sequence = target_action_sequence[ valid_target_sequence == 1] target_mask = target_action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} losses = [] if (self.training or self._use_gold_program_for_eval ) and target_action_sequence is not None: if valid_target_sequence.sum() > 0: outputs, final_states = self._decoder_trainer.decode( initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)), ) # B X TARGET X SENTENCE question_attention = [[ dbg["question_attention"] for dbg in final_states[i][0].debug_info[0] ] for i in range(len(final_states)) if valid_target_sequence[i].item() == 1] target_attn_loss = self._compute_target_attn_loss( question_attention, gold_question_attentions[valid_target_sequence == 1].squeeze(-1), ) if not self._use_gold_program_for_eval: outputs["loss"] += target_attn_loss else: outputs["loss"] = torch.tensor(0.0) if torch.cuda.is_available(): outputs["loss"] = outputs["loss"].cuda() else: final_states = None outputs["loss"] = torch.tensor(0.0).cuda() if (1 - valid_target_sequence).sum() > 0: if final_states is None: final_states = {} initial_state = self._get_initial_state( encoded_sentence[valid_target_sequence == 0], text_masks[valid_target_sequence == 0], actions, ) remaining_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=False, ) new_final_states = {} count = 0 for i in range(valid_target_sequence.shape[0]): if valid_target_sequence[i] < 0.5: new_final_states[i] = remaining_states[i - count] else: new_final_states[i] = final_states[count] count += 1 final_states = new_final_states if final_states is None: len_final_states = 0 else: len_final_states = len(final_states) else: initial_state = self._get_initial_state(encoded_sentence, text_masks, actions) final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=False, ) action_mapping = {} for action_index, action in enumerate(actions[0]): action_mapping[action_index] = action[0] outputs["action_mapping"] = action_mapping outputs["debug_info"] = [] outputs["modules_debug_info"] = [] outputs["best_action_sequence"] = [] outputs["image_id"] = [] outputs["prediction"] = [] outputs["label"] = [] outputs["correct"] = [] outputs["bboxes"] = [] outputs = self._compute_parsing_validation_outputs( actions, target_action_sequence.shape[0], final_states, initial_state, [ datum for i, datum in enumerate(metadata) if valid_target_sequence[i].item() == 1 ], outputs, target_action_sequence, ) if not self._use_modules: logits = self.logit_fc(x.view(batch_size, -1)) if denotation is not None: self._denotation_accuracy( torch.cat((torch.zeros_like(logits), logits), dim=-1), denotation) if self.training: outputs["loss"] += self.loss_with_logits( logits.view(-1), denotation.view(-1).float()) self._training_batches_so_far += 1 else: outputs["loss"] = self.loss_with_logits( logits.view(-1), denotation.view(-1).float()) return outputs if self._nmn_settings["mask_non_attention"]: zero_one_mult = (torch.zeros_like(text_masks).unsqueeze(1).repeat( 1, target_action_sequence.shape[1], 1)) reformatted_gold_question_attentions = torch.where( gold_question_attentions.squeeze(-1) == -1, torch.zeros_like(gold_question_attentions.squeeze(-1)), gold_question_attentions.squeeze(-1), ) pred_question_attention = [ torch.stack([ torch.nn.functional.pad( dbg["question_attention"].view(-1), pad=( 0, zero_one_mult.shape[-1] - dbg["question_attention"].numel(), ), ) for dbg in final_states[i][0].debug_info[0] ]) for i in range(len(final_states)) ] pred_question_attention = torch.stack([ torch.nn.functional.pad( attention, pad=(0, 0, 0, zero_one_mult.shape[1] - attention.shape[0]), ) for attention in pred_question_attention ]).to(zero_one_mult.device) zero_one_mult.scatter_( 2, reformatted_gold_question_attentions, torch.ones_like(reformatted_gold_question_attentions), ) zero_one_mult[:, :, 0] = 1.0 sep_indices = ( (text_masks * (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat( batch_size, 1).to(text_masks.device))).argmax(1).long()) sep_indices = (sep_indices.unsqueeze(1).repeat( 1, text_masks.shape[1]).unsqueeze(1).repeat( 1, target_action_sequence.shape[1], 1)) indices_dim2 = (torch.arange( text_masks.shape[1]).unsqueeze(0).unsqueeze(0).repeat( batch_size, target_action_sequence.shape[1], 1).to(sep_indices.device).long()) zero_one_mult = torch.where( sep_indices == indices_dim2, torch.ones_like(zero_one_mult), zero_one_mult, ).float() reshaped_questions = ( sentence[self._tokens_namespace].unsqueeze(1).repeat( 1, target_action_sequence.shape[1], 1).view(-1, text_masks.shape[-1])) reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat( 1, target_action_sequence.shape[1], 1, 1, 1).view(-1, img_num, obj_num, visual_feat.shape[-1])) reshaped_pos = (pos.unsqueeze(1).repeat( 1, target_action_sequence.shape[1], 1, 1, 1).view(-1, img_num, obj_num, pos.shape[-1])) zero_one_mult = zero_one_mult.view(-1, text_masks.shape[-1]) q_att_filter = zero_one_mult.sum(1) > 2 (l1, v1, text, vis_only1), x1 = self._encoder( reshaped_questions[q_att_filter, :], zero_one_mult[q_att_filter, :], reshaped_visual_feat[q_att_filter, 0, :, :], reshaped_pos[q_att_filter, 0, :, :], ) (l2, v2, text, vis_only2), x2 = self._encoder( reshaped_questions[q_att_filter, :], zero_one_mult[q_att_filter, :], reshaped_visual_feat[q_att_filter, 1, :, :], reshaped_pos[q_att_filter, 1, :, :], ) l_cat = torch.cat((l1.unsqueeze(1), l2.unsqueeze(1)), dim=1) v_cat = torch.cat((v1.unsqueeze(1), v2.unsqueeze(1)), dim=1) x_cat = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1) l = [{} for _ in range(batch_size)] v = [{} for _ in range(batch_size)] x = [{} for _ in range(batch_size)] count = 0 batch_index = -1 for i in range(zero_one_mult.shape[0]): module_num = i % target_action_sequence.shape[1] if module_num == 0: batch_index += 1 state = final_states[batch_index][0] action_indices = state.action_history[0] action_strings = [ action_mapping[action_index] for action_index in action_indices ] if q_att_filter[i].item(): l[batch_index][module_num] = self._dropout(l_cat[count]) v[batch_index][module_num] = self._dropout(v_cat[count]) x[batch_index][module_num] = self._dropout(x_cat[count]) count += 1 else: l = self._dropout(l_orig) v = self._dropout(v_orig) x = self._dropout(x_orig) outputs["box_acc"] = [{} for _ in range(batch_size)] outputs["best_box_acc"] = [{} for _ in range(batch_size)] outputs["box_score"] = [{} for _ in range(batch_size)] outputs["box_f1"] = [{} for _ in range(batch_size)] outputs["box_f1_overall_score"] = [{} for _ in range(batch_size)] outputs["best_box_f1"] = [{} for _ in range(batch_size)] outputs["gold_box"] = [] outputs["ious"] = [{} for _ in range(batch_size)] for batch_index in range(batch_size): if (self.training and self._training_batches_so_far < self._num_parse_only_batches): continue if not final_states[batch_index]: logger.error(f"No pogram found for batch index {batch_index}") outputs["best_action_sequence"].append([]) outputs["debug_info"].append([]) continue # print(denotation.shape, denotation[batch_index]) outputs["modules_debug_info"].append([]) denotation_log_prob_list = [] # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we # execute) at test time, just for efficiency. for state_index, state in enumerate(final_states[batch_index]): world = VisualReasoningNlvr2Language( l[batch_index], v[batch_index], x[batch_index], self._language_parameters, metadata[batch_index]["tokenized_utterance"], pos[batch_index], self._nmn_settings, ) action_indices = state.action_history[0] action_strings = [ action_mapping[action_index] for action_index in action_indices ] # Shape: (num_denotations,) assert len(action_strings) == len(state.debug_info[0]) # Plug in gold question attentions for i in range(len(state.debug_info[0])): if (self._use_gold_program_for_eval and valid_target_sequence[batch_index] == 1): n_att_words = ((gold_question_attentions[batch_index, i] >= 0).float().sum()) state.debug_info[0][i][ "question_attention"] = torch.zeros_like( state.debug_info[0][i]["question_attention"]) if n_att_words > 0: for j in gold_question_attentions[batch_index, i]: if j >= 0: state.debug_info[0][i][ "question_attention"][j] = ( 1.0 / n_att_words) if (i not in l[batch_index] and self._nmn_settings["mask_non_attention"] and (action_strings[i][-4:] == "find" or action_strings[i][-6:] == "filter" or action_strings[i][-13:] == "with_relation" or action_strings[i][-7:] == "project")): l[batch_index][i] = l_orig[batch_index, :, :] v[batch_index][i] = v_orig[batch_index, :, :] x[batch_index][i] = x_orig[batch_index, :] world = VisualReasoningNlvr2Language( l[batch_index], v[batch_index], x[batch_index], self._language_parameters, metadata[batch_index]["tokenized_utterance"], pos[batch_index], self._nmn_settings, ) world.parameters.train(self.training) state_denotation_probs = world.execute_action_sequence( action_strings, state.debug_info[0]) outputs["modules_debug_info"][batch_index].append( world.modules_debug_info) # P(denotation | parse) * P(parse | question) world_log_prob = (state_denotation_probs + 1e-6).log() if not self._use_gold_program_for_eval: world_log_prob += state.score[0] denotation_log_prob_list.append(world_log_prob) # P(denotation | parse) * P(parse | question) for the all programs on the beam. # Shape: (beam_size, num_denotations) denotation_log_probs = torch.stack(denotation_log_prob_list) # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question) # Shape: (num_denotations,) marginalized_denotation_log_probs = util.logsumexp( denotation_log_probs, dim=0) if denotation is not None: # This line is needed, otherwise we have numbers slightly exceeding 0..1. Should check why state_denotation_probs = state_denotation_probs.clamp(min=0, max=1) loss = self.loss( state_denotation_probs.unsqueeze(0), denotation[batch_index].unsqueeze(0).float(), ).view(1) losses.append(loss) self._denotation_accuracy( torch.tensor( [1 - state_denotation_probs, state_denotation_probs]).to(denotation.device), denotation[batch_index], ) group_id = metadata[batch_index]["identifier"].split("-") group_id = group_id[0] + "-" + group_id[1] + "-" + group_id[-1] if group_id not in self.consistency_group_map: self.consistency_group_map[group_id] = True if (state_denotation_probs.item() >= 0.5 and denotation[batch_index].item() < 0.5) or ( state_denotation_probs.item() < 0.5 and denotation[batch_index].item() > 0.5): self.consistency_group_map[group_id] = False if (gold_box_annotations is not None and len(gold_box_annotations[batch_index]) > 0): box_f1_score, overall_f1_score_value = self._box_f1_score( outputs["modules_debug_info"][batch_index][0], gold_box_annotations[batch_index], pos[batch_index], ) outputs["box_f1"][batch_index] = box_f1_score outputs["box_f1_overall_score"][ batch_index] = overall_f1_score_value best_f1_predictions = self._best_box_f1_score.compute_best_box_predictions( outputs["modules_debug_info"][batch_index][0], gold_box_annotations[batch_index], pos[batch_index], ) best_f1, _ = self._best_box_f1_score( best_f1_predictions, gold_box_annotations[batch_index], pos[batch_index], ) outputs["best_box_f1"][batch_index] = best_f1 outputs["gold_box"].append( gold_box_annotations[batch_index]) outputs["image_id"].append(image_id[batch_index]) outputs["prediction"].append(world_log_prob.exp()) if denotation is not None: outputs["label"].append(denotation[batch_index]) outputs["correct"].append(world_log_prob.exp().round().int() == denotation[batch_index].int()) outputs["bboxes"].append(pos[batch_index]) if losses: outputs["loss"] += torch.stack(losses).mean() if self.training: self._training_batches_so_far += 1 return outputs def _compute_parsing_validation_outputs( self, actions, batch_size, final_states, initial_state, metadata, outputs, target_action_sequence, ): if (not self.training and target_action_sequence is not None and target_action_sequence.numel() > 0): outputs["parse_correct"] = [] # skip beam search if we already searched # if self._use_gold_program_for_eval: # final_states = self._beam_search.search(self._max_decoding_steps, # initial_state, # self._transition_function, # keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in final_states: best_action_indices = final_states[i][0].action_history[0] best_action_sequences[i] = best_action_indices targets = target_action_sequence[i].data sequence_in_targets = self._action_history_match( best_action_indices, targets) self._program_accuracy(sequence_in_targets) similarity = difflib.SequenceMatcher( None, best_action_indices, targets) self._program_similarity(similarity.ratio()) outputs["parse_correct"].append(sequence_in_targets) else: self._program_accuracy(0) self._program_similarity(0) continue batch_action_strings = self._get_action_strings( actions, best_action_sequences) if metadata is not None: outputs["sentence_tokens"] = [ x["tokenized_utterance"] for x in metadata ] outputs["utterance"] = [x["utterance"] for x in metadata] outputs["parse_gold"] = [x["gold"] for x in metadata] outputs["debug_info"] = [] outputs["parse_predicted"] = [] outputs["action_mapping"] = [] for i in range(batch_size): if i in final_states: outputs["debug_info"].append( final_states[i][0].debug_info[0]) # type: ignore outputs["action_mapping"].append( [a[0] for a in actions[i]]) outputs["parse_predicted"].append( self._world.action_sequence_to_logical_form( batch_action_strings[i])) outputs["best_action_strings"] = batch_action_strings action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] return outputs def _get_initial_state( self, encoder_outputs: torch.Tensor, utterance_mask: torch.Tensor, actions: List[ProductionRule], ) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) # Use CLS states as final encoder outputs memory_cell = encoder_outputs.new_zeros(batch_size, encoder_outputs.shape[-1]) initial_score = encoder_outputs.data.new_zeros(batch_size) attended_sentence, _ = self._transition_function.attend_on_question( final_encoder_output, encoder_outputs, utterance_mask) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: encoder_output = final_encoder_output[i].repeat( self._decoder_num_layers, 1) cell = memory_cell[i].repeat(self._decoder_num_layers, 1) else: encoder_output = final_encoder_output[i] cell = memory_cell[i] initial_rnn_state.append( RnnStatelet( encoder_output, cell, self._first_action_embedding, attended_sentence[i], encoder_output_list, utterance_mask_list, )) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=[[] for _ in range(batch_size)], ) return initial_state @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = { "denotation_acc": self._denotation_accuracy.get_metric(reset), "program_acc": self._program_accuracy.get_metric(reset), "consistency": self.consistency(reset), } for m in self._box_f1_score.modules: metrics["_" + m + "_box_f1"] = self._box_f1_score.get_metric( reset=False, module=m)["f1"] metrics["_" + m + "_best_box_f1"] = self._best_box_f1_score.get_metric( reset=False, module=m)["f1"] box_f1_pr = self._box_f1_score.get_metric(reset=reset) for key in box_f1_pr: metrics["overall_box_" + key] = box_f1_pr[key] best_box_f1_pr = self._best_box_f1_score.get_metric(reset=reset) for key in best_box_f1_pr: metrics["overall_best_box_" + key] = best_box_f1_pr[key] return metrics def _create_grammar_state( self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = self._world.get_nonterminal_productions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append((production_rule_array[2], action_index)) global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]["global"] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids), ) return GrammarStatelet([START_SYMBOL], translated_valid_actions, self._world.is_nonterminal) @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ # TODO(mattg): FIX THIS - I haven't touched this method yet. action_mapping = output_dict["action_mapping"] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict["debug_info"] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions, debug_info): action_info = {} action_info["predicted_action"] = predicted_action considered_actions = action_debug_info["considered_actions"] probabilities = action_debug_info["probabilities"] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info["considered_actions"] = considered_actions action_info["action_probabilities"] = probabilities action_info["utterance_attention"] = action_debug_info.get( "question_attention", []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict def _compute_target_attn_loss(self, question_attention, gold_question_attention): attn_loss = 0 normalizer = 0 gold_question_attention = gold_question_attention.cpu().numpy() # TODO: Pad and batch this for performance for instance_attn, gld_instance_attn in zip(question_attention, gold_question_attention): for step_attn, gld_step_attn in zip(instance_attn, gld_instance_attn): if gld_step_attn[0] == -1: continue # consider only non-padding indices gld_step_attn = [a for a in gld_step_attn if a > -1] given_attn = step_attn[gld_step_attn] attn_loss += (given_attn.sum() + 1e-8).log() normalizer += 1 if normalizer == 0: return 0 else: return -1 * (attn_loss / normalizer) @classmethod def _get_action_strings( cls, possible_actions: List[List[ProductionRule]], action_indices: Dict[int, List[List[int]]], ) -> List[List[List[str]]]: """ Takes a list of possible actions and indices of decoded actions into those possible actions for a batch and returns sequences of action strings. We assume ``action_indices`` is a dict mapping batch indices to k-best decoded sequence lists. """ all_action_strings: List[List[List[str]]] = [] batch_size = len(possible_actions) for i in range(batch_size): batch_actions = possible_actions[i] batch_best_sequences = action_indices[ i] if i in action_indices else [] # This will append an empty list to ``all_action_strings`` if ``batch_best_sequences`` # is empty. action_strings = [ batch_actions[rule_id][0] for rule_id in batch_best_sequences ] all_action_strings.append(action_strings) return all_action_strings
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() max_turn_length = 12 self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed( torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding( max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding( (num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed( torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question[ 'token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape( total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout( self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector( max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat( 1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat( batch_size, 1, 1) question_num_ind = question_num_ind.reshape( total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker( question_num_ind) embedded_question = torch.cat( [embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view( total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view( total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker( p3_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout( self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout( self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout( self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention( repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax( question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum( repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector ], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout( self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([ self_attention_vecs, residual_layer, residual_layer * self_attention_vecs ], dim=-1) residual_layer = F.relu( self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze( -1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding(max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match(phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class SpansText2SqlParser(Model): """ Parameters ---------- vocab : ``Vocabulary`` utterance_embedder : ``TextFieldEmbedder`` Embedder for utterances. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input utterance. decoder_beam_search : ``BeamSearch`` Beam search used to retrieve best sequences after training. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. input_attention: ``Attention`` We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). span_extractor: ``SpanExtractor``, optional If provided, extracts spans representations based on the encoded inputs. The span representations are used for decoding. """ def __init__(self, vocab: Vocabulary, mydatabase: str, schema_path: str, utterance_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, input_attention: Attention, add_action_bias: bool = True, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, span_extractor: SpanExtractor = None) -> None: super().__init__(vocab, regularizer) self._utterance_embedder = utterance_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._dropout = torch.nn.Dropout(p=dropout) # span extractor, allows using spans from the source as input to the decoder self._span_extractor = span_extractor self._exact_match = Average() self._action_similarity = Average() self._valid_sql_query = SqlValidity(mydatabase=mydatabase) self._token_match = TokenSequenceAccuracy() self._kb_match = KnowledgeBaseConstsAccuracy(schema_path=schema_path) self._schema_free_match = GlobalTemplAccuracy(schema_path=schema_path) # the padding value used by IndexField self._action_padding_index = -1 num_actions = vocab.get_vocab_size("rule_labels") input_action_dim = action_embedding_dim if self._add_action_bias: input_action_dim += 1 self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding( num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous utterance attention. self._first_action_embedding = torch.nn.Parameter( torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter( torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1) self._transition_function = BasicTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=input_attention, add_action_bias=self._add_action_bias, dropout=dropout) self.parse_sql_on_decoding = True initializer(self) @overrides def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None, spans: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. spans: torch.Tensor, optional (default=None) A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of input spans that could be informative for the decoder. Comes from a ``ListField[SpanField]`` """ encode_outputs = self._encode(tokens, spans) # encode_outputs['mask'] shape: (batch_size, num_tokens, encoder_output_dim) batch_size = encode_outputs['mask'].size(0) initial_state = self._get_initial_state( encode_outputs['encoder_outputs'], encode_outputs['mask'], valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. try: loss_output = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) except ZeroDivisionError as e: logger.info( f"Input utterance in ZeroDivisionError: {[t.text for t in tokens['tokens']]}" ) raise e outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['target_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Add the target sql from the target actions for sql tokens exact match comparison target_sql_query = '' if action_sequence is not None: target_action_strings = [ action_mapping[i][action_index] for action_index in action_sequence[i].data.tolist() if action_index != self._action_padding_index ] target_sql_query = action_sequence_to_sql( target_action_strings) # target_sql_query = sqlparse.format(target_sql_query, reindent=True) target_sql_query_for_acc = target_sql_query.split() # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._action_similarity(0) outputs['target_sql_query'].append( target_sql_query_for_acc) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [ action_mapping[i][action_index] for action_index in best_action_indices ] predicted_sql_query = action_sequence_to_sql(action_strings) predicted_sql_query_for_acc = predicted_sql_query.split() if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match( best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher( None, best_action_indices, targets) self._action_similarity(similarity.ratio()) # predicted_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in # predicted_sql_query.split()] # target_sql_query_for_acc = [token if '@' not in token else token.split('@')[1] for token in # target_sql_query.split()] predicted_sql_query_for_acc = re.sub( r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ", r" \g<1> AS \g<1>\g<2> ", predicted_sql_query).split() target_sql_query_for_acc = re.sub( r" TABLE_PLACEHOLDER AS ([A-Z_]+)\s*(alias[0-9]) ", r" \g<1> AS \g<1>\g<2> ", target_sql_query).split() self._valid_sql_query([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._token_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._kb_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) self._schema_free_match([predicted_sql_query_for_acc], [target_sql_query_for_acc]) outputs['best_action_sequence'].append(action_strings) # outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['predicted_sql_query'].append( predicted_sql_query_for_acc) outputs['target_sql_query'].append(target_sql_query_for_acc) outputs['debug_info'].append( best_final_states[i][0].debug_info[0]) # type: ignore return outputs def _encode(self, tokens: Dict[str, torch.LongTensor], spans: torch.Tensor = None): """ If spans are provided, returns the encoded spans (by self._span_extractor) instead of the encoded utterance tokens """ outputs = {} embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() outputs['mask'] = mask # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) outputs['encoder_outputs'] = encoder_outputs # if spans (over the input) are given, return their representation instead of the # source tokens representation if spans is not None and self._span_extractor is not None: # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() span_representations = self._span_extractor( encoder_outputs, spans, mask, span_mask) outputs["mask"] = span_mask outputs["encoder_outputs"] = span_representations return outputs def _get_initial_state( self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]) -> GrammarBasedState: batch_size = encoder_outputs.size(0) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = encoder_outputs.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [ self._create_grammar_state(actions[i]) for i in range(batch_size) ] initial_sql_state = [ SqlStatelet(actions[i], self.parse_sql_on_decoding) for i in range(batch_size) ] initial_state = GrammarBasedState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, debug_info=None) return initial_state @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True @staticmethod def get_terminals_mask(action_strings): terminals_mask = [] for j, rule in enumerate(action_strings): lhs, rhs = rule.split('->') rhs_values = rhs.strip().strip('[]').split(',') if len(rhs_values) == 1 and rhs_values[0].strip().strip( '"') != rhs_values[0].strip(): terminals_mask.append(1) elif 'TABLE_PLACEHOLDER' in rhs: terminals_mask.append(1) else: terminals_mask.append(0) return terminals_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track four metrics here: 1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that can be parsed. (make sure you pass "keep_if_unparseable=True" to the dataset reader, which we do for validation data, but not training data). 3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we're trying to produce a super long SQL query and run out of time steps, or something. 4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match. """ validation_correct = self._exact_match._total_value # pylint: disable=protected-access validation_total = self._exact_match._count # pylint: disable=protected-access all_metrics = { '_exact_match_count': validation_correct, '_example_count': validation_total, 'exact_match': self._exact_match.get_metric(reset), 'sql_validity': self._valid_sql_query.get_metric(reset=reset)['sql_validity'], 'action_similarity': self._action_similarity.get_metric(reset) } all_metrics.update(self._token_match.get_metric(reset=reset)) all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update(self._schema_free_match.get_metric(reset=reset)) return all_metrics def _create_grammar_state( self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append( (action, i)) else: raise ValueError( "The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items( ): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append( (production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True) @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_actions`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append( (action_mapping[batch_index][action], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['utterance_attention'] = action_debug_info.get( 'question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class WikiTablesSemanticParser(Model): """ A ``WikiTablesSemanticParser`` is a :class:`Model` which takes as input a table and a question, and produces a logical form that answers the question when executed over the table. The logical form is generated by a `type-constrained`, `transition-based` parser. This is an abstract class that defines most of the functionality related to the transition-based parser. It does not contain the implementation for actually training the parser. You may want to train it using a learning-to-search algorithm, in which case you will want to use ``WikiTablesErmSemanticParser``, or if you have a set of approximate logical forms that give the correct denotation, you will want to use ``WikiTablesMmlSemanticParser``. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. entity_encoder : ``Seq2VecEncoder`` The encoder to used for averaging the words of an entity. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. add_action_bias : ``bool``, optional (default=True) If ``True``, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding. use_neighbor_similarity_for_linking : ``bool``, optional (default=False) If ``True``, we will compute a max similarity between a question token and the `neighbors` of an entity as a component of the linking scores. This is meant to capture the same kind of information as the ``related_column`` feature. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. tables_directory : ``str``, optional (default=/wikitables/) The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells SEMPRE where to find the tables. """ # pylint: disable=abstract-method def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, max_decoding_steps: int, add_action_bias: bool = True, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/') -> None: super().__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._entity_encoder = TimeDistributed(entity_encoder) self._max_decoding_steps = max_decoding_steps self._add_action_bias = add_action_bias self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._executor = WikiTablesSempreExecutor(tables_directory) self._denotation_accuracy = Average() self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim())) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(), "entity word average embedding dim", "question embedding dim") self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 5 # TODO(mattg): get this in a more principled way somehow? self._embedding_dim = question_embedder.get_output_dim() self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None if self._use_neighbor_similarity_for_linking: self._question_entity_params = torch.nn.Linear(1, 1) self._question_neighbor_params = torch.nn.Linear(1, 1) else: self._question_entity_params = None self._question_neighbor_params = None def _get_initial_rnn_and_grammar_state(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet], List[LambdaGrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state @staticmethod def _get_neighbor_indices(worlds: List[WikiTablesWorld], num_entities: int, tensor: torch.Tensor) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. """ num_neighbors = 0 for world in worlds: for entity in world.table_graph.entities: if len(world.table_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.table_graph.neighbors[entity]) batch_neighbors = [] for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.table_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.table_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]] # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length(neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) return tensor.new_tensor(batch_neighbors, dtype=torch.long) @staticmethod def _get_type_vector(worlds: List[WikiTablesWorld], num_entities: int, tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith('fb:cell'): entity_type = 1 elif entity.startswith('fb:part'): entity_type = 2 elif entity.startswith('fb:row'): entity_type = 3 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types def _get_linking_probabilities(self, worlds: List[WikiTablesWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. dpd_acc, which is the percentage of the time that our best output action sequence is in the set of action sequences provided by DPD. This is an easy-to-compute lower bound on denotation accuracy for the set of examples where we actually have DPD output. We only score dpd_acc on that subset. 2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical "accuracy" metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you're computing this on the full data, and not just the subset that has DPD output (make sure you pass "keep_if_no_dpd=True" to the dataset reader, which we do for validation data, but not training data). 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ return { 'dpd_acc': self._action_sequence_accuracy.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'lf_percent': self._has_logical_form.get_metric(reset), } def _create_grammar_state(self, world: WikiTablesWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> LambdaGrammarStatelet: """ This method creates the LambdaGrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The way we represent the valid expansions is a little complicated: we use a dictionary of `action types`, where the key is the action type (like "global", "linked", or whatever your model is expecting), and the value is a tuple representing all actions of that type. The tuple is (input tensor, output tensor, action id). The input tensor has the representation that is used when `selecting` actions, for all actions of this type. The output tensor has the representation that is used when feeding the action to the next step of the decoder (this could just be the same as the input tensor). The action ids are a list of indices into the main action list for each batch instance. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``WikiTablesWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ # TODO(mattg): Move the "valid_actions" construction to another method. action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder(global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases(global_action_tensor) global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor) translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) # Lastly, we need to also create embedded representations of context-specific actions. In # this case, those are only variable productions, like "r -> x". Note that our language # only permits one lambda at a time, so we don't need to worry about how nested lambdas # might impact this. context_actions = {} for action_id, action in enumerate(possible_actions): if action[0].endswith(" -> x"): input_embedding = self._action_embedder(action[2]) if self._add_action_bias: input_bias = self._action_biases(action[2]) input_embedding = torch.cat([input_embedding, input_bias], dim=-1) output_embedding = self._output_action_embedder(action[2]) context_actions[action[0]] = (input_embedding, output_embedding, action_id) return LambdaGrammarStatelet([START_SYMBOL], {}, translated_valid_actions, context_actions, type_declaration.is_nonterminal) def _compute_validation_outputs(self, actions: List[List[ProductionRule]], best_final_states: Mapping[int, Sequence[GrammarBasedState]], world: List[WikiTablesWorld], example_lisp_string: List[str], metadata: List[Dict[str, Any]], outputs: Dict[str, Any]) -> None: """ Does common things for validation time: computing logical form accuracy (which is expensive and unnecessary during training), adding visualization info to the output dictionary, etc. This doesn't return anything; instead it `modifies` the given ``outputs`` dictionary, and calls metrics on ``self``. """ batch_size = len(actions) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['logical_form'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: logical_form = world[i].get_logical_form(action_strings, add_var_function=False) self._has_logical_form(1.0) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' if example_lisp_string: denotation_correct = self._executor.evaluate_logical_form(logical_form, example_lisp_string[i]) self._denotation_accuracy(1.0 if denotation_correct else 0.0) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs['entities'].append(world[i].table_graph.entities) else: outputs['logical_form'].append('') self._has_logical_form(0.0) self._denotation_accuracy(0.0) if metadata is not None: outputs["question_tokens"] = [x["question_tokens"] for x in metadata] outputs["original_table"] = [x["original_table"] for x in metadata] @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in the ``TransitionFunction``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class PassageAttnToCount(Model): def __init__( self, vocab: Vocabulary, passage_attention_to_count: Seq2SeqEncoder, dropout: float = 0.2, initializers: InitializerApplicator = InitializerApplicator() ) -> None: super(PassageAttnToCount, self).__init__(vocab=vocab) self.scaling_vals = [1, 2, 5, 10] self.passage_attention_to_count = passage_attention_to_count assert len(self.scaling_vals ) == self.passage_attention_to_count.get_input_dim() self.num_counts = 10 # self.passage_count_predictor = torch.nn.Linear(self.passage_attention_to_count.get_output_dim(), # self.num_counts, bias=False) # We want to predict a score for each passage token self.passage_count_hidden2logits = torch.nn.Linear( self.passage_attention_to_count.get_output_dim(), 1, bias=True) self.passagelength_to_bias = torch.nn.Linear(1, 1, bias=True) self.count_acc = Average() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializers(self) # self.passage_count_hidden2logits.bias.data.fill_(-1.0) # self.passage_count_hidden2logits.bias.requires_grad = False def device_id(self): allenutil.get_device_of() @overrides def forward( self, passage_attention: torch.Tensor, passage_lengths: List[int], count_answer: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: device_id = allenutil.get_device_of(passage_attention) batch_size, max_passage_length = passage_attention.size() # Shape: (B, passage_length) passage_mask = (passage_attention >= 0).float() # List of (B, P) shaped tensors scaled_attentions = [ passage_attention * sf for sf in self.scaling_vals ] # Shape: (B, passage_length, num_scaling_factors) scaled_passage_attentions = torch.stack(scaled_attentions, dim=2) # Shape (batch_size, 1) passage_len_bias = self.passagelength_to_bias( passage_mask.sum(1, keepdim=True)) scaled_passage_attentions = scaled_passage_attentions * passage_mask.unsqueeze( 2) # Shape: (B, passage_length, hidden_dim) count_hidden_repr = self.passage_attention_to_count( scaled_passage_attentions, passage_mask) # Shape: (B, passage_length, 1) -- score for each token passage_span_logits = self.passage_count_hidden2logits( count_hidden_repr) # Shape: (B, passage_length) -- sigmoid on token-score token_sigmoids = torch.sigmoid(passage_span_logits.squeeze(2)) token_sigmoids = token_sigmoids * passage_mask # Shape: (B, 1) -- sum of sigmoids. This will act as the predicted mean # passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) + passage_len_bias passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) # Shape: (1, count_vals) self.countvals = allenutil.get_range_vector( 10, device=device_id).unsqueeze(0).float() variance = 0.2 # Shape: (batch_size, count_vals) l2_by_vsquared = torch.pow(self.countvals - passage_count_mean, 2) / (2 * variance * variance) exp_val = torch.exp(-1 * l2_by_vsquared) + 1e-30 # Shape: (batch_size, count_vals) count_distribution = exp_val / (torch.sum(exp_val, 1, keepdim=True)) # Loss computation output_dict = {} loss = 0.0 pred_count_idx = torch.argmax(count_distribution, 1) if count_answer is not None: # L2-loss passage_count_mean = passage_count_mean.squeeze(1) L2Loss = F.mse_loss(input=passage_count_mean, target=count_answer.float()) loss = L2Loss predictions = passage_count_mean.detach().cpu().numpy() predictions = np.round_(predictions) gold_count = count_answer.detach().cpu().numpy() correct_vec = (predictions == gold_count) correct_perc = sum(correct_vec) / batch_size # print(f"{correct_perc} {predictions} {gold_count}") self.count_acc(correct_perc) # loss = F.cross_entropy(input=count_distribution, target=count_answer) # List of predicted count idxs, Shape: (B,) # correct_vec = (pred_count_idx == count_answer).float() # correct_perc = torch.sum(correct_vec) / batch_size # self.count_acc(correct_perc.item()) batch_loss = loss / batch_size output_dict["loss"] = batch_loss output_dict["passage_attention"] = passage_attention output_dict["passage_sigmoid"] = token_sigmoids output_dict["count_mean"] = passage_count_mean output_dict["count_distritbuion"] = count_distribution output_dict["count_answer"] = count_answer output_dict["pred_count"] = pred_count_idx return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metric_dict = {} count_acc = self.count_acc.get_metric(reset) metric_dict.update({'acc': count_acc}) return metric_dict
class SimpleProjectionOld(Model): """ """ def __init__(self, vocab: Vocabulary, input_embedder: TextFieldEmbedder, pooler: Seq2VecEncoder, nli_projection_layer: FeedForward, training_tasks: Any, validation_tasks: Any, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(SimpleProjectionOld, self).__init__(vocab, regularizer) if type(training_tasks) == dict: self._training_tasks = list(training_tasks.keys()) else: self._training_tasks = training_tasks if type(validation_tasks) == dict: self._validation_tasks = list(validation_tasks.keys()) else: self._validation_tasks = validation_tasks self._input_embedder = input_embedder self._pooler = pooler self._label_namespace = "labels" self._num_labels = vocab.get_vocab_size( namespace=self._label_namespace) self._nli_projection_layer = nli_projection_layer print( vocab.get_token_to_index_vocabulary( namespace=self._label_namespace)) assert nli_projection_layer.get_output_dim() == self._num_labels self._dropout = torch.nn.Dropout(p=dropout) self._loss = torch.nn.CrossEntropyLoss() initializer(self._nli_projection_layer) self._nli_per_lang_acc: Dict[str, CategoricalAccuracy] = dict() for taskname in self._validation_tasks: # this will hide some metrics from tqdm, but they will still be computed self._nli_per_lang_acc[taskname] = CategoricalAccuracy() self._nli_avg_acc = Average() def forward( self, # type: ignore premise_hypothesis: Dict[str, torch.Tensor] = None, premise: Dict[str, torch.Tensor] = None, hypothesis: Dict[str, torch.LongTensor] = None, dataset: List[str] = None, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise_hypothesis : Dict[str, torch.LongTensor] Combined in a single text field for BERT encoding premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional, (default = None) From a ``LabelField`` dataset : List[str] Task indicator metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ if dataset is not None: # TODO: hardcoded; used when not multitask reader was used taskname = dataset[0] else: taskname = "nli-en" if premise_hypothesis is not None: assert premise is None and hypothesis is None if premise_hypothesis is not None: embedded_combined = self._input_embedder(premise_hypothesis) mask = get_text_field_mask(premise_hypothesis).float() pooled_combined = self._pooler(embedded_combined, mask=mask) elif premise is not None and hypothesis is not None: embedded_premise = self._input_embedder(premise) embedded_hypothesis = self._input_embedder(hypothesis) mask_premise = get_text_field_mask(premise).float() mask_hypothesis = get_text_field_mask(hypothesis).float() pooled_premise = self._pooler(embedded_premise, mask=mask_premise) pooled_hypothesis = self._pooler(embedded_hypothesis, mask=mask_hypothesis) pooled_combined = torch.cat([pooled_premise, pooled_hypothesis], dim=-1) else: raise ConfigurationError( "One of premise or hypothesis is None. Check your DatasetReader" ) pooled_combined = self._dropout(pooled_combined) logits = self._nli_projection_layer(pooled_combined) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._nli_per_lang_acc[taskname](logits, label) if metadata is not None: output_dict["premise_tokens"] = [ x["premise_tokens"] for x in metadata ] output_dict["hypothesis_tokens"] = [ x["hypothesis_tokens"] for x in metadata ] return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the probabilities, converts index to string label, and add ``"label"`` key to the dictionary with the result. """ predictions = output_dict["probs"] if predictions.dim() == 2: predictions_list = [ predictions[i] for i in range(predictions.shape[0]) ] else: predictions_list = [predictions] classes = [] for prediction in predictions_list: label_idx = prediction.argmax(dim=-1).item() label_str = (self.vocab.get_index_to_token_vocabulary( self._label_namespace).get(label_idx, str(label_idx))) classes.append(label_str) output_dict["label"] = classes return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = {} if self.training: tasks = self._training_tasks else: tasks = self._validation_tasks for taskname in tasks: metricname = taskname if metricname[-2:] != 'en' or metricname[-2:] != 'de' or metricname[ -2:] != 'ru': # hide other langs from tqdn metricname = '_' + metricname metrics[metricname] = self._nli_per_lang_acc[taskname].get_metric( reset) accs = metrics.values() # TODO: should only count 'nli-*' metrics avg = sum(accs) / sum(x > 0 for x in accs) self._nli_avg_acc(avg) metrics["nli-avg"] = self._nli_avg_acc.get_metric(reset) return metrics