def __init__( self, device, preproc, # rule_emb_size=128, node_embed_size=64, # TODO: This should be automatically inferred from encoder enc_recurrent_size=256, recurrent_size=256, dropout=0., desc_attn='bahdanau', copy_pointer=None, multi_loss_type='logsumexp', sup_att=None, use_align_mat=False, use_align_loss=False, enumerate_order=False, loss_type="softmax"): super().__init__() self._device = device self.preproc = preproc self.ast_wrapper = preproc.ast_wrapper self.terminal_vocab = preproc.vocab self.rule_emb_size = rule_emb_size self.node_emb_size = node_embed_size self.enc_recurrent_size = enc_recurrent_size self.recurrent_size = recurrent_size self.rules_index = { v: idx for idx, v in enumerate(self.preproc.all_rules) } self.use_align_mat = use_align_mat self.use_align_loss = use_align_loss self.enumerate_order = enumerate_order if use_align_mat: from ratsql.models.spider import spider_dec_func self.compute_align_loss = lambda *args: \ spider_dec_func.compute_align_loss(self, *args) self.compute_pointer_with_align = lambda *args: \ spider_dec_func.compute_pointer_with_align(self, *args) if self.preproc.use_seq_elem_rules: self.node_type_vocab = vocab.Vocab( sorted(self.preproc.primitive_types) + sorted(self.ast_wrapper.custom_primitive_types) + sorted(self.preproc.sum_type_constructors.keys()) + sorted(self.preproc.field_presence_infos.keys()) + sorted(self.preproc.seq_lengths.keys()), special_elems=()) else: self.node_type_vocab = vocab.Vocab( sorted(self.preproc.primitive_types) + sorted(self.ast_wrapper.custom_primitive_types) + sorted(self.ast_wrapper.sum_types.keys()) + sorted(self.ast_wrapper.singular_types.keys()) + sorted(self.preproc.seq_lengths.keys()), special_elems=()) self.state_update = variational_lstm.RecurrentDropoutLSTMCell( input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size + self.node_emb_size, hidden_size=self.recurrent_size, dropout=dropout) self.attn_type = desc_attn if desc_attn == 'bahdanau': self.desc_attn = attention.BahdanauAttention( query_size=self.recurrent_size, value_size=self.enc_recurrent_size, proj_size=50) elif desc_attn == 'mha': self.desc_attn = attention.MultiHeadedAttention( h=8, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) elif desc_attn == 'mha-1h': self.desc_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) elif desc_attn == 'sep': self.question_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) self.schema_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) else: # TODO: Figure out how to get right sizes (query, value) to module self.desc_attn = desc_attn self.sup_att = sup_att self.rule_logits = torch.nn.Sequential( torch.nn.Linear(self.recurrent_size, self.rule_emb_size), torch.nn.Tanh(), torch.nn.Linear(self.rule_emb_size, len(self.rules_index))) self.rule_embedding = torch.nn.Embedding( num_embeddings=len(self.rules_index), embedding_dim=self.rule_emb_size) self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1) self.terminal_logits = torch.nn.Sequential( torch.nn.Linear(self.recurrent_size, self.rule_emb_size), torch.nn.Tanh(), torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab))) self.terminal_embedding = torch.nn.Embedding( num_embeddings=len(self.terminal_vocab), embedding_dim=self.rule_emb_size) if copy_pointer is None: self.copy_pointer = attention.BahdanauPointer( query_size=self.recurrent_size, key_size=self.enc_recurrent_size, proj_size=50) else: # TODO: Figure out how to get right sizes (query, key) to module self.copy_pointer = copy_pointer if multi_loss_type == 'logsumexp': self.multi_loss_reduction = lambda logprobs: -torch.logsumexp( logprobs, dim=1) elif multi_loss_type == 'mean': self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1) self.pointers = torch.nn.ModuleDict() self.pointer_action_emb_proj = torch.nn.ModuleDict() for pointer_type in self.preproc.grammar.pointers: self.pointers[pointer_type] = attention.ScaledDotProductPointer( query_size=self.recurrent_size, key_size=self.enc_recurrent_size) self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear( self.enc_recurrent_size, self.rule_emb_size) self.node_type_embedding = torch.nn.Embedding( num_embeddings=len(self.node_type_vocab), embedding_dim=self.node_emb_size) # TODO batching self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device) self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device) if loss_type == "softmax": self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none') elif loss_type == "entmax": self.xent_loss = entmax.entmax15_loss elif loss_type == "sparsemax": self.xent_loss = entmax.sparsemax_loss elif loss_type == "label_smooth": self.xent_loss = self.label_smooth_loss
def __init__( self, device, preproc, grammar_path, rule_emb_size=128, node_embed_size=64, # TODO: This should be automatically inferred from encoder enc_recurrent_size=256, recurrent_size=256, dropout=0., desc_attn='bahdanau', copy_pointer=None, multi_loss_type='logsumexp', sup_att=None, use_align_mat=False, use_align_loss=False, enumerate_order=False, loss_type="softmax"): super().__init__() self._device = device self.preproc = preproc self.ast_wrapper = preproc.ast_wrapper self.terminal_vocab = preproc.vocab self.preproc.primitive_types.append("singleton") if self.preproc.use_seq_elem_rules: self.node_type_vocab = vocab.Vocab( sorted(self.preproc.primitive_types) + sorted(self.ast_wrapper.custom_primitive_types) + sorted(self.preproc.sum_type_constructors.keys()) + sorted(self.preproc.field_presence_infos.keys()) + sorted(self.preproc.seq_lengths.keys()), special_elems=()) else: self.node_type_vocab = vocab.Vocab( sorted(self.preproc.primitive_types) + sorted(self.ast_wrapper.custom_primitive_types) + sorted(self.ast_wrapper.sum_types.keys()) + sorted(self.ast_wrapper.singular_types.keys()) + sorted(self.preproc.seq_lengths.keys()), special_elems=()) self.all_rules, self.rules_index, self.parent_to_preterminal, self.preterminal_mask, self.preterminal_debug, \ self.preterminal_types, self.parent_to_hc, self.hc_table, self.hc_debug, self.parent_to_head, \ self.parent_to_rule = self.compute_rule_masks(grammar_path) # json.dump(dict(self.parent_to_preterminal), open('data/spider/head-corner-glove,cv_link=true/p.json')) #json.dump({"parent_to_preterminal": dict(self.parent_to_preterminal), # "preterminal_mask": dict(self.preterminal_mask), # "parent_to_hc": {key: sorted(list(self.parent_to_hc[key])) for key in self.parent_to_hc}, # "hc_table": {key: dict(self.hc_table[key]) for key in self.hc_table}, # "parent_to_head": dict(self.parent_to_head), # "node_type_vocab_e2i": dict(self.node_type_vocab.elem_to_id), # "node_type_vocab_i2e": dict(self.node_type_vocab.id_to_elem), # # "terminal_vocab": self.terminal_vocab, # # "rules_index": self.rules_index, # "parent_to_rule": dict(self.parent_to_rule), # }, # open('data/spider/head-corner-glove,cv_link=true/head_corner_elems.json', 'w')) self.rule_emb_size = rule_emb_size self.node_emb_size = node_embed_size self.enc_recurrent_size = enc_recurrent_size self.recurrent_size = recurrent_size self.use_align_mat = use_align_mat self.use_align_loss = use_align_loss self.enumerate_order = enumerate_order if use_align_mat: from ratsql.models.spider import spider_dec_func self.compute_align_loss = lambda *args: \ spider_dec_func.compute_align_loss(self, *args) self.compute_pointer_with_align = lambda *args: \ spider_dec_func.compute_pointer_with_align_head_corner(self, *args) self.state_update = variational_lstm.RecurrentDropoutLSTMCell( input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size * 2 + self.node_emb_size, hidden_size=self.recurrent_size, dropout=dropout) self.attn_type = desc_attn if desc_attn == 'bahdanau': self.desc_attn = attention.BahdanauAttention( query_size=self.recurrent_size, value_size=self.enc_recurrent_size, proj_size=50) elif desc_attn == 'mha': self.desc_attn = attention.MultiHeadedAttention( h=8, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) elif desc_attn == 'mha-1h': self.desc_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) elif desc_attn == 'sep': self.question_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) self.schema_attn = attention.MultiHeadedAttention( h=1, query_size=self.recurrent_size, value_size=self.enc_recurrent_size) else: # TODO: Figure out how to get right sizes (query, value) to module self.desc_attn = desc_attn self.sup_att = sup_att self.rule_logits = torch.nn.Sequential( torch.nn.Linear(self.recurrent_size, self.rule_emb_size), torch.nn.Tanh(), torch.nn.Linear(self.rule_emb_size, len(self.rules_index))) self.rule_embedding = torch.nn.Embedding( num_embeddings=len(self.rules_index), embedding_dim=self.rule_emb_size) self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1) self.terminal_logits = torch.nn.Sequential( torch.nn.Linear(self.recurrent_size, self.rule_emb_size), torch.nn.Tanh(), torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab))) self.terminal_embedding = torch.nn.Embedding( num_embeddings=len(self.terminal_vocab), embedding_dim=self.rule_emb_size) if copy_pointer is None: self.copy_pointer = attention.BahdanauPointer( query_size=self.recurrent_size, key_size=self.enc_recurrent_size, proj_size=50) else: # TODO: Figure out how to get right sizes (query, key) to module self.copy_pointer = copy_pointer if multi_loss_type == 'logsumexp': self.multi_loss_reduction = lambda logprobs: -torch.logsumexp( logprobs, dim=1) elif multi_loss_type == 'mean': self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1) self.pointers = torch.nn.ModuleDict() self.pointer_action_emb_proj = torch.nn.ModuleDict() for pointer_type in self.preproc.grammar.pointers: self.pointers[pointer_type] = attention.ScaledDotProductPointer( query_size=self.recurrent_size, key_size=self.enc_recurrent_size) self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear( self.enc_recurrent_size, self.rule_emb_size) self.node_type_embedding = torch.nn.Embedding( num_embeddings=len(self.node_type_vocab), embedding_dim=self.node_emb_size) # TODO batching self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device) self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device) if loss_type == "softmax": self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none') elif loss_type == "entmax": self.xent_loss = entmax.entmax15_loss elif loss_type == "sparsemax": self.xent_loss = entmax.sparsemax_loss elif loss_type == "label_smooth": self.xent_loss = self.label_smooth_loss self.goals = None self.head_corners = None self.operation = None