コード例 #1
0
ファイル: test.py プロジェクト: jungokasai/stagging_srl
def test(args):
    model_dir = args.model_dir
    with open(os.path.join(model_dir, 'args.pkl'), 'rb') as f:
        model_args = pickle.load(f)
    if not hasattr(model_args, 'language'):
        model_args.language = 'eng'

    #model_args.stags_dir = 'pred'

    fn_txt_valid = 'data/{}/conll09/{}.txt'.format(model_args.language,
                                                   args.data)
    #fn_preds_valid = 'data/{}/conll09/pred/{}_predicates.txt'.format(model_args.language, args.data)
    #fn_preds_valid = 'data/{}/conll09/pred/_{}_predicates_mt.txt'.format(model_args.language, args.data)
    fn_preds_valid = 'data/{}/conll09/pred/{}_predicates.txt'.format(
        model_args.language, args.data)
    if args.stags is None:
        fn_stags_valid = 'data/{}/conll09/{}/{}_stags_{}.txt'.format(
            model_args.language, model_args.stags_dir, args.data,
            model_args.stag_type)
    else:
        fn_stags_valid = args.stags

    fn_sys = '{}.txt'.format(args.data)
    fn_sys = os.path.join(args.model_dir, fn_sys)

    vocabs = vocab.get_vocabs(model_args.language, model_args.stag_type)

    with tf.Graph().as_default():
        tf.set_random_seed(model_args.seed)
        np.random.seed(model_args.seed)

        print("Building model...")
        model = SRL_Model(vocabs, model_args)

        saver = tf.train.Saver()

        with tf.Session() as session:
            print('Restoring model...')
            saver.restore(session, tf.train.latest_checkpoint(model_dir))

            print('-' * 78)
            print('Validating...')
            valid_loss = model.run_testing_epoch(session, vocabs, fn_txt_valid,
                                                 fn_preds_valid,
                                                 fn_stags_valid, fn_sys,
                                                 model_args.language)
            print('Validation loss: {}'.format(valid_loss))

            print('-' * 78)
            print('Running evaluation script...')
            labeled_f1, unlabeled_f1 = run_evaluation_script(
                fn_txt_valid, fn_sys)
            print('Labeled F1:    {0:.2f}'.format(labeled_f1))
            print('Unlabeled F1:  {0:.2f}'.format(unlabeled_f1))
