def transfer(args): load_model = args.load_model if args.mode == 'train' else True class Dummy: pass train_env = Dummy() infer_env = Dummy() _, infer_graph = model(args, train_env, infer_env) args.data = 'overnight' args.load_data = True X_tran, y_tran = load_data(args) args.data = 'overnight_set' tran_sets = load_data(args) model2load = 'model/{}'.format(args.subset) sess = tf.InteractiveSession(graph=infer_graph) infer_env.infer_saver.restore(sess, model2load) print('========subset transfer set========') subsets = ['basketball', 'calendar', 'housing', 'recipes', 'restaurants'] for subset, (X_tran_subset, y_tran_subset) in zip(subsets, tran_sets): print('---------' + subset + '---------') tran_em = decode_data(sess, infer_env, args, X_tran_subset, y_tran_subset, filename=str(subset + '.txt')) print('===========transfer set============') tran_em = decode_data(sess, infer_env, args, X_tran, y_tran) return
def inferrence(args): args.load_model = True class Dummy: pass train_env = Dummy() infer_env = Dummy() _, infer_graph = model(args, train_env, infer_env) args.data = 'wikisql' args.load_data = True X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args) model2load = 'model/{}'.format(args.subset) sess = tf.InteractiveSession(graph=infer_graph) infer_env.infer_saver.restore(sess, model2load) print('===========dev set============') decode_data(sess, infer_env, args, X_dev, y_dev) em = decode_data_recover(sess, infer_env, args, X_dev, y_dev, 'dev') print('==========test set===========') decode_data(sess, infer_env, args, X_test, y_test) test_em = decode_data_recover(sess, infer_env, args, X_test, y_test, 'test') return
def train_model(args): class Dummy: pass train_env = Dummy() infer_env = Dummy() train_graph, infer_graph = model(args, train_env, infer_env) args.data = 'geo' args.load_data = True args.load_model = False X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args) model2load = 'model/{}'.format(args.subset) max_em, global_test_em, best_base = -1, -1, -1 acc = 0 sess1 = tf.InteractiveSession(graph=train_graph) sess1.run(tf.global_variables_initializer()) sess1.run(tf.local_variables_initializer()) sess2 = tf.InteractiveSession(graph=infer_graph) sess2.run(tf.global_variables_initializer()) sess2.run(tf.global_variables_initializer()) for base in range(args.total_epochs / args.epochs): print('\nIteration: %d (%d epochs)' % (base, args.epochs)) model2load, acc = train( sess1, train_env, X_train, y_train, epochs=args.epochs, load=args.load_model, name=args.subset, batch_size=args.batch_size, base=base, acc=acc, model2load=model2load) args.load_model = True if acc > 0: infer_env.infer_saver.restore(sess2, model2load) print('===========dev set============') dev_em = decode_data(sess2, infer_env, args, X_dev, y_dev) #dev_em = decode_data_recover(sess, infer_env, args, X_dev, y_dev, # 'dev') print('==========test set===========') test_em = decode_data(sess2, infer_env, args, X_test, y_test) #test_em = decode_data_recover(sess, infer_env, args, X_test, y_test, # 'test') if dev_em > max_em: max_em = dev_em global_test_em = test_em best_base = base print('\n Saving model for best testing') train_env.saver.save(sess1, 'model/{0}-{1}-{2}'.format(args.subset, base, max_em)) print('Max EM acc: %.4f during %d iteration.' % (max_em, best_base)) print('test EM acc: %.4f ' % global_test_em) return
def train_model(args): class Dummy: pass train_env = Dummy() infer_env = Dummy() train_graph, infer_graph = model(args, train_env, infer_env) args.data = 'wikisql' args.load_data = True args.load_model = False X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args) model2load = 'model/{}'.format(args.subset) max_em, global_test_em, best_base = -1, -1, -1 for base in range(args.total_epochs / args.epochs): print('\nIteration: %d (%d epochs)' % (base, args.epochs)) sess = tf.InteractiveSession(graph=train_graph) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) model2load = train(sess, train_env, X_train, y_train, epochs=args.epochs, load=args.load_model, name=args.subset, batch_size=args.batch_size, base=base, model2load=model2load) args.load_model = True break sess = tf.InteractiveSession(graph=infer_graph) infer_env.infer_saver.restore(sess, model2load) print('===========dev set============') decode_data(sess, infer_env, X_dev, y_dev) dev_em = decode_data_recover(sess, infer_env, args, X_dev, y_dev, 'dev') print('==========test set===========') decode_data(sess, infer_env, X_test, y_test) test_em = decode_data_recover(sess, infer_env, args, X_test, y_test, 'test') if em > max_em: max_em = em global_test_em = test_em best_base = base print('Max EM acc: %.4f during %d iteration.' % (max_em, best_base)) print('test EM acc: %.4f ' % global_test_em) return