Ejemplo n.º 1
0
def __train():
    log_file = os.path.join(config.LOG_DIR, '{}-{}-{}-{}.log'.format(os.path.splitext(
        os.path.basename(__file__))[0], args.idx, str_today, config.MACHINE_NAME))
    # log_file = None
    init_universal_logging(log_file, mode='a', to_stdout=True)
    logging.info('logging to {}'.format(log_file))

    train_config = srlfetexp.TrainConfig(loss_name='mm', neg_scale=0.1, n_steps=70000)

    lstm_dim = 250
    mlp_hidden_dim = 500
    type_embed_dim = 500
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE

    # dataset = 'figer'
    dataset = 'bbn'
    datafiles = config.FIGER_FILES if dataset == 'figer' else config.BBN_FILES

    data_prefix = datafiles['srl-train-data-prefix']
    dev_data_pkl = data_prefix + '-dev.pkl'
    train_data_pkl = data_prefix + '-train.pkl'

    test_file_tup = (datafiles['test-mentions'], datafiles['test-sents'],
                     datafiles['test-sents-dep'], datafiles['test-srl'])
    single_type_path = False if dataset == 'figer' else True

    # output_model_file = None
    save_model_file_prefix = os.path.join(config.DATA_DIR, 'models/srl-{}'.format(dataset))

    gres = expdata.ResData(datafiles['type-vocab'], word_vecs_file)
    logging.info('dataset={} {}'.format(dataset, data_prefix))
    srlfetexp.train_srlfet(device, gres, train_data_pkl, dev_data_pkl, None, test_file_tup, lstm_dim, mlp_hidden_dim,
                           type_embed_dim, train_config, single_type_path,
                           save_model_file_prefix=save_model_file_prefix)
Ejemplo n.º 2
0
def __train():
    use_vr = True
    # use_vr = False
    use_hr = True
    # use_hr = False

    train_samples_file = os.path.join(config.DATA_DIR, 'figer/figer-dev-man-labeled.txt')
    train_base_preds_file = os.path.join(config.DATA_DIR, 'figer/ftbfet7069-results-figer-wikival.txt')
    train_srl_preds_file = os.path.join(config.DATA_DIR, 'figer/figer-train-srl0-preds.txt')
    # train_srl_preds_file = os.path.join(config.DATA_DIR, 'figer/figer-train-srl-preds.txt')
    train_hyp_file = os.path.join(config.DATA_DIR, 'figer/figer-wikival-all-fmm-hypext.txt')
    train_verif_hyp_file = os.path.join(config.DATA_DIR, 'figer/figer-wikival-all-fmm-hypext-bbntyped.txt')
    train_hypext_logits_file = os.path.join(config.DATA_DIR,
                                            'figer/bert-tmhypextbbn-results-all-figer-wikival.txt')
    train_pred_file_tup = (train_base_preds_file, train_srl_preds_file, train_hyp_file, train_verif_hyp_file,
                           train_hypext_logits_file)
    gres = expdata.ResData(config.FIGER_FILES['type-vocab'], None)
    type_infer_train = TypeInfer(gres.type_vocab, gres.type_id_dict, single_type_path=False)
    _, _, child_type_vecs_train = fetutils.build_hierarchy_vecs(gres.type_vocab, gres.type_id_dict)
    train_samples = expdata.load_labeled_samples(gres.type_id_dict, child_type_vecs_train,
                                                 train_samples_file, train_pred_file_tup, use_vr, use_hr)
    print(len(train_samples), 'labeled samples')

    test_mentions_file = config.FIGER_FILES['test-mentions']
    test_sents_file = config.FIGER_FILES['test-sents']
    test_base_preds_file = os.path.join(config.DATA_DIR, 'figer/ftbfet7069-results-figer.txt')
    test_srl_preds_file = os.path.join(config.DATA_DIR, 'figer/figer-test-srl0-preds.txt')
    # test_srl_preds_file = os.path.join(config.DATA_DIR, 'figer/figer-test-srl-preds.txt')
    test_hyp_preds_file = os.path.join(config.DATA_DIR, 'figer/figer-test-all-fmm-hypext.txt')
    test_verif_hypext_file = os.path.join(config.DATA_DIR, 'figer/figer-test-all-fmm-hypext-bbntyped.txt')
    test_hypext_logits_file = os.path.join(config.DATA_DIR, 'figer/bert-tmhypextbbn-results-all-figer-test.txt')
    test_file_tup = (test_mentions_file, test_sents_file, test_base_preds_file, test_srl_preds_file,
                     test_hyp_preds_file, test_verif_hypext_file, test_hypext_logits_file)
    gres_test = expdata.ResData(config.FIGER_FILES['type-vocab'], None)

    l1_type_indices_test, l1_type_vec_test, child_type_vecs_test = fetutils.build_hierarchy_vecs(
        gres.type_vocab, gres.type_id_dict)
    type_infer_test = TypeInfer(gres_test.type_vocab, gres_test.type_id_dict, single_type_path=False)
    test_samples, test_true_labels_dict = expdata.samples_from_test(gres_test, child_type_vecs_test, test_file_tup)

    stackexp.train_stacking(device, gres, use_vr, use_hr, type_infer_train, train_samples,
                            gres_test, type_infer_test, test_samples, test_true_labels_dict)
