예제 #1
0
파일: converter.py 프로젝트: vzhong/gazp
 def __init__(self, tables='data/spider/tables.json', db='data/database'):
     self.tables = tables
     self.db = db
     self.schema_tokens, self.column_names, self.database_schemas = editsql_preprocess.read_database_schema(
         tables, {}, {}, {})
     self.kmaps = evaluation.build_foreign_key_map_from_json(tables)
     self.evaluator = evaluation.Evaluator()
예제 #2
0
 def __init__(self, args, ext, remove_invalid=True):
     self.args = args
     self.remove_invalid = remove_invalid
     if self.remove_invalid:
         self.conv = converter.Converter(tables=getattr(
             args, 'tables', 'data/spider/tables'),
                                         db=getattr(args, 'db',
                                                    'data/database'))
         self.sql_vocab = ext['sql_voc']
         self.evaluator = evaluation.Evaluator()
예제 #3
0
파일: nl2sql.py 프로젝트: vzhong/gazp
    def __init__(self, args, ext):
        super().__init__(args, ext)
        self.conv = converter.Converter(tables=getattr(args, 'tables',
                                                       'data/spider/tables'),
                                        db=getattr(args, 'db',
                                                   'data/database'))
        self.bert_tokenizer = DistilBertTokenizer.from_pretrained(
            args.dcache + '/vocab.txt', cache_dir=args.dcache)
        self.bert_embedder = DistilBertModel.from_pretrained(
            args.dcache, cache_dir=args.dcache)
        self.value_bert_embedder = DistilBertModel.from_pretrained(
            args.dcache, cache_dir=args.dcache)
        self.denc = 768
        self.demb = args.demb
        self.sql_vocab = ext['sql_voc']
        self.sql_emb = nn.Embedding.from_pretrained(ext['sql_emb'],
                                                    freeze=False)
        self.pad_id = self.sql_vocab.word2index('PAD')

        self.dropout = nn.Dropout(args.dropout)
        self.bert_dropout = nn.Dropout(args.bert_dropout)
        self.table_sa_scorer = nn.Linear(self.denc, 1)
        self.col_sa_scorer = nn.Linear(self.denc, 1)
        self.col_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)
        self.table_trans = nn.LSTM(self.denc,
                                   args.drnn,
                                   bidirectional=True,
                                   batch_first=True)
        self.pointer_decoder = decoder.PointerDecoder(
            demb=self.demb,
            denc=2 * args.drnn,
            ddec=args.drnn,
            dropout=args.dec_dropout,
            num_layers=args.num_layers)

        self.utt_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)
        self.value_decoder = decoder.PointerDecoder(demb=self.demb,
                                                    denc=self.denc,
                                                    ddec=args.drnn,
                                                    dropout=args.dec_dropout,
                                                    num_layers=args.num_layers)

        self.evaluator = evaluation.Evaluator()
        if 'reranker' in ext:
            self.reranker = ext['reranker']
        else:
            self.reranker = rank_max.Module(args, ext, remove_invalid=True)
예제 #4
0
    def __init__(self, args, ext):
        super().__init__(args, ext)
        self.database_schemas = ext['database_schemas']
        self.database_content = ext['db_content']
        self.kmaps = ext['kmaps']
        self.bert_tokenizer = DistilBertTokenizer.from_pretrained(
            preprocess.BERT_MODEL, cache_dir=args.dcache)
        self.bert_embedder = DistilBertModel.from_pretrained(
            preprocess.BERT_MODEL, cache_dir=args.dcache)
        self.denc = 768
        self.demb = args.demb
        self.sql_vocab = ext['sql_voc']
        self.sql_emb = nn.Embedding.from_pretrained(ext['sql_emb'],
                                                    freeze=False)
        self.pad_id = self.sql_vocab.word2index('PAD')

        self.dropout = nn.Dropout(args.dropout)
        self.bert_dropout = nn.Dropout(args.bert_dropout)
        self.table_sa_scorer = nn.Linear(self.denc, 1)
        self.col_sa_scorer = nn.Linear(self.denc, 1)
        self.col_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)
        self.table_trans = nn.LSTM(self.denc,
                                   args.drnn,
                                   bidirectional=True,
                                   batch_first=True)
        self.pointer_decoder = decoder.PointerDecoder(
            demb=self.demb,
            denc=2 * args.drnn,
            ddec=args.drnn,
            dropout=args.dec_dropout,
            num_layers=args.num_layers)

        self.utt_trans = nn.LSTM(self.denc,
                                 self.demb // 2,
                                 bidirectional=True,
                                 batch_first=True)

        self.evaluator = evaluation.Evaluator()