Example #1
0
    def __init__(
            self,
            device,
            preproc,
            #
            rule_emb_size=128,
            node_embed_size=64,
            # TODO: This should be automatically inferred from encoder
            enc_recurrent_size=256,
            recurrent_size=256,
            dropout=0.,
            desc_attn='bahdanau',
            copy_pointer=None,
            multi_loss_type='logsumexp',
            sup_att=None,
            use_align_mat=False,
            use_align_loss=False,
            enumerate_order=False,
            loss_type="softmax"):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.ast_wrapper = preproc.ast_wrapper
        self.terminal_vocab = preproc.vocab

        self.rule_emb_size = rule_emb_size
        self.node_emb_size = node_embed_size
        self.enc_recurrent_size = enc_recurrent_size
        self.recurrent_size = recurrent_size

        self.rules_index = {
            v: idx
            for idx, v in enumerate(self.preproc.all_rules)
        }
        self.use_align_mat = use_align_mat
        self.use_align_loss = use_align_loss
        self.enumerate_order = enumerate_order

        if use_align_mat:
            from ratsql.models.spider import spider_dec_func
            self.compute_align_loss = lambda *args: \
                spider_dec_func.compute_align_loss(self, *args)
            self.compute_pointer_with_align = lambda *args: \
                spider_dec_func.compute_pointer_with_align(self, *args)

        if self.preproc.use_seq_elem_rules:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.preproc.sum_type_constructors.keys()) +
                sorted(self.preproc.field_presence_infos.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())
        else:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.ast_wrapper.sum_types.keys()) +
                sorted(self.ast_wrapper.singular_types.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())

        self.state_update = variational_lstm.RecurrentDropoutLSTMCell(
            input_size=self.rule_emb_size * 2 + self.enc_recurrent_size +
            self.recurrent_size + self.node_emb_size,
            hidden_size=self.recurrent_size,
            dropout=dropout)

        self.attn_type = desc_attn
        if desc_attn == 'bahdanau':
            self.desc_attn = attention.BahdanauAttention(
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size,
                proj_size=50)
        elif desc_attn == 'mha':
            self.desc_attn = attention.MultiHeadedAttention(
                h=8,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'mha-1h':
            self.desc_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'sep':
            self.question_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
            self.schema_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        else:
            # TODO: Figure out how to get right sizes (query, value) to module
            self.desc_attn = desc_attn
        self.sup_att = sup_att

        self.rule_logits = torch.nn.Sequential(
            torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.rule_emb_size, len(self.rules_index)))
        self.rule_embedding = torch.nn.Embedding(
            num_embeddings=len(self.rules_index),
            embedding_dim=self.rule_emb_size)

        self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1)
        self.terminal_logits = torch.nn.Sequential(
            torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
        self.terminal_embedding = torch.nn.Embedding(
            num_embeddings=len(self.terminal_vocab),
            embedding_dim=self.rule_emb_size)
        if copy_pointer is None:
            self.copy_pointer = attention.BahdanauPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size,
                proj_size=50)
        else:
            # TODO: Figure out how to get right sizes (query, key) to module
            self.copy_pointer = copy_pointer
        if multi_loss_type == 'logsumexp':
            self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(
                logprobs, dim=1)
        elif multi_loss_type == 'mean':
            self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs,
                                                                     dim=1)

        self.pointers = torch.nn.ModuleDict()
        self.pointer_action_emb_proj = torch.nn.ModuleDict()
        for pointer_type in self.preproc.grammar.pointers:
            self.pointers[pointer_type] = attention.ScaledDotProductPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size)
            self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear(
                self.enc_recurrent_size, self.rule_emb_size)

        self.node_type_embedding = torch.nn.Embedding(
            num_embeddings=len(self.node_type_vocab),
            embedding_dim=self.node_emb_size)

        # TODO batching
        self.zero_rule_emb = torch.zeros(1,
                                         self.rule_emb_size,
                                         device=self._device)
        self.zero_recurrent_emb = torch.zeros(1,
                                              self.recurrent_size,
                                              device=self._device)
        if loss_type == "softmax":
            self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none')
        elif loss_type == "entmax":
            self.xent_loss = entmax.entmax15_loss
        elif loss_type == "sparsemax":
            self.xent_loss = entmax.sparsemax_loss
        elif loss_type == "label_smooth":
            self.xent_loss = self.label_smooth_loss
