예제 #1
0
def load_model(models_path, glove_path, toy=False):
    ### CONFIGURABLE
    GPU = True  # GPU activated
    B_word = 42  # GloVE corpus size
    N_word = 300  # word embedding dimension
    N_h = 300  # hidden layer size
    N_depth = 2  # num LSTM layers

    print("Loading GloVE word embeddings...")
    word_emb = load_word_emb('{}/glove.{}B.{}d.txt'.format(
        glove_path, B_word, N_word),
                             load_used=False,
                             use_small=toy)

    model = SuperModel(word_emb,
                       N_word=N_word,
                       gpu=GPU,
                       trainable_emb=False,
                       table_type='std',
                       use_hs=True)

    print("Loading trained models...")
    model.multi_sql.load_state_dict(
        torch.load("{}/multi_sql_models.dump".format(models_path)))
    model.key_word.load_state_dict(
        torch.load("{}/keyword_models.dump".format(models_path)))
    model.col.load_state_dict(
        torch.load("{}/col_models.dump".format(models_path)))
    model.op.load_state_dict(
        torch.load("{}/op_models.dump".format(models_path)))
    model.agg.load_state_dict(
        torch.load("{}/agg_models.dump".format(models_path)))
    model.root_teminal.load_state_dict(
        torch.load("{}/root_tem_models.dump".format(models_path)))
    model.des_asc.load_state_dict(
        torch.load("{}/des_asc_models.dump".format(models_path)))
    model.having.load_state_dict(
        torch.load("{}/having_models.dump".format(models_path)))
    return model
