コード例 #1
0
def load():
    global test_tables
    test_table_file = '../data/val.tables.json'
    bert_model_path = '../model'
    test_tables = read_tables(test_table_file)
    paths = get_checkpoint_paths(bert_model_path)
    global label_encoder
    label_encoder = SqlLabelEncoder()
    global query_tokenizer
    model, query_tokenizer = construct_model(paths)
    model_path = '../task1_best_model.h5'
    model.load_weights(model_path)
    global tokenizer
    model2, tokenizer = construct_model2(paths)
    model2.load_weights('../model_best_weights.h5')
    global models
    models = {}
    models['stage1'] = model
    models['stage2'] = model2
    global graph
    graph = tf.get_default_graph()
コード例 #2
0
ファイル: model2.py プロジェクト: GongCQ/tianchi_nl2sql
test_table_file = '../data/test/test.tables.json'
test_data_file = '../data/test/test.json'

# Download pretrained BERT model from https://github.com/ymcui/Chinese-BERT-wwm
bert_model_path = '../model/chinese_wwm_L-12_H-768_A-12'

paths = get_checkpoint_paths(bert_model_path)

task1_file = '../submit/task1_output.json'

# ## Read Data

# In[ ]:

train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)

val_tables = read_tables(val_table_file)
val_data = read_data(val_data_file, val_tables)

test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)

# ## Build Dataset

# In[ ]:


def is_float(value):
    try:
コード例 #3
0
import keras.backend as K
from keras.layers import Input, Dense, Lambda, Multiply, Masking, Concatenate
from keras.models import Model
from keras.preprocessing.sequence import pad_sequences
from keras.callbacks import Callback, ModelCheckpoint
from keras.utils.data_utils import Sequence
from keras.utils import multi_gpu_model

from nl2sql.utils import read_data, read_line, read_tables, SQL, MultiSentenceTokenizer, Query, Question, Table
from nl2sql.utils.optimizer import RAdam
from dbengine import DBEngine
from build_model import construct_model, construct_model2, outputs_to_sqls, SqlLabelEncoder, DataSequence, QuestionCondPairsDataset, QuestionCondPairsDataseq, merge_result, CandidateCondsExtractor, FullSampler

test_table_file = './data/val.tables.json'
bert_model_path = './model'
test_tables = read_tables(test_table_file)
paths = get_checkpoint_paths(bert_model_path)
model, query_tokenizer = construct_model(paths)
model_path = 'task1_best_model.h5'
model.load_weights(model_path)
model2, tokenizer = construct_model2(paths)
model2.load_weights('model_best_weights.h5')

label_encoder = SqlLabelEncoder()

test_json_line = '{"question": "长沙2011年平均每天成交量是3.17,那么近一周的成交量是多少", "table_id": "69cc8c0c334311e98692542696d6e445", "sql": {"agg": [0], "cond_conn_op": 1, "sel": [5], "conds": [[1, 2, "3.17"], [0, 2, "长沙"]]}}'

test_data = read_line(test_json_line, test_tables)

test_dataseq = DataSequence(data=test_data,
                            tokenizer=query_tokenizer,