Ejemplo n.º 1
0
def transfer(args):

    load_model = args.load_model if args.mode == 'train' else True

    class Dummy:
        pass

    train_env = Dummy()
    infer_env = Dummy()

    _, infer_graph = model(args, train_env, infer_env)

    args.data = 'overnight'
    args.load_data = True
    X_tran, y_tran = load_data(args)
    args.data = 'overnight_set'
    tran_sets = load_data(args)
    model2load = 'model/{}'.format(args.subset)

    sess = tf.InteractiveSession(graph=infer_graph)
    infer_env.infer_saver.restore(sess, model2load)

    print('========subset transfer set========')
    subsets = ['basketball', 'calendar', 'housing', 'recipes', 'restaurants']
    for subset, (X_tran_subset, y_tran_subset) in zip(subsets, tran_sets):
        print('---------' + subset + '---------')
        tran_em = decode_data(sess,
                              infer_env,
                              args,
                              X_tran_subset,
                              y_tran_subset,
                              filename=str(subset + '.txt'))
    print('===========transfer set============')
    tran_em = decode_data(sess, infer_env, args, X_tran, y_tran)
    return
Ejemplo n.º 2
0
def inferrence(args):
    args.load_model = True

    class Dummy:
        pass

    train_env = Dummy()
    infer_env = Dummy()
    _, infer_graph = model(args, train_env, infer_env)

    args.data = 'wikisql'
    args.load_data = True
    X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args)
    model2load = 'model/{}'.format(args.subset)

    sess = tf.InteractiveSession(graph=infer_graph)
    infer_env.infer_saver.restore(sess, model2load)
    print('===========dev set============')
    decode_data(sess, infer_env, args, X_dev, y_dev)
    em = decode_data_recover(sess, infer_env, args, X_dev, y_dev, 'dev')
    print('==========test set===========')
    decode_data(sess, infer_env, args, X_test, y_test)
    test_em = decode_data_recover(sess, infer_env, args, X_test, y_test,
                                  'test')

    return
Ejemplo n.º 3
0
def train_model(args):
    class Dummy:
        pass

    train_env = Dummy()
    infer_env = Dummy()

    train_graph, infer_graph = model(args, train_env, infer_env)

    args.data = 'geo'
    args.load_data = True
    args.load_model = False
    X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args)
    model2load = 'model/{}'.format(args.subset)
    max_em, global_test_em, best_base = -1, -1, -1
    acc = 0
    sess1 = tf.InteractiveSession(graph=train_graph)
    sess1.run(tf.global_variables_initializer())
    sess1.run(tf.local_variables_initializer())
    sess2 = tf.InteractiveSession(graph=infer_graph)
    sess2.run(tf.global_variables_initializer())
    sess2.run(tf.global_variables_initializer())
    for base in range(args.total_epochs / args.epochs):
        print('\nIteration: %d (%d epochs)' % (base, args.epochs))
        model2load, acc = train(
            sess1,
            train_env,
            X_train,
            y_train,
            epochs=args.epochs,
            load=args.load_model,
            name=args.subset,
            batch_size=args.batch_size,
            base=base,
            acc=acc,
            model2load=model2load)
        args.load_model = True
        if acc > 0:
            infer_env.infer_saver.restore(sess2, model2load)

            print('===========dev set============')
            dev_em = decode_data(sess2, infer_env, args, X_dev, y_dev)
            #dev_em = decode_data_recover(sess, infer_env, args, X_dev, y_dev,
            #                             'dev')
            print('==========test set===========')
            test_em = decode_data(sess2, infer_env, args, X_test, y_test)
            #test_em = decode_data_recover(sess, infer_env, args, X_test, y_test,
            #                              'test')

            if dev_em > max_em:
                max_em = dev_em
                global_test_em = test_em
                best_base = base
                print('\n Saving model for best testing')
                train_env.saver.save(sess1, 'model/{0}-{1}-{2}'.format(args.subset, base, max_em))
            print('Max EM acc: %.4f during %d iteration.' % (max_em, best_base))
            print('test EM acc: %.4f ' % global_test_em)

    return
Ejemplo n.º 4
0
def train_model(args):
    class Dummy:
        pass

    train_env = Dummy()
    infer_env = Dummy()

    train_graph, infer_graph = model(args, train_env, infer_env)

    args.data = 'wikisql'
    args.load_data = True
    args.load_model = False
    X_train, y_train, X_test, y_test, X_dev, y_dev = load_data(args)
    model2load = 'model/{}'.format(args.subset)
    max_em, global_test_em, best_base = -1, -1, -1

    for base in range(args.total_epochs / args.epochs):
        print('\nIteration: %d (%d epochs)' % (base, args.epochs))
        sess = tf.InteractiveSession(graph=train_graph)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        model2load = train(sess,
                           train_env,
                           X_train,
                           y_train,
                           epochs=args.epochs,
                           load=args.load_model,
                           name=args.subset,
                           batch_size=args.batch_size,
                           base=base,
                           model2load=model2load)
        args.load_model = True
        break
        sess = tf.InteractiveSession(graph=infer_graph)
        infer_env.infer_saver.restore(sess, model2load)

        print('===========dev set============')
        decode_data(sess, infer_env, X_dev, y_dev)
        dev_em = decode_data_recover(sess, infer_env, args, X_dev, y_dev,
                                     'dev')
        print('==========test set===========')
        decode_data(sess, infer_env, X_test, y_test)
        test_em = decode_data_recover(sess, infer_env, args, X_test, y_test,
                                      'test')

        if em > max_em:
            max_em = em
            global_test_em = test_em
            best_base = base
        print('Max EM acc: %.4f during %d iteration.' % (max_em, best_base))
        print('test EM acc: %.4f ' % global_test_em)
    return