Example #2
0
    def __init__(
            self,
            device,
            preproc,
            grammar_path,
            rule_emb_size=128,
            node_embed_size=64,
            # TODO: This should be automatically inferred from encoder
            enc_recurrent_size=256,
            recurrent_size=256,
            dropout=0.,
            desc_attn='bahdanau',
            copy_pointer=None,
            multi_loss_type='logsumexp',
            sup_att=None,
            use_align_mat=False,
            use_align_loss=False,
            enumerate_order=False,
            loss_type="softmax"):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.ast_wrapper = preproc.ast_wrapper
        self.terminal_vocab = preproc.vocab
        self.preproc.primitive_types.append("singleton")

        if self.preproc.use_seq_elem_rules:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.preproc.sum_type_constructors.keys()) +
                sorted(self.preproc.field_presence_infos.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())
        else:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.ast_wrapper.sum_types.keys()) +
                sorted(self.ast_wrapper.singular_types.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())

        self.all_rules, self.rules_index, self.parent_to_preterminal, self.preterminal_mask, self.preterminal_debug, \
        self.preterminal_types, self.parent_to_hc, self.hc_table, self.hc_debug, self.parent_to_head, \
                    self.parent_to_rule = self.compute_rule_masks(grammar_path)

        # json.dump(dict(self.parent_to_preterminal), open('data/spider/head-corner-glove,cv_link=true/p.json'))
        #json.dump({"parent_to_preterminal": dict(self.parent_to_preterminal),
        #           "preterminal_mask": dict(self.preterminal_mask),
        #           "parent_to_hc": {key: sorted(list(self.parent_to_hc[key])) for key in self.parent_to_hc},
        #           "hc_table": {key: dict(self.hc_table[key]) for key in self.hc_table},
        #           "parent_to_head": dict(self.parent_to_head),
        #           "node_type_vocab_e2i": dict(self.node_type_vocab.elem_to_id),
        #           "node_type_vocab_i2e": dict(self.node_type_vocab.id_to_elem),
        #           # "terminal_vocab": self.terminal_vocab,
        #           # "rules_index": self.rules_index,
        #           "parent_to_rule": dict(self.parent_to_rule),
        #            },
        #          open('data/spider/head-corner-glove,cv_link=true/head_corner_elems.json', 'w'))

        self.rule_emb_size = rule_emb_size
        self.node_emb_size = node_embed_size
        self.enc_recurrent_size = enc_recurrent_size
        self.recurrent_size = recurrent_size

        self.use_align_mat = use_align_mat
        self.use_align_loss = use_align_loss
        self.enumerate_order = enumerate_order

        if use_align_mat:
            from ratsql.models.spider import spider_dec_func
            self.compute_align_loss = lambda *args: \
                spider_dec_func.compute_align_loss(self, *args)
            self.compute_pointer_with_align = lambda *args: \
                spider_dec_func.compute_pointer_with_align_head_corner(self, *args)

        self.state_update = variational_lstm.RecurrentDropoutLSTMCell(
            input_size=self.rule_emb_size * 2 + self.enc_recurrent_size +
            self.recurrent_size * 2 + self.node_emb_size,
            hidden_size=self.recurrent_size,
            dropout=dropout)

        self.attn_type = desc_attn
        if desc_attn == 'bahdanau':
            self.desc_attn = attention.BahdanauAttention(
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size,
                proj_size=50)
        elif desc_attn == 'mha':
            self.desc_attn = attention.MultiHeadedAttention(
                h=8,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'mha-1h':
            self.desc_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'sep':
            self.question_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
            self.schema_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        else:
            # TODO: Figure out how to get right sizes (query, value) to module
            self.desc_attn = desc_attn
        self.sup_att = sup_att
        self.rule_logits = torch.nn.Sequential(
            torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.rule_emb_size, len(self.rules_index)))
        self.rule_embedding = torch.nn.Embedding(
            num_embeddings=len(self.rules_index),
            embedding_dim=self.rule_emb_size)

        self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1)
        self.terminal_logits = torch.nn.Sequential(
            torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
        self.terminal_embedding = torch.nn.Embedding(
            num_embeddings=len(self.terminal_vocab),
            embedding_dim=self.rule_emb_size)
        if copy_pointer is None:
            self.copy_pointer = attention.BahdanauPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size,
                proj_size=50)
        else:
            # TODO: Figure out how to get right sizes (query, key) to module
            self.copy_pointer = copy_pointer
        if multi_loss_type == 'logsumexp':
            self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(
                logprobs, dim=1)
        elif multi_loss_type == 'mean':
            self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs,
                                                                     dim=1)

        self.pointers = torch.nn.ModuleDict()
        self.pointer_action_emb_proj = torch.nn.ModuleDict()
        for pointer_type in self.preproc.grammar.pointers:
            self.pointers[pointer_type] = attention.ScaledDotProductPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size)
            self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear(
                self.enc_recurrent_size, self.rule_emb_size)

        self.node_type_embedding = torch.nn.Embedding(
            num_embeddings=len(self.node_type_vocab),
            embedding_dim=self.node_emb_size)

        # TODO batching
        self.zero_rule_emb = torch.zeros(1,
                                         self.rule_emb_size,
                                         device=self._device)
        self.zero_recurrent_emb = torch.zeros(1,
                                              self.recurrent_size,
                                              device=self._device)
        if loss_type == "softmax":
            self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none')
        elif loss_type == "entmax":
            self.xent_loss = entmax.entmax15_loss
        elif loss_type == "sparsemax":
            self.xent_loss = entmax.sparsemax_loss
        elif loss_type == "label_smooth":
            self.xent_loss = self.label_smooth_loss

        self.goals = None
        self.head_corners = None
        self.operation = None