コード例 #2
0
ファイル: train.py プロジェクト: jungokasai/stagging_srl
def train(args):
    # Set the filepaths for training and validation
    fn_txt_train = 'data/{}/conll09/train.txt'.format(args.language)
    fn_preds_train = 'data/{}/conll09/{}/train_predicates.txt'.format(
        args.language, args.training_split)
    if args.use_stags:
        fn_stags_train = 'data/{}/conll09/{}/train_stags_{}.txt'.format(
            args.language, args.stags_dir, args.stag_type)
    else:
        fn_stags_train = fn_preds_train

    fn_txt_valid = 'data/{}/conll09/dev.txt'.format(args.language)
    if args.language == 'eng':
        fn_preds_valid = 'data/{}/conll09/pred/_dev_predicates_mt.txt'.format(
            args.language)
    else:
        fn_preds_valid = 'data/{}/conll09/gold/dev_predicates.txt'.format(
            args.language)
    if args.use_stags:
        #fn_stags_valid = 'data/{}/conll09/{}/dev_stags_{}.txt'.format(
        #    args.language, args.stags_dir, args.stag_type)
        fn_stags_valid = 'data/{}/conll09/{}/dev_stags_{}.txt'.format(
            args.language, args.stags_dir, args.stag_type)
    else:
        fn_stags_valid = fn_preds_valid

    # Come up with a model name based on the hyperparameters
    model_suffix = '_'
    if args.training_split == 'gold':
        model_suffix += 'g'
    else:
        model_suffix += 'p'
    if args.testing_split == 'gold':
        model_suffix += 'g'
    else:
        model_suffix += 'p'
    if args.language != 'eng':
        model_suffix += '_' + args.language
    if args.restrict_labels:
        model_suffix += '_rl'
    if args.use_stags:
        model_suffix += '_st{}_{}'.format(args.stag_embed_size, args.stag_type)
        if args.stags_dir == 'gold':
            model_suffix += 'g'
        elif args.stags_dir == 'pred':
            model_suffix += 'p'
        elif args.stags_dir == 'malt':
            model_suffix += 'm'
        elif args.stags_dir == 'pred_pos':
            model_suffix += 'o'
        elif args.stags_dir == 'predicted_stag_0.5':
            model_suffix += '5'
        elif args.stags_dir == 'pred_new':
            model_suffix += 'n'
        elif args.stags_dir == 'pred_elmo':
            model_suffix += 'e'
        if args.use_stag_features:
            model_suffix += 'f{}'.format(args.stag_feature_embed_size)
    if args.dropout < 1.0:
        model_suffix += '_dr{}'.format(args.dropout)
    if args.recurrent_dropout < 1.0:
        model_suffix += '_rdr{}'.format(args.recurrent_dropout)
    if args.use_word_dropout:
        model_suffix += '_wdr'
    if args.use_basic_classifier:
        model_suffix += '_bc'
    if args.use_highway_lstm:
        model_suffix += '_hw'
    if args.optimizer != 'adam':
        model_suffix += '_' + args.optimizer
    if args.seed != 89:
        model_suffix += '_s{}'.format(args.seed)
    fn_sys = 'output/predictions/dev{}.txt'.format(model_suffix)

    # Prepare for saving the model
    model_dir = 'output/models/srl' + model_suffix + '/'
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    print('Saving args to', model_dir + 'args.pkl')
    with open(model_dir + 'args.pkl', 'wb') as f:
        pickle.dump(args, f)

    vocabs = vocab.get_vocabs(args.language, args.stag_type)

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)

        print("Building model...")
        model = SRL_Model(vocabs, args)
        saver = tf.train.Saver(max_to_keep=1)

        with tf.Session() as session:
            best_f1 = 0
            bad_streak = 0

            session.run(tf.global_variables_initializer())

            for i in range(args.max_epochs):
                print('-' * 78)
                print('Epoch {}'.format(i))
                start = timer()
                train_loss = model.run_training_epoch(session, vocabs,
                                                      fn_txt_train,
                                                      fn_preds_train,
                                                      fn_stags_train,
                                                      args.language)
                end = timer()
                print('Done with epoch {}'.format(i))
                print('Avg loss: {}, total time: {}'.format(
                    train_loss, end - start))

                print('-' * 78)
                print('Validating...')
                valid_loss = model.run_testing_epoch(session, vocabs,
                                                     fn_txt_valid,
                                                     fn_preds_valid,
                                                     fn_stags_valid, fn_sys,
                                                     args.language)
                print('Validation loss: {}'.format(valid_loss))

                print('-' * 78)
                print('Running evaluation script...')
                labeled_f1, unlabeled_f1 = run_evaluation_script(
                    fn_txt_valid, fn_sys)
                print('Labeled F1:    {0:.2f}'.format(labeled_f1))
                print('Unlabeled F1:  {0:.2f}'.format(unlabeled_f1))

                if labeled_f1 > best_f1:
                    best_f1 = labeled_f1
                    bad_streak = 0
                    print('Saving model to', model_dir + 'model')
                    saver.save(session, model_dir + 'model')
                else:
                    print('F1 deteriorated (best score: {})'.format(best_f1))
                    bad_streak += 1
                    if bad_streak >= args.early_stopping:
                        print(
                            'No F1 improvement for %d epochs, stopping early' %
                            args.early_stopping)
                        print('Best F1 score: {0:.2f}'.format(best_f1))
                        break