예제 #1
0
 def __init__(self, model_config):
     self.model_config = model_config
     self.voc_abstr = Vocab(self.model_config,
                            vocab_path=self.model_config.path_abstr_voc)
     self.voc_kword = Vocab(self.model_config,
                            vocab_path=self.model_config.path_kword_voc)
     self.populate_data()
     self.size = len(self.data)
예제 #2
0
 def __init__(self, model_config, is_train):
     self.model_config = model_config
     self.is_train = is_train
     self.voc_abstr = Vocab(self.model_config,
                            vocab_path=self.model_config.path_abstr_voc)
     self.voc_kword = Vocab(self.model_config,
                            vocab_path=self.model_config.path_kword_voc)
     self.hparams = transformer.transformer_base()
     self.setup_hparams()
예제 #3
0
    def __init__(self, model_config):
        self.model_config = model_config

        vocab_simple_path = self.model_config.vocab_simple
        vocab_complex_path = self.model_config.vocab_complex
        vocab_all_path = self.model_config.vocab_all
        if self.model_config.subword_vocab_size > 0:
            vocab_simple_path = self.model_config.subword_vocab_simple
            vocab_complex_path = self.model_config.subword_vocab_complex
            vocab_all_path = self.model_config.subword_vocab_all

        if (self.model_config.tie_embedding == 'none'
                or self.model_config.tie_embedding == 'dec_out'):
            self.vocab_simple = Vocab(model_config, vocab_simple_path)
            self.vocab_complex = Vocab(model_config, vocab_complex_path)
        elif (self.model_config.tie_embedding == 'all'
              or self.model_config.tie_embedding == 'enc_dec'):
            self.vocab_simple = Vocab(model_config, vocab_all_path)
            self.vocab_complex = Vocab(model_config, vocab_all_path)

        # Populate basic complex simple pairs
        self.data = self.populate_data(self.vocab_complex, self.vocab_simple,
                                       True)
        self.data_complex_raw_lines = self.populate_data_rawfile(
            self.model_config.val_dataset_complex_rawlines_file)
        # Populate simple references
        self.data_references_raw_lines = []
        for i in range(self.model_config.num_refs):
            ref_tmp_rawlines = self.populate_data_rawfile(
                self.model_config.val_dataset_simple_folder +
                self.model_config.val_dataset_simple_rawlines_file_references +
                str(i))
            self.data_references_raw_lines.append(ref_tmp_rawlines)

        if self.model_config.replace_ner:
            self.mapper = load_mappers(self.model_config.val_mapper,
                                       self.model_config.lower_case)
            while len(self.mapper) < len(self.data):
                self.mapper.append({})

        assert len(self.data_complex_raw_lines) == len(self.data)
        assert len(self.mapper) == len(self.data)
        for i in range(self.model_config.num_refs):
            assert len(self.data_references_raw_lines[i]) == len(self.data)
        print(
            'Use Val Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d' %
            (self.model_config.val_dataset_simple_folder +
             self.model_config.val_dataset_simple_file,
             self.model_config.val_dataset_complex, len(self.data)))

        if 'rule' in self.model_config.memory:
            self.vocab_rule = Rule(model_config, self.model_config.vocab_rules)
            self.rules = self.populate_rules(
                self.model_config.val_dataset_complex_ppdb, self.vocab_rule)
            print('Populate Rule with size:%s' %
                  self.vocab_rule.get_rule_size())
