USE_SMALL = False GPU = True BATCH_SIZE = 64 TEST_ENTRY = (True, True, True) # (AGG, SEL, COND) sql_data, table_data, val_sql_data, val_table_data, test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB = load_dataset( args.dataset, use_small=USE_SMALL) examples, tables = load_dataset_dummy(0) examples.extend(sql_data) tables.update(table_data) word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ load_used=True, use_small=USE_SMALL) # load_used can speed up loading if args.baseline: model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, trainable_emb=True) else: model = SQLNet(word_emb, N_word=N_word, use_ca=args.ca, gpu=GPU, trainable_emb=True) if args.train_emb: agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name( args, savedstr='_mconly') print('==> best model names:', agg_m, sel_m, cond_m) print "Loading from %s" % agg_m model.agg_pred.load_state_dict(torch.load(agg_m)) print "Loading from %s" % sel_m model.sel_pred.load_state_dict(torch.load(sel_m))
BATCH_SIZE = 64 TRAIN_ENTRY = (True, True, True) # (AGG, SEL, COND) TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY learning_rate = 1e-4 if args.rl else 1e-3 sql_data, table_data, val_sql_data, val_table_data, \ test_sql_data, test_table_data, \ TRAIN_DB, DEV_DB, TEST_DB = load_dataset( args.dataset, use_small=USE_SMALL) word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ load_used=args.train_emb, use_small=USE_SMALL) if args.baseline: model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb) assert not args.train_emb, "Seq2SQL can\'t train embedding." else: model = SQLNet(word_emb, N_word=N_word, use_ca=args.ca, gpu=GPU, trainable_emb=args.train_emb) assert not args.rl, "SQLNet can\'t do reinforcement learning." optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) if args.train_emb: agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args)
GPU=True BATCH_SIZE=64 TRAIN_ENTRY=(True, True, True) # (AGG, SEL, COND) TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY learning_rate = 1e-4 if params['rl'] else 1e-3 sql_data, table_data, val_sql_data, val_table_data, \ test_sql_data, test_table_data, \ TRAIN_DB, DEV_DB, TEST_DB = load_dataset( params['dataset'], use_small=USE_SMALL) word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ load_used=params['train_emb'], use_small=USE_SMALL) if params['baseline']: model = Seq2SQL(word_emb, N_word=N_word, gpu=GPU, trainable_emb = params['train_emb']) assert not params['train_emb'], "Seq2SQL can\'t train embedding." else: model = SQLNet(word_emb, N_word=N_word, use_ca=params['ca'], gpu=GPU, trainable_emb = params['train_emb']) assert not params['rl'], "SQLNet can\'t do reinforcement learning." optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 0) if params['train_emb']: agg_m, sel_m, cond_m, agg_e, sel_e, cond_e = best_model_name(args) else: agg_m, sel_m, cond_m = best_model_name(args) if params['rl'] or params['train_emb']: # Load pretrained model. agg_lm, sel_lm, cond_lm = best_model_name(args, for_load=True)