Exemplo n.º 1
0
        USE_SMALL = False
        GPU = True
        BATCH_SIZE = 64
    TEST_ENTRY = (True, True, True)  # (AGG, SEL, COND)

    sql_data, table_data, val_sql_data, val_table_data, test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB = load_dataset(
        args.dataset, use_small=USE_SMALL)
    examples, tables = load_dataset_dummy(0)
    examples.extend(sql_data)
    tables.update(table_data)

    word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
        load_used=True, use_small=USE_SMALL) # load_used can speed up loading

    if args.baseline:
        model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, trainable_emb=True)
    else:
        model = SQLNet(word_emb,
                       N_word=N_word,
                       use_ca=args.ca,
                       gpu=GPU,
                       trainable_emb=True)

    if args.train_emb:
        agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(
            args, savedstr='_mconly')
        print('==> best model names:', agg_m, sel_m, cond_m)
        print "Loading from %s" % agg_m
        model.agg_pred.load_state_dict(torch.load(agg_m))
        print "Loading from %s" % sel_m
        model.sel_pred.load_state_dict(torch.load(sel_m))
Exemplo n.º 2
0
        BATCH_SIZE = 64
    TRAIN_ENTRY = (True, True, True)  # (AGG, SEL, COND)
    TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
    learning_rate = 1e-4 if args.rl else 1e-3

    sql_data, table_data, val_sql_data, val_table_data, \
            test_sql_data, test_table_data, \
            TRAIN_DB, DEV_DB, TEST_DB = load_dataset(
                    args.dataset, use_small=USE_SMALL)

    word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
            load_used=args.train_emb, use_small=USE_SMALL)

    if args.baseline:
        model = Seq2SQL(word_emb,
                        N_word=N_word,
                        gpu=GPU,
                        trainable_emb=args.train_emb)
        assert not args.train_emb, "Seq2SQL can\'t train embedding."
    else:
        model = SQLNet(word_emb,
                       N_word=N_word,
                       use_ca=args.ca,
                       gpu=GPU,
                       trainable_emb=args.train_emb)
        assert not args.rl, "SQLNet can\'t do reinforcement learning."
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0)

    if args.train_emb:
        agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args)
Exemplo n.º 3
0
            GPU=True
            BATCH_SIZE=64
        TRAIN_ENTRY=(True, True, True)  # (AGG, SEL, COND)
        TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
        learning_rate = 1e-4 if params['rl'] else 1e-3

        sql_data, table_data, val_sql_data, val_table_data, \
                test_sql_data, test_table_data, \
                TRAIN_DB, DEV_DB, TEST_DB = load_dataset(
                        params['dataset'], use_small=USE_SMALL)

        word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \
                load_used=params['train_emb'], use_small=USE_SMALL)

        if params['baseline']:
            model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU,
                    trainable_emb = params['train_emb'])
            assert not params['train_emb'], "Seq2SQL can\'t train embedding."
        else:
            model = SQLNet(word_emb, N_word=N_word, use_ca=params['ca'],
                    gpu=GPU, trainable_emb = params['train_emb'])
            assert not params['rl'], "SQLNet can\'t do reinforcement learning."
        optimizer = torch.optim.Adam(model.parameters(),
                lr=learning_rate, weight_decay = 0)

        if params['train_emb']:
            agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args)
        else:
            agg_m, sel_m, cond_m = best_model_name(args)

        if params['rl'] or params['train_emb']: # Load pretrained model.
            agg_lm, sel_lm, cond_lm = best_model_name(args, for_load=True)