コード例 #1
0
ファイル: lm.py プロジェクト: rekriz11/sockeye-recipes
    def generate_data_file(self, file='googlebooks-eng-all-3gram-20120701-en'):
        grams = defaultdict(list)
        gram_path = get_path('../text_simplification_data/wiki/' + file)
        idx = 0
        pre_time = time.time()
        for line in open(gram_path, encoding='utf-8'):
            items = line.split('\t')
            ngram = '|'.join(
                [word.split('_')[0] for word in items[0].split(' ')])
            # year = int(items[1])
            cnt = int(items[2])
            grams[ngram].append(cnt)
            idx += 1
            if idx % 1000000 == 0:
                cur_time = time.time()
                print('processed %s. in %s' % (idx, cur_time - pre_time))
                pre_time = cur_time

        output_file = get_path('../text_simplification_data/wiki/' + file +
                               '.processed')
        f = open(output_file, 'w', encoding='utf-8')
        outputs = []
        for ngram in grams:
            output = '\t'.join([ngram, str(np.mean(grams[ngram]))])
            outputs.append(output)
            if len(outputs) == 100000:
                f.write('\n'.join(outputs))
                f.flush()
                outputs = []
        f.write('\n'.join(outputs))
        f.flush()
        f.close()
コード例 #2
0
            model_config = WikiDressLargeNewDefault()
            ckpt = get_ckpt(model_config.modeldir, model_config.logdir)

            if ckpt:
                vconfig = WikiDressLargeNewTestDefault()
                if best_sari is None:
                    best_sari = get_best_sari(vconfig.resultdir)

                sari_point = eval(vconfig, ckpt)

                # Try different max_cand_rules
                if args.memory is not None and 'rule' in args.memory:
                    for rcand in [15, 30, 50]:
                        vconfig.max_cand_rules = rcand
                        vconfig.resultdir = get_path(
                            '../' + vconfig.output_folder +
                            '/result/eightref_val_cand' + str(rcand),
                            vconfig.environment)
                        eval(vconfig, ckpt)

                eval(WikiDressLargeNewTestDefault(), ckpt)
                print(
                    '=====================Current Best SARI:%s====================='
                    % best_sari)
                if float(sari_point) < best_sari:
                    remove(ckpt + '.index')
                    remove(ckpt + '.meta')
                    remove(ckpt + '.data-00000-of-00001')
                    print('remove ckpt:%s' % ckpt)
                else:
                    for file in listdir(model_config.modeldir):
                        step = ckpt[ckpt.rindex('model.ckpt-') +
コード例 #3
0
from model.model_config import BaseConfig
from data_generator.data import EvalData
from model.model_config import get_path
from operator import itemgetter

mapper = {}
for line in open(get_path('../wsd_data/medline/abbr.txt')):
    items = line.strip().split('\t')
    cnt = int(items[1])
    abbr_items = items[0].split('|')
    if abbr_items[1] not in mapper:
        mapper[abbr_items[1]] = []
    mapper[abbr_items[1]].append((abbr_items[2], cnt))
for abbr in mapper:
    mapper[abbr].sort(key=itemgetter(-1), reverse=True)

model_config = BaseConfig()
data = EvalData(model_config)

sample = data.get_sample()
correct_cnt, correct_cnt2, correct_cnt3, correct_cnt4, correct_cnt5 = 0.0, 0.0, 0.0, 0.0, 0.0
total_cnt = 0.0
while sample is not None:
    ts = sample['targets']

    for t in ts:
        if t[1] == 0 or t[2] == 0:
            continue
        cascade_add = False
        try:
            pred = data.sense2id[data.id2abbr[t[1]] + '|' +
コード例 #4
0
ファイル: dataset_helper.py プロジェクト: Astroneko404/wsd
    def __init__(self, environment):

        # DataSet Corpus files
        if environment == 'luoz3_x1':
            mimic_base_folder = "/home/mengr/Project/wsd/wsd_data/mimic/data"
            self.mimic_train_txt = mimic_base_folder+"/train"
            self.mimic_eval_txt = mimic_base_folder+"/eval"
        else:
            self.mimic_train_txt = "/exp_data/wsd_data/mimic/train"
            self.mimic_eval_txt = "/exp_data/wsd_data/mimic/eval"
            # self.mimic_train_txt = get_path('../wsd_data/mimic/train', env=environment)
            # self.mimic_eval_txt = get_path('../wsd_data/mimic/eval', env=environment)
            # # mimic v1 (deprecated)
            # mimic_train_txt = '/home/zhaos5/projs/wsd/wsd_data/mimic/train'
            # mimic_eval_txt = '/home/zhaos5/projs/wsd/wsd_data/mimic/eval'

        self.share_txt = get_path('../wsd_data/share/processed/share_all_processed.txt', env=environment)
        self.msh_txt = get_path('../wsd_data/msh/msh_processed/msh_processed.txt', env=environment)
        self.umn_txt = get_path('../wsd_data/umn/umn_processed/umn_processed.txt', env=environment)
        self.upmc_example_txt = get_path('../wsd_data/upmc/example/processed/upmc_example_processed.txt', env=environment)
        self.upmc_ab_train_txt = get_path('../wsd_data/upmc/AB/processed/upmc_ab_train.txt', env=environment)
        self.upmc_ab_test_txt = get_path('../wsd_data/upmc/AB/processed/upmc_ab_test.txt', env=environment)
        self.upmc_all_no_mark_txt = get_path('../wsd_data/upmc/batch1_4/processed/train_no_mark.txt', env=environment)
        self.upmc_ad_train_txt = get_path('../wsd_data/upmc/AD/processed/upmc_train.txt', env=environment)
        self.upmc_ad_test_txt = get_path('../wsd_data/upmc/AD/processed/upmc_test.txt', env=environment)

        # paths for processed files
        self.mimic_train_folder = get_path('../wsd_data/mimic/processed/train/', env=environment)
        self.mimic_test_folder = get_path('../wsd_data/mimic/processed/test/', env=environment)
        self.share_test_folder = get_path('../wsd_data/share/processed/test/', env=environment)
        self.msh_test_folder = get_path('../wsd_data/msh/msh_processed/test/', env=environment)
        self.umn_test_folder = get_path('../wsd_data/umn/umn_processed/test/', env=environment)
        self.upmc_example_folder = get_path('../wsd_data/upmc/example/processed/test/', env=environment)
        self.upmc_ab_train_folder = get_path('../wsd_data/upmc/AB/processed/train/', env=environment)
        self.upmc_ab_test_folder = get_path('../wsd_data/upmc/AB/processed/test/', env=environment)
        self.upmc_ad_train_folder = get_path('../wsd_data/upmc/AD/processed/train/', env=environment)
        self.upmc_ad_test_folder = get_path('../wsd_data/upmc/AD/processed/test/', env=environment)
        self.upmc_all_no_mark_folder = get_path('../wsd_data/upmc/batch1_4/processed/', env=environment)

        # path to sense inventory
        self.sense_inventory_json = get_path('../wsd_data/sense_inventory/final_cleaned_sense_inventory_with_testsets.json', env=environment)
        self.sense_inventory_pkl = get_path('../wsd_data/sense_inventory/final_cleaned_sense_inventory_with_testsets.pkl', env=environment)
コード例 #5
0
ファイル: eval.py プロジェクト: zchenack/text_simplification
        while True:
            model_config = WikiDressLargeDefault()
            ckpt = get_ckpt(model_config.modeldir, model_config.logdir)

            if ckpt:
                vconfig = WikiDressLargeEvalDefault()
                if best_sari is None:
                    best_sari = get_best_sari(vconfig.resultdir)

                sari_point = eval(vconfig, ckpt)
                eval(WikiDressLargeTestDefault(), ckpt)
                if args.memory is not None and 'rule' in args.memory:
                    for rcand in [15, 30, 50]:
                        vconfig.max_cand_rules = rcand
                        vconfig.resultdir = get_path(
                            '../' + vconfig.output_folder +
                            '/result/eightref_val_cand' + str(rcand),
                            vconfig.environment)
                        eval(vconfig, ckpt)

                    tconfig = WikiDressLargeTestDefault()
                    for rcand in [15, 30, 50]:
                        tconfig.max_cand_rules = rcand
                        tconfig.resultdir = get_path(
                            '../' + tconfig.output_folder +
                            '/result/eightref_test_cand' + str(rcand),
                            tconfig.environment)
                        eval(tconfig, ckpt)
                print(
                    '=====================Current Best SARI:%s====================='
                    % best_sari)
                # if float(sari_point) < best_sari:
コード例 #6
0
ファイル: lm.py プロジェクト: rekriz11/sockeye-recipes
"""
Deprecated: Will train our new language model in language_model folder.
"""
import tensorflow as tf
import numpy as np
from collections import defaultdict
from google.protobuf import text_format
import time

from model.model_config import get_path

MAX_WORD_LEN = 50
BASE_PATH = get_path('../text_simplification_data/lm1b/')


class GoogleLM:
    """Get from https://github.com/tensorflow/models/tree/master/research/lm_1b."""
    def __init__(self, batch_size=32):
        self.vocab = CharsVocabulary(BASE_PATH + 'vocab-2016-09-10.txt',
                                     MAX_WORD_LEN)
        self.sess, self.t = self.load_model()
        print('Init GoogleLM Session .')

    def get_batch_weight(self, sentneces, num_steps):
        inputs, targets, weights, char_inputs = self.get_batch_data(
            sentneces, num_steps)
        log_perps = []
        for inp, target, weight, char_input in zip(inputs, targets, weights,
                                                   char_inputs):
            input_dict = {
                self.t['inputs_in']: inp,
コード例 #7
0
from model.model_config import get_path

abbrs = set()

for line in open(get_path('../wsd_data/medline/abbr_common.txt')):
    items = line.strip().split('\t')
    w = items[0]
    ws = w.split('|')
    abbr = ws[0]
    abbrs.add(abbr)

for line in open(get_path('../wsd_data/medline/abbr_rare.txt')):
    items = line.strip().split('\t')
    w = items[0]
    ws = w.split('|')
    abbr = ws[0]
    abbrs.add(abbr)

f = open(get_path('../wsd_data/medline/abbr_all.txt'), 'w')
for abbr in abbrs:
    f.write('\'' + abbr + '_\'')
    f.write('\n')
f.close()
コード例 #8
0
def train(model_config=None):
    model_config = (DefaultConfig() if model_config is None else model_config)

    if model_config.fetch_mode == 'tf_example_dataset':
        data = TfExampleTrainDataset(model_config)
    else:
        data = TrainData(model_config)

    if model_config.framework == 'transformer':
        graph = TransformerGraph(data, True, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(data, True, model_config)
    else:
        raise NotImplementedError('Unknown Framework.')
    graph.create_model_multigpu()

    ckpt_path = None
    if model_config.warm_start:
        ckpt_path = model_config.warm_start
        var_list = slim.get_variables_to_restore()
    if ckpt_path is not None:
        # Handling missing vars by ourselves
        available_vars = {}
        reader = tf.train.NewCheckpointReader(ckpt_path)
        var_dict = {var.op.name: var for var in var_list}
        for var in var_dict:
            if 'global_step' in var and 'optim' not in model_config.warm_config:
                print('Ignore var:', var)
                continue
            if 'optimization' in var and 'optim' not in model_config.warm_config:
                print('Ignore var:', var)
                continue
            if reader.has_tensor(var):
                var_ckpt = reader.get_tensor(var)
                var_cur = var_dict[var]
                if any([
                        var_cur.shape[i] != var_ckpt.shape[i]
                        for i in range(len(var_ckpt.shape))
                ]):
                    print('Variable missing due to shape.', var)
                else:
                    available_vars[var] = var_dict[var]
            else:
                print('Variable missing:', var)

        partial_restore_ckpt = slim.assign_from_checkpoint_fn(
            ckpt_path,
            available_vars,
            ignore_missing_vars=False,
            reshape_variables=False)

    if model_config.bert_mode:
        bert_restore_ckpt = utils.restore_bert(ckpt=model_config.bert_ckpt)

    if 'direct' in model_config.memory:
        bert_direct_restore_ckpt = utils.restore_bert(
            ckpt=model_config.bert_ckpt, model='direct/')

    sess = tf.train.MonitoredTrainingSession(
        checkpoint_dir=model_config.logdir,
        save_checkpoint_secs=model_config.save_model_secs,
        config=session.get_session_config(model_config),
        hooks=[
            tf.train.CheckpointSaverHook(
                model_config.logdir,
                save_secs=model_config.save_model_secs,
                saver=graph.saver)
        ],
        save_summaries_steps=None,
        save_summaries_secs=None,  # Disable tf.summary
    )

    if checkpoint.is_fresh_run(
            model_config.logdir) and 'init' in model_config.bert_mode:
        if model_config.bert_mode:
            if 'direct' in model_config.memory:
                bert_direct_restore_ckpt(sess)
            # else:
            bert_restore_ckpt(sess)
            print('BERT init')

    if checkpoint.is_fresh_run(model_config.logdir):
        if ckpt_path is not None:
            partial_restore_ckpt(sess)
            print('Restore from %s' % ckpt_path)

    perplexitys = []
    start_time = datetime.now()

    # Intialize tf example dataset reader
    if model_config.fetch_mode == 'tf_example_dataset':
        if model_config.dmode == 'listalter':
            assert type(data.training_init_op) == list
            for init_op in data.training_init_op:
                sess.run(init_op)
        else:
            sess.run(data.training_init_op)
            print('Init dataset interator.')
            if model_config.dmode == 'alter':
                sess.run(data.training_init_op2)
                print('Init dataset2 interator.')

    # with tf.contrib.tfprof.ProfileContext('/zfs1/hdaqing/saz31/text_simplification_0924/bertbaseal2_ls/profile') as pctx:
    while True:
        fetches = [
            graph.train_op, graph.loss, graph.global_step, graph.perplexity,
            graph.ops, graph.increment_global_step, graph.loss_style
        ]
        if model_config.fetch_mode:
            _, loss, step, perplexity, _, _, loss_style = sess.run(fetches)
        else:
            input_feed = get_graph_train_data(data, graph.objs, model_config)
            _, loss, step, perplexity, _, _ = sess.run(fetches, input_feed)
        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' %
                  (perplexity, step, time_span))
            if 'pred' in model_config.tune_mode:
                print('Loss:%s\tLoss_tyle:%s' % (loss, loss_style))
            perplexitys.clear()
            if step / model_config.model_print_freq == 1:
                print_cpu_usage()
                print_cpu_memory()
                print_gpu_memory()

        #if step % (100 * model_config.model_print_freq) == 0:
        #    graph.saver.save(sess, join(model_config.logdir, 'bk.ckpt-', step))

        if model_config.model_eval_freq > 0 and step % model_config.model_eval_freq == 0:
            if args.mode == 'dress':
                from model.model_config import WikiDressLargeDefault, WikiDressLargeEvalDefault, \
                    WikiDressLargeTestDefault
                model_config = WikiDressLargeDefault()
                ckpt = get_ckpt(model_config.modeldir, model_config.logdir)

                vconfig = WikiDressLargeEvalDefault()
                best_sari = get_best_sari(vconfig.resultdir)
                sari_point = eval(vconfig, ckpt)
                eval(WikiDressLargeTestDefault(), ckpt)
                if args.memory is not None and 'rule' in args.memory:
                    for rcand in [15, 30, 50]:
                        vconfig.max_cand_rules = rcand
                        vconfig.resultdir = get_path(
                            '../' + vconfig.output_folder +
                            '/result/eightref_val_cand' + str(rcand),
                            vconfig.environment)
                        eval(vconfig, ckpt)
                print(
                    '=====================Current Best SARI:%s====================='
                    % best_sari)
                if float(sari_point) < best_sari:
                    remove(ckpt + '.index')
                    remove(ckpt + '.meta')
                    remove(ckpt + '.data-00000-of-00001')
                    print('remove ckpt:%s' % ckpt)
                else:
                    for file in listdir(model_config.modeldir):
                        step = ckpt[ckpt.rindex('model.ckpt-') +
                                    len('model.ckpt-'):-1]
                        if step not in file:
                            remove(model_config.modeldir + file)
                    print('Get Best Model, remove ckpt except:%s.' % ckpt)