Ejemplo n.º 3
0
def __train3():
    log_file = os.path.join(config.LOG_DIR, '{}-{}-{}-{}.log'.format(os.path.splitext(
        os.path.basename(__file__))[0], args.idx, str_today, config.MACHINE_NAME))
    # log_file = None
    init_universal_logging(log_file, mode='a', to_stdout=True)
    logging.info('logging to {}'.format(log_file))

    margin = 1.0
    train_config = srlfetexp.TrainConfig(
        pos_margin=margin, neg_margin=margin, neg_scale=1.0, batch_size=128, schedule_lr=True, n_steps=70000)

    lstm_dim = 250
    mlp_hidden_dim = 500
    type_embed_dim = 500
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE

    dataset = 'figer'
    # dataset = 'bbn'
    datafiles = config.FIGER_FILES if dataset == 'figer' else config.BBN_FILES

    data_prefix = datafiles['srl-train-data-prefix']
    dev_data_pkl = data_prefix + '-dev.pkl'
    train_data_pkl = data_prefix + '-train.pkl'

    test_file_tup = (datafiles['test-mentions'], datafiles['test-sents'],
                     datafiles['test-sents-dep'], datafiles['test-srl'])
    single_type_path = False if dataset == 'figer' else True

    # output_model_file = None
    save_model_file_prefix = os.path.join(config.DATA_DIR, 'models/srl3-{}'.format(dataset))

    val_mentions_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-mentions.json')
    val_sents_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-sents.json')
    val_srl_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-srl.txt')
    val_dep_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-tok-dep.txt')
    val_manual_label_file = os.path.join(config.DATA_DIR, 'figer/figer-dev-man-labeled.txt')
    # manual_val_file_tup = (val_mentions_file, val_sents_file, val_dep_file, val_srl_file, val_manual_label_file)
    manual_val_file_tup = None

    gres = expdata.ResData(datafiles['type-vocab'], word_vecs_file)
    logging.info('dataset={} {}'.format(dataset, data_prefix))
    srlfetexp.train_srlfet(
        device, gres, train_data_pkl, dev_data_pkl, manual_val_file_tup, test_file_tup, lstm_dim, mlp_hidden_dim,
        type_embed_dim, train_config, single_type_path, save_model_file_prefix=save_model_file_prefix)
Ejemplo n.º 4
0
def __eval():
    dataset = 'figer'
    # dataset = 'bbn'
    datafiles = config.FIGER_FILES if dataset == 'figer' else config.BBN_FILES
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE
    model_file_prefix = os.path.join(config.DATA_DIR,
                                     'models/srl-{}'.format(dataset))
    # sub_set = 'test'
    # sub_set = 'train'
    sub_sets = ['test', 'train']
    for sub_set in sub_sets:
        if sub_set == 'test':
            mentions_file = datafiles['test-mentions']
            sents_file = datafiles['test-sents']
            srl_file = datafiles['test-srl']
            dep_file = datafiles['test-sents-dep']
        else:
            if dataset == 'bbn':
                mentions_file = datafiles['train-mentions']
                sents_file = datafiles['train-sents']
                srl_file = datafiles['train-srl']
                dep_file = datafiles['train-sents-dep']
            else:
                mentions_file = os.path.join(
                    config.DATA_DIR, 'figer/wiki-valcands-figer-mentions.json')
                sents_file = os.path.join(
                    config.DATA_DIR, 'figer/wiki-valcands-figer-sents.json')
                srl_file = os.path.join(config.DATA_DIR,
                                        'figer/wiki-valcands-figer-srl.txt')
                dep_file = os.path.join(
                    config.DATA_DIR, 'figer/wiki-valcands-figer-tok-dep.txt')

        output_preds_file = os.path.join(
            config.DATA_DIR,
            '{}/{}-{}-srl-preds.txt'.format(dataset, dataset, sub_set))
        single_type_path = True if dataset == 'bbn' else False

        gres = expdata.ResData(datafiles['type-vocab'], word_vecs_file)
        srlfetexp.eval_trained(device, gres, model_file_prefix, mentions_file,
                               sents_file, srl_file, dep_file,
                               single_type_path, output_preds_file)