예제 #4
0
    def __init__(self, model_config):
        self.model_config = model_config

        vocab_simple_path = self.model_config.vocab_simple
        vocab_complex_path = self.model_config.vocab_complex
        vocab_all_path = self.model_config.vocab_all
        if self.model_config.subword_vocab_size > 0:
            vocab_simple_path = self.model_config.subword_vocab_simple
            vocab_complex_path = self.model_config.subword_vocab_complex
            vocab_all_path = self.model_config.subword_vocab_all

        data_simple_path = self.model_config.train_dataset_simple
        data_complex_path = self.model_config.train_dataset_complex

        if (self.model_config.tie_embedding == 'none'
                or self.model_config.tie_embedding == 'dec_out'):
            self.vocab_simple = Vocab(model_config, vocab_simple_path)
            self.vocab_complex = Vocab(model_config, vocab_complex_path)
        elif (self.model_config.tie_embedding == 'all'
              or self.model_config.tie_embedding == 'enc_dec'):
            self.vocab_simple = Vocab(model_config, vocab_all_path)
            self.vocab_complex = Vocab(model_config, vocab_all_path)

        self.size = self.get_size(data_complex_path)
        if self.model_config.use_dataset2:
            self.size2 = self.get_size(
                self.model_config.train_dataset_complex2)
        # Populate basic complex simple pairs
        if not self.model_config.it_train:
            self.data = self.populate_data(data_complex_path, data_simple_path,
                                           self.vocab_complex,
                                           self.vocab_simple, True)
        else:
            self.data_it = self.get_data_sample_it(data_simple_path,
                                                   data_complex_path)

        print(
            'Use Train Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d.'
            % (data_simple_path, data_complex_path, self.size))

        if 'rule' in self.model_config.memory or 'rule' in self.model_config.rl_configs:
            self.vocab_rule = Rule(model_config, self.model_config.vocab_rules)
            self.rules_target, self.rules_align = self.populate_rules(
                self.model_config.train_dataset_complex_ppdb, self.vocab_rule)
            print(len(self.rules_align))
            print(self.rules_align[0])
            print(self.rules_align[94206])
            print(self.size)
            print(len(self.rules_target))
            print(self.rules_target[0])
            print(self.rules_target[94206])
            assert len(self.rules_align) == self.size
            assert len(self.rules_target) == self.size
            print('Populate Rule with size:%s' %
                  self.vocab_rule.get_rule_size())
예제 #5
0
 def __init__(self, model_config):
     self.model_config = model_config
     self.voc_abstr = Vocab(self.model_config,
                            vocab_path=self.model_config.path_abstr_voc)
     self.voc_kword = Vocab(self.model_config,
                            vocab_path=self.model_config.path_kword_voc)
     self.populate_data()
     self.size = len(self.data)
     if model_config.eval_mode == 'truncate2000':
         self.data = self.data[:2000]
         self.size = len(self.data)
         assert self.size == 2000
예제 #6
0
 def __init__(self, model_config):
     self.model_config = model_config
     if self.model_config.tied_embedding == 'enc|dec':
         self.voc_abstr = Vocab(self.model_config, vocab_path=self.model_config.path_abstrkword_voc)
         self.voc_kword = Vocab(self.model_config, vocab_path=self.model_config.path_abstrkword_voc)
     else:
         self.voc_abstr = Vocab(self.model_config, vocab_path=self.model_config.path_abstr_voc)
         self.voc_kword = Vocab(self.model_config, vocab_path=self.model_config.path_kword_voc)
     self.populate_data()
     self.size = len(self.data)
     if model_config.eval_mode == 'truncate2000':
         self.data = self.data[:2000]
         self.size = len(self.data)
예제 #7
0
파일: data.py 프로젝트: Astroneko404/wsd
    def __init__(self, model_config):
        self.model_config = model_config
        # For Abbr
        self.populate_abbr()
        # For Context
        self.voc = Vocab(model_config, model_config.voc_file)

        if 'stype' in model_config.extra_mode or 'def' in model_config.extra_mode:
            self.populate_cui()
예제 #8
0
from data_generator.vocab import Vocab
from util import constant
from types import SimpleNamespace
from collections import Counter

model_config = SimpleNamespace(min_count=0,
                               subword_vocab_size=50000,
                               lower_case=True)
vocab = Vocab(model_config,
              '/zfs1/hdaqing/saz31/dataset/vocab/all30k.subvocab')
ids = vocab.encode(constant.SYMBOL_START +
                   ' -lrb- . #pad# #pad# #pad# #pad# #pad#')

print(ids)
print('=====')

print([vocab.describe([id]) for id in ids])
print(vocab.describe(ids))

print([vocab.subword.all_subtoken_strings[id] for id in ids])
예제 #9
0
from model.model_config import BaseConfig
from data_generator.vocab import Vocab
import pickle
import re
from collections import Counter

c = Counter()
max_len = -1
voc = Vocab(BaseConfig(), '/Users/sanqiangzhao/git/wsd_data/mimic/subvocab')
with open('/Users/sanqiangzhao/git/wsd_data/mimic/cui_extra.pkl',
          'rb') as cui_file:
    cui_extra = pickle.load(cui_file)
    for cui in cui_extra:
        info = cui_extra[cui]
        text = info[0]
        text = re.compile(r'<[^>]+>').sub('', text)
        l = len(voc.encode(text))
        max_len = max(max_len, l)
        c.update([l])

print(max_len)
print(c.most_common())
from data_generator.vocab import Vocab
from util import constant
from types import SimpleNamespace
from collections import Counter
import os
import json