예제 #2
0
파일: test.py 프로젝트: ygan/syntaxSQL
    # TRAIN_ENTRY=(False, True, False)  # (AGG, SEL, COND)
    # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
    learning_rate = 1e-4

    #TODO
    data = json.load(open(args.test_data_path))
    # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history)

    # word_emb = None
    word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
            load_used=args.train_emb, use_small=USE_SMALL)

    # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history)
    #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt")

    model = SuperModel(word_emb, N_word=N_word, gpu=GPU, trainable_emb = args.train_emb, table_type=args.table_type, use_hs=use_hs)

    # agg_m, sel_m, cond_m = best_model_name(args)
    # torch.save(model.state_dict(), "saved_models/{}_models.dump".format(args.train_component))

    print("Loading from modules...")
    model.multi_sql.load_state_dict(torch.load("{}/multi_sql_models.dump".format(args.models)))
    model.key_word.load_state_dict(torch.load("{}/keyword_models.dump".format(args.models)))
    model.col.load_state_dict(torch.load("{}/col_models.dump".format(args.models)))
    model.op.load_state_dict(torch.load("{}/op_models.dump".format(args.models)))
    model.agg.load_state_dict(torch.load("{}/agg_models.dump".format(args.models)))
    model.root_teminal.load_state_dict(torch.load("{}/root_tem_models.dump".format(args.models)))
    model.des_asc.load_state_dict(torch.load("{}/des_asc_models.dump".format(args.models)))
    model.having.load_state_dict(torch.load("{}/having_models.dump".format(args.models)))

    test_acc(model, BATCH_SIZE, data, args.output_path)
    # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
    learning_rate = 1e-4

    #TODO
    data = json.load(open(args.test_data_path))
    # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history)

    word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
            load_used=args.train_emb, use_small=USE_SMALL)
    # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history)
    #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt")

    model = SuperModel(args.data_root,
                       word_emb,
                       N_word=N_word,
                       gpu=GPU,
                       trainable_emb=args.train_emb,
                       table_type=args.table_type,
                       use_hs=use_hs,
                       feats_format=args.feats_format)

    # agg_m, sel_m, cond_m = best_model_name(args)
    # torch.save(model.state_dict(), "saved_models/{}_models.dump".format(args.train_component))

    print "Loading from modules..."
    model.multi_sql.load_state_dict(
        torch.load("{}/multi_sql_models.dump".format(args.models)))
    model.key_word.load_state_dict(
        torch.load("{}/keyword_models.dump".format(args.models)))
    model.col.load_state_dict(
        torch.load("{}/col_models.dump".format(args.models)))
    model.op.load_state_dict(
예제 #4
0
def infer_script(nlq, db_name, toy, word_emb):
    """
    Arguments:
        nlq: english question (tokenization is done here)
        db_name: name of the database the query targets
        toy: uses a small example of word embeddings to debug faster

    """

    SAVED_MODELS_FOLDER = "saved_models"
    OUTPUT_PATH = "output_inference.txt"
    HISTORY_TYPE = "full"
    GPU_ENABLE = False
    TRAIN_EMB = False
    TABLE_TYPE = "std"
    LOAD_USED_W2I = False

    use_hs = True
    if HISTORY_TYPE == "no":
        HISTORY_TYPE = "full"
        use_hs = False

    N_word = 300
    B_word = 42
    N_h = 300
    N_depth = 2  # not used in test.py

    if toy:
        USE_SMALL = True
    else:
        USE_SMALL = False

    GPU = GPU_ENABLE
    BATCH_SIZE = 1  #64

    # QUESTION TOKENIZATION
    tok_q = tokenize(nlq)
    # print("tokenized question: {}".format(tokenize("What are the maximum and minimum budget of the departments?")))

    # Natural language question and database reading
    nlq = {'db_id': db_name, 'question_toks': tok_q}
    print("nlq: {}".format(nlq))

    db_id = nlq["db_id"]

    table_dict = get_table_dict("./data/spider/tables.json")
    # table_dict = table_dict[db_id] # subset table dict to the specified database
    # table_dict[db_id] = {'column_names': [[-1, '*'], [0, 'department id'], [0, 'name'], [0, 'creation'], [0, 'ranking'], [0, 'budget in billions'], [0, 'num employees'], [1, 'head id'], [1, 'name'], [1, 'born state'], [1, 'age'], [2, 'department id'], [2, 'head id'], [2, 'temporary acting']],
    #         'column_names_original': [[-1, '*'], [0, 'Department_ID'], [0, 'Name'], [0, 'Creation'], [0, 'Ranking'], [0, 'Budget_in_Billions'], [0, 'Num_Employees'], [1, 'head_ID'], [1, 'name'], [1, 'born_state'], [1, 'age'], [2, 'department_ID'], [2, 'head_ID'], [2, 'temporary_acting']],
    #         'column_types': ['text', 'number', 'text', 'text', 'number', 'number', 'number', 'number', 'text', 'text', 'number', 'number', 'number', 'text'],
    #         'db_id': 'department_management',
    #         'foreign_keys': [[12, 7], [11, 1]],
    #         'primary_keys': [1, 7, 11],
    #         'table_names': ['department', 'head', 'management'],
    #         'table_names_original': ['department', 'head', 'management']}

    # LOAD WORD EMBEDDINGS
    # if not os.path.isfile('./glove/usedwordemb.pickle'):
    # print("Creating word embedding dictionary...")
    # word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
    #         load_used=LOAD_USED_W2I,
    #         use_small=USE_SMALL)
    # print("word_emb: {}".format(word_emb))
    # print("tyep word_emb: {}".format(type(word_emb)))
    #     with open('./glove/usedwordemb.pickle', 'wb') as handle:
    #         print("Saving word embedding as pickle...")
    #         pickle.dump(word_emb, handle, protocol=pickle.HIGHEST_PROTOCOL)
    # else:
    #     with open('./glove/usedwordemb.pickle', 'rb') as handle:
    #         print("Loading word embedding pickle...")
    #         word_emb = pickle.load(handle)

    # dev_data = load_train_dev_dataset(args.train_component, "dev", args.history)
    #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt")

    # print("table_type = " + str(args.table_type))
    model = SuperModel(word_emb,
                       N_word=N_word,
                       gpu=GPU,
                       trainable_emb=TRAIN_EMB,
                       table_type=TABLE_TYPE,
                       use_hs=use_hs)

    # agg_m, sel_m, cond_m = best_model_name(args)
    # torch.save(model.state_dict(), "saved_models/{}_models.dump".format(args.train_component))

    print("Loading modules...")
    if GPU_ENABLE:
        map_to = "gpu"
    else:
        map_to = "cpu"

    # LOAD THE TRAINED MODELS
    model.multi_sql.load_state_dict(
        torch.load("{}/multi_sql_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.key_word.load_state_dict(
        torch.load("{}/keyword_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.col.load_state_dict(
        torch.load("{}/col_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.op.load_state_dict(
        torch.load("{}/op_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.agg.load_state_dict(
        torch.load("{}/agg_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.root_teminal.load_state_dict(
        torch.load("{}/root_tem_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.des_asc.load_state_dict(
        torch.load("{}/des_asc_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.having.load_state_dict(
        torch.load("{}/having_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))
    model.andor.load_state_dict(
        torch.load("{}/andor_models.dump".format(SAVED_MODELS_FOLDER),
                   map_location=map_to))

    # print("Model used:")
    # print(model)

    # This should return the generated SQL query
    # test_acc(model, BATCH_SIZE, data, args.output_path)
    #test_exec_acc()

    # This should return the generated SQL query
    gen_sql = infer_sql(model=model,
                        batch_size=BATCH_SIZE,
                        nlq=nlq,
                        table_dict=table_dict,
                        output_path=OUTPUT_PATH)
    # print("Generated SQL: {}".format(gen_sql))

    return gen_sql