def __init__(self, tables='data/spider/tables.json', db='data/database'): self.tables = tables self.db = db self.schema_tokens, self.column_names, self.database_schemas = editsql_preprocess.read_database_schema( tables, {}, {}, {}) self.kmaps = evaluation.build_foreign_key_map_from_json(tables) self.evaluator = evaluation.Evaluator()
def __init__(self, args, ext, remove_invalid=True): self.args = args self.remove_invalid = remove_invalid if self.remove_invalid: self.conv = converter.Converter(tables=getattr( args, 'tables', 'data/spider/tables'), db=getattr(args, 'db', 'data/database')) self.sql_vocab = ext['sql_voc'] self.evaluator = evaluation.Evaluator()
def __init__(self, args, ext): super().__init__(args, ext) self.conv = converter.Converter(tables=getattr(args, 'tables', 'data/spider/tables'), db=getattr(args, 'db', 'data/database')) self.bert_tokenizer = DistilBertTokenizer.from_pretrained( args.dcache + '/vocab.txt', cache_dir=args.dcache) self.bert_embedder = DistilBertModel.from_pretrained( args.dcache, cache_dir=args.dcache) self.value_bert_embedder = DistilBertModel.from_pretrained( args.dcache, cache_dir=args.dcache) self.denc = 768 self.demb = args.demb self.sql_vocab = ext['sql_voc'] self.sql_emb = nn.Embedding.from_pretrained(ext['sql_emb'], freeze=False) self.pad_id = self.sql_vocab.word2index('PAD') self.dropout = nn.Dropout(args.dropout) self.bert_dropout = nn.Dropout(args.bert_dropout) self.table_sa_scorer = nn.Linear(self.denc, 1) self.col_sa_scorer = nn.Linear(self.denc, 1) self.col_trans = nn.LSTM(self.denc, self.demb // 2, bidirectional=True, batch_first=True) self.table_trans = nn.LSTM(self.denc, args.drnn, bidirectional=True, batch_first=True) self.pointer_decoder = decoder.PointerDecoder( demb=self.demb, denc=2 * args.drnn, ddec=args.drnn, dropout=args.dec_dropout, num_layers=args.num_layers) self.utt_trans = nn.LSTM(self.denc, self.demb // 2, bidirectional=True, batch_first=True) self.value_decoder = decoder.PointerDecoder(demb=self.demb, denc=self.denc, ddec=args.drnn, dropout=args.dec_dropout, num_layers=args.num_layers) self.evaluator = evaluation.Evaluator() if 'reranker' in ext: self.reranker = ext['reranker'] else: self.reranker = rank_max.Module(args, ext, remove_invalid=True)
def __init__(self, args, ext): super().__init__(args, ext) self.database_schemas = ext['database_schemas'] self.database_content = ext['db_content'] self.kmaps = ext['kmaps'] self.bert_tokenizer = DistilBertTokenizer.from_pretrained( preprocess.BERT_MODEL, cache_dir=args.dcache) self.bert_embedder = DistilBertModel.from_pretrained( preprocess.BERT_MODEL, cache_dir=args.dcache) self.denc = 768 self.demb = args.demb self.sql_vocab = ext['sql_voc'] self.sql_emb = nn.Embedding.from_pretrained(ext['sql_emb'], freeze=False) self.pad_id = self.sql_vocab.word2index('PAD') self.dropout = nn.Dropout(args.dropout) self.bert_dropout = nn.Dropout(args.bert_dropout) self.table_sa_scorer = nn.Linear(self.denc, 1) self.col_sa_scorer = nn.Linear(self.denc, 1) self.col_trans = nn.LSTM(self.denc, self.demb // 2, bidirectional=True, batch_first=True) self.table_trans = nn.LSTM(self.denc, args.drnn, bidirectional=True, batch_first=True) self.pointer_decoder = decoder.PointerDecoder( demb=self.demb, denc=2 * args.drnn, ddec=args.drnn, dropout=args.dec_dropout, num_layers=args.num_layers) self.utt_trans = nn.LSTM(self.denc, self.demb // 2, bidirectional=True, batch_first=True) self.evaluator = evaluation.Evaluator()