model_config = SimpleNamespace(min_count=0, subword_vocab_size=0, lower_case=True, bert_mode=['bert_token'], top_count=9999999999999)
vocab = Vocab(model_config, '/zfs1/hdaqing/saz31/dataset/vocab/bert/vocab_30k')
base = '/zfs1/hdaqing/saz31/dataset/tmp_trans/ner/features/'
max_len = 0
for file in os.listdir(base):
    f = open(base + file)
    for line in f:
        obj = json.loads(line)
        rule = obj['ppdb_rule'].split('\t')

        for pair in rule:
            items = pair.split('=>')
            if len(items) != 3 or '-' in pair or '\'' in pair or '"' in pair or ',' in pair:
                continue
            words = items[1].lower().split()
            words = vocab.encode(' '.join(words))
            if len(words) >= max_len:
                print(pair)
                max_len = len(words)
            # max_len = max(max_len, len(words))

    print('cur max_len:%s with file %s' % (max_len, file))

print(max_len)
예제 #11
0
"""Get max lens for subvoc dataset for trans data"""
from data_generator.vocab import Vocab
from model.model_config import WikiTransTrainConfig

vocab_comp = Vocab(
    WikiTransTrainConfig(),
    '/Users/sanqiangzhao/git/text_simplification_data/vocab/comp30k.subvocab')
vocab_simp = Vocab(
    WikiTransTrainConfig(),
    '/Users/sanqiangzhao/git/text_simplification_data/vocab/simp30k.subvocab')

max_l_comp, max_l_simp = 0, 0
for line in open(
        '/Users/sanqiangzhao/git/text_simplification_data/val_0930/words_comps'
):
    l_comp = len(vocab_comp.encode(line))
    l_simp = len(vocab_simp.encode(line))
    max_l_comp = max(max_l_comp, l_comp)
    max_l_simp = max(max_l_simp, l_simp)

print(max_l_comp)
print(max_l_simp)
예제 #12
0
 def __init__(self, model_config):
     self.model_config = model_config
     # For Abbr
     self.populate_abbr()
     # For Context
     self.voc = Vocab(model_config, model_config.voc_file)
예제 #13
0
"""Valid proper length"""
import json
from collections import Counter
from data_generator.vocab import Vocab
from model.model_config import DefaultConfig

PATH = '/Users/sanqiangzhao/git/kp/keyphrase_data/kp20k_cleaned/tf_data/kp20k.valid.one2many.json'

voc_abstr = Vocab(
    DefaultConfig(),
    vocab_path=
    '/Users/sanqiangzhao/git/kp/keyphrase_data/kp20k_cleaned/tf_data/abstr.subvoc'
)
voc_kword = Vocab(
    DefaultConfig(),
    vocab_path=
    '/Users/sanqiangzhao/git/kp/keyphrase_data/kp20k_cleaned/tf_data/kword.subvoc'
)

max_abstr_len, max_kword_len, max_kword_cnt = 0, 0, 0
c_abstr_len, c_kwod_len, c_kword_cnt = Counter(), Counter(), Counter()

for line in open(PATH):
    obj = json.loads(line)
    kphrases = obj['kphrases'].split(';')
    abstr = obj['abstr']

    abstr_ids = voc_abstr.encode(abstr)
    if len(abstr_ids) > max_abstr_len:
        print(abstr)
    max_abstr_len = max(max_abstr_len, len(abstr_ids))
예제 #14
0
from data_generator.vocab import Vocab
from util import constant
from types import SimpleNamespace
from collections import Counter

model_config = SimpleNamespace(min_count=0,
                               subword_vocab_size=50000,
                               lower_case=True)
vocab = Vocab(
    model_config,
    '/Users/zhaosanqiang916/git/text_simplification_data/wiki/voc/voc_all_sub50k.txt'
)
# vocab = Vocab(model_config, '/Users/zhaosanqiang916/git/text_simplification_data/train/dress/wikilarge/wiki.full.aner.train.src.vocab.lower')
f = open(
    '/Users/zhaosanqiang916/git/text_simplification_data/wiki/ner3/ner_comp.txt'
)
c = Counter()
cnt = 0
tcnt = 0
for line in f:
    words = line.lower().split()
    words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
    words = vocab.encode(' '.join(words))
    tcnt += 1
    if len(words) > 200:
        # c.update([len(words)])
        # print(len(words))
        cnt += 1
    if tcnt % 100000 == 0:
        print(cnt / tcnt)