示例#1
0
class TrainData:
    def __init__(self, model_config):
        self.model_config = model_config
        self.voc_abstr = Vocab(self.model_config,
                               vocab_path=self.model_config.path_abstr_voc)
        self.voc_kword = Vocab(self.model_config,
                               vocab_path=self.model_config.path_kword_voc)
        self.populate_data()
        self.size = len(self.data)

    def populate_data(self):
        pad_id = self.voc_abstr.encode(constant.SYMBOL_PAD)
        self.data = []
        for line in open(self.model_config.path_train_json):
            try:
                obj = json.loads(line.strip())
                if self.model_config.subword_vocab_size > 0:
                    abstr = self.voc_abstr.encode(
                        ' '.join([constant.SYMBOL_START] +
                                 obj['title'].split() +
                                 obj['abstract'].split() +
                                 [constant.SYMBOL_END]))
                    if len(abstr) > self.model_config.max_abstr_len:
                        abstr = abstr[:self.model_config.max_abstr_len]
                    else:
                        num_pad = self.model_config.max_abstr_len - len(abstr)
                        abstr.extend(num_pad * pad_id)
                    kwords = [
                        self.voc_kword.encode(
                            ' '.join([constant.SYMBOL_START] +
                                     kphrase.split() + [constant.SYMBOL_END]))
                        for kphrase in obj['kphrases'].split(';')
                    ]
                    for kword_id, kword in enumerate(kwords):
                        if len(kword) >= self.model_config.max_kword_len:
                            kwords[kword_id] = kword[:self.model_config.
                                                     max_kword_len]
                        else:
                            num_pad = self.model_config.max_kword_len - len(
                                kword)
                            kwords[kword_id].extend(num_pad * pad_id)
                else:
                    abstr = [
                        self.voc_abstr.encode(w)
                        for w in [constant.SYMBOL_START] +
                        obj['title'].split() + obj['abstract'].split() +
                        [constant.SYMBOL_END]
                    ]
                    if len(abstr) > self.model_config.max_abstr_len:
                        abstr = abstr[:self.model_config.max_abstr_len]
                    else:
                        num_pad = self.model_config.max_abstr_len - len(abstr)
                        abstr.extend(num_pad * [pad_id])

                    kwords = [[
                        self.voc_kword.encode(w)
                        for w in [constant.SYMBOL_START] + kphrase.split() +
                        [constant.SYMBOL_END]
                    ] for kphrase in obj['kphrases'].split(';')]
                    for kword_id, kword in enumerate(kwords):
                        if len(kword) >= self.model_config.max_kword_len:
                            kwords[kword_id] = kword[:self.model_config.
                                                     max_kword_len]
                        else:
                            num_pad = self.model_config.max_kword_len - len(
                                kword)
                            kwords[kword_id].extend(num_pad * [pad_id])
            except:
                print('json error:')

            self.data.append({'abstr': abstr, 'kwords': kwords})

    def get_data_sample(self):
        i = rd.sample(range(self.size), 1)[0]
        return self.data[i]
示例#2
0
from data_generator.vocab import Vocab
from util import constant
from types import SimpleNamespace
from collections import Counter

model_config = SimpleNamespace(min_count=0,
                               subword_vocab_size=50000,
                               lower_case=True)
vocab = Vocab(model_config,
              '/zfs1/hdaqing/saz31/dataset/vocab/all30k.subvocab')
ids = vocab.encode(constant.SYMBOL_START +
                   ' -lrb- . #pad# #pad# #pad# #pad# #pad#')

print(ids)
print('=====')

print([vocab.describe([id]) for id in ids])
print(vocab.describe(ids))

print([vocab.subword.all_subtoken_strings[id] for id in ids])
示例#3
0
class ValData:
    def __init__(self, model_config):
        self.model_config = model_config

        vocab_simple_path = self.model_config.vocab_simple
        vocab_complex_path = self.model_config.vocab_complex
        vocab_all_path = self.model_config.vocab_all
        if self.model_config.subword_vocab_size > 0:
            vocab_simple_path = self.model_config.subword_vocab_simple
            vocab_complex_path = self.model_config.subword_vocab_complex
            vocab_all_path = self.model_config.subword_vocab_all

        if (self.model_config.tie_embedding == 'none' or
                    self.model_config.tie_embedding == 'dec_out'):
            self.vocab_simple = Vocab(model_config, vocab_simple_path)
            self.vocab_complex = Vocab(model_config, vocab_complex_path)
        elif (self.model_config.tie_embedding == 'all' or
                    self.model_config.tie_embedding == 'enc_dec'):
            self.vocab_simple = Vocab(model_config, vocab_all_path)
            self.vocab_complex = Vocab(model_config, vocab_all_path)

        # Populate basic complex simple pairs
        self.data = self.populate_data(self.vocab_complex, self.vocab_simple, True)
        self.data_complex_raw_lines = self.populate_data_rawfile(
            self.model_config.val_dataset_complex_rawlines_file)
        # Populate simple references
        self.data_references_raw_lines = []
        for i in range(self.model_config.num_refs):
            ref_tmp_rawlines = self.populate_data_rawfile(
                self.model_config.val_dataset_simple_folder +
                self.model_config.val_dataset_simple_rawlines_file_references)
            self.data_references_raw_lines.append(ref_tmp_rawlines)

        if self.model_config.replace_ner:
            self.mapper = load_mappers(self.model_config.val_mapper, self.model_config.lower_case)
            while len(self.mapper) < len(self.data):
                self.mapper.append({})

        assert len(self.data_complex_raw_lines) == len(self.data)
        assert len(self.mapper) == len(self.data)
        for i in range(self.model_config.num_refs):
            assert len(self.data_references_raw_lines[i]) == len(self.data)
        print('Use Val Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d'
              % (self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file,
                 self.model_config.val_dataset_complex, len(self.data)))

        if 'rule' in self.model_config.memory:
            self.vocab_rule = Rule(model_config, self.model_config.vocab_rules)
            self.rules = self.populate_rules(
                self.model_config.val_dataset_complex_ppdb, self.vocab_rule)
            print('Populate Rule with size:%s' % self.vocab_rule.get_rule_size())

    def populate_rules(self, rule_path, vocab_rule):
        data = []
        for line in open(rule_path, encoding='utf-8'):
            cur_rules = line.split('\t')
            tmp = []
            for cur_rule in cur_rules:
                rule_id, _, rule_targets = vocab_rule.encode(cur_rule)
                if rule_targets is not None:
                    tmp.append((rule_id, [self.vocab_simple.encode(rule_target) for rule_target in rule_targets]))
            data.append(tmp)
        return data

    def populate_data_rawfile(self, data_path):
        """Populate data raw lines into memory"""
        data = []
        for line in open(data_path, encoding='utf-8'):
            data.append(line.strip())
        return data

    def process_line(self, line, vocab, max_len, need_raw=False):
        if self.model_config.tokenizer == 'split':
            words = line.split()
        elif self.model_config.tokenizer == 'nltk':
            words = word_tokenize(line)
        else:
            raise Exception('Unknown tokenizer.')

        words = [Vocab.process_word(word, self.model_config)
                 for word in words]
        if need_raw:
            words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
        else:
            words_raw = None

        if self.model_config.subword_vocab_size > 0:
            words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
            words = vocab.encode(' '.join(words))
        else:
            words = [vocab.encode(word) for word in words]
            words = ([self.vocab_simple.encode(constant.SYMBOL_START)] + words +
                     [self.vocab_simple.encode(constant.SYMBOL_END)])

        if self.model_config.subword_vocab_size > 0:
            pad_id = vocab.encode(constant.SYMBOL_PAD)
        else:
            pad_id = [vocab.encode(constant.SYMBOL_PAD)]

        if len(words) < max_len:
            num_pad = max_len - len(words)
            words.extend(num_pad * pad_id)
        else:
            words = words[:max_len]

        return words, words_raw

    def populate_data(self, vocab_comp, vocab_simp, need_raw=False):
        # Populate data into memory
        data = []
        # max_len = -1
        # from collections import Counter
        # len_report = Counter()
        lines_comp = open(
            self.model_config.val_dataset_complex, encoding='utf-8').readlines()
        lines_simp = open(
            self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file,
            encoding='utf-8').readlines()
        assert len(lines_comp) == len(lines_simp)
        for line_id in range(len(lines_comp)):
            obj = {}
            line_comp = lines_comp[line_id]
            line_simp = lines_simp[line_id]
            words_comp, words_raw_comp = self.process_line(
                line_comp, vocab_comp, self.model_config.max_complex_sentence, need_raw)
            words_simp, words_raw_simp = self.process_line(
                line_simp, vocab_simp, self.model_config.max_simple_sentence, need_raw)

            obj['words_comp'] = words_comp
            obj['words_simp'] = words_simp
            if need_raw:
                obj['words_raw_comp'] = words_raw_comp
                obj['words_raw_simp'] = words_raw_simp

            data.append(obj)
            # len_report.update([len(words)])
            # if len(words) > max_len:
            #     max_len = len(words)
        # print('Max length for data %s is %s.' % (data_path, max_len))
        # print('counter:%s' % len_report)
        return data

    def get_data_iter(self):
        i = 0
        while True:
            if i % 100 == 0:
                print("Processed " + str(i) + " examples so far")
            ref_rawlines_batch = [self.data_references_raw_lines[j][i]
                                  for j in range(self.model_config.num_refs)]
            supplement = {}
            if 'rule' in self.model_config.memory:
                try:
                    supplement['mem'] = self.rules[i]
                except IndexError:
                    print("****INDEX ERROR: " + str(i))
                    yield None

            obj = {
                'sentence_simple': self.data[i]['words_simp'],
                'sentence_complex': self.data[i]['words_comp'],
                'sentence_complex_raw': self.data[i]['words_raw_comp'],
                'sentence_simple_raw': self.data[i]['words_raw_simp'],
                'sentence_complex_raw_lines': self.data_complex_raw_lines[i],
                'mapper': self.mapper[i],
                'ref_raw_lines': ref_rawlines_batch,
                'sup': supplement,
            }

            yield obj

            i += 1
            if i == len(self.data):
                yield None
示例#4
0
from model.model_config import BaseConfig
from data_generator.vocab import Vocab
import pickle
import re
from collections import Counter

c = Counter()
max_len = -1
voc = Vocab(BaseConfig(), '/Users/sanqiangzhao/git/wsd_data/mimic/subvocab')
with open('/Users/sanqiangzhao/git/wsd_data/mimic/cui_extra.pkl',
          'rb') as cui_file:
    cui_extra = pickle.load(cui_file)
    for cui in cui_extra:
        info = cui_extra[cui]
        text = info[0]
        text = re.compile(r'<[^>]+>').sub('', text)
        l = len(voc.encode(text))
        max_len = max(max_len, l)
        c.update([l])

print(max_len)
print(c.most_common())
示例#5
0
"""Get max lens for subvoc dataset for trans data"""
from data_generator.vocab import Vocab
from model.model_config import WikiTransTrainConfig

vocab_comp = Vocab(
    WikiTransTrainConfig(),
    '/Users/sanqiangzhao/git/text_simplification_data/vocab/comp30k.subvocab')
vocab_simp = Vocab(
    WikiTransTrainConfig(),
    '/Users/sanqiangzhao/git/text_simplification_data/vocab/simp30k.subvocab')

max_l_comp, max_l_simp = 0, 0
for line in open(
        '/Users/sanqiangzhao/git/text_simplification_data/val_0930/words_comps'
):
    l_comp = len(vocab_comp.encode(line))
    l_simp = len(vocab_simp.encode(line))
    max_l_comp = max(max_l_comp, l_comp)
    max_l_simp = max(max_l_simp, l_simp)

print(max_l_comp)
print(max_l_simp)
from data_generator.vocab import Vocab
from util import constant
from types import SimpleNamespace
from collections import Counter
import os
import json

model_config = SimpleNamespace(min_count=0, subword_vocab_size=0, lower_case=True, bert_mode=['bert_token'], top_count=9999999999999)
vocab = Vocab(model_config, '/zfs1/hdaqing/saz31/dataset/vocab/bert/vocab_30k')
base = '/zfs1/hdaqing/saz31/dataset/tmp_trans/ner/features/'
max_len = 0
for file in os.listdir(base):
    f = open(base + file)
    for line in f:
        obj = json.loads(line)
        rule = obj['ppdb_rule'].split('\t')

        for pair in rule:
            items = pair.split('=>')
            if len(items) != 3 or '-' in pair or '\'' in pair or '"' in pair or ',' in pair:
                continue
            words = items[1].lower().split()
            words = vocab.encode(' '.join(words))
            if len(words) >= max_len:
                print(pair)
                max_len = len(words)
            # max_len = max(max_len, len(words))

    print('cur max_len:%s with file %s' % (max_len, file))

print(max_len)
class TrainData:
    def __init__(self, model_config):
        self.model_config = model_config

        vocab_simple_path = self.model_config.vocab_simple
        vocab_complex_path = self.model_config.vocab_complex
        vocab_all_path = self.model_config.vocab_all
        if self.model_config.subword_vocab_size > 0:
            vocab_simple_path = self.model_config.subword_vocab_simple
            vocab_complex_path = self.model_config.subword_vocab_complex
            vocab_all_path = self.model_config.subword_vocab_all

        data_simple_path = self.model_config.train_dataset_simple
        data_complex_path = self.model_config.train_dataset_complex

        if (self.model_config.tie_embedding == 'none'
                or self.model_config.tie_embedding == 'dec_out'):
            self.vocab_simple = Vocab(model_config, vocab_simple_path)
            self.vocab_complex = Vocab(model_config, vocab_complex_path)
        elif (self.model_config.tie_embedding == 'all'
              or self.model_config.tie_embedding == 'enc_dec'):
            self.vocab_simple = Vocab(model_config, vocab_all_path)
            self.vocab_complex = Vocab(model_config, vocab_all_path)

        self.size = self.get_size(data_complex_path)
        if self.model_config.use_dataset2:
            self.size2 = self.get_size(
                self.model_config.train_dataset_complex2)
        # Populate basic complex simple pairs
        if not self.model_config.it_train:
            self.data = self.populate_data(data_complex_path, data_simple_path,
                                           self.vocab_complex,
                                           self.vocab_simple, True)
        else:
            self.data_it = self.get_data_sample_it(data_simple_path,
                                                   data_complex_path)

        print(
            'Use Train Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d.'
            % (data_simple_path, data_complex_path, self.size))

        if 'rule' in self.model_config.memory or 'rule' in self.model_config.rl_configs:
            self.vocab_rule = Rule(model_config, self.model_config.vocab_rules)
            self.rules_target, self.rules_align = self.populate_rules(
                self.model_config.train_dataset_complex_ppdb, self.vocab_rule)
            assert len(self.rules_align) == self.size
            assert len(self.rules_target) == self.size
            print('Populate Rule with size:%s' %
                  self.vocab_rule.get_rule_size())
            # if self.model_config.use_dataset2:
            #     self.rules2 = self.populate_rules(
            #         self.model_config.train_dataset_complex_ppdb2, self.vocab_rule)
            #     assert len(self.rules2) == self.size2

    def process_line(self, line, vocab, max_len, need_raw=False):
        if self.model_config.tokenizer == 'split':
            words = line.split()
        elif self.model_config.tokenizer == 'nltk':
            words = word_tokenize(line)
        else:
            raise Exception('Unknown tokenizer.')

        words = [Vocab.process_word(word, self.model_config) for word in words]
        if need_raw:
            words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
        else:
            words_raw = None

        if self.model_config.subword_vocab_size > 0:
            words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
            words = vocab.encode(' '.join(words))
        else:
            words = [vocab.encode(word) for word in words]
            words = ([self.vocab_simple.encode(constant.SYMBOL_START)] +
                     words + [self.vocab_simple.encode(constant.SYMBOL_END)])

        if self.model_config.subword_vocab_size > 0:
            pad_id = vocab.encode(constant.SYMBOL_PAD)
        else:
            pad_id = [vocab.encode(constant.SYMBOL_PAD)]

        if len(words) < max_len:
            num_pad = max_len - len(words)
            words.extend(num_pad * pad_id)
        else:
            words = words[:max_len]

        return words, words_raw

    def get_size(self, data_complex_path):
        return len(open(data_complex_path, encoding='utf-8').readlines())

    def get_data_sample_it(self, data_simple_path, data_complex_path):
        f_simple = open(data_simple_path, encoding='utf-8')
        f_complex = open(data_complex_path, encoding='utf-8')
        # if self.model_config.use_dataset2:
        #     f_simple2 = open(self.model_config.train_dataset_simple2, encoding='utf-8')
        #     f_complex2 = open(self.model_config.train_dataset_complex2, encoding='utf-8')
        #     j = 0
        i = 0
        while True:
            if i >= self.size:
                f_simple = open(data_simple_path, encoding='utf-8')
                f_complex = open(data_complex_path, encoding='utf-8')
                i = 0
            line_complex = f_complex.readline()
            line_simple = f_simple.readline()
            if rd.random() < 0.5 or i >= self.size:
                i += 1
                continue

            words_complex, words_raw_comp = self.process_line(
                line_complex, self.vocab_complex,
                self.model_config.max_complex_sentence, True)
            words_simple, words_raw_simp = self.process_line(
                line_simple, self.vocab_simple,
                self.model_config.max_simple_sentence, True)

            supplement = {}
            if 'rule' in self.model_config.memory:
                supplement['rules_target'] = self.rules_target[i]
                supplement['rules_align'] = self.rules_align[i]

            obj = {}
            obj['words_comp'] = words_complex
            obj['words_simp'] = words_simple
            obj['words_raw_comp'] = words_raw_comp
            obj['words_raw_simp'] = words_raw_simp

            yield i, obj, supplement

            i += 1

            # if self.model_config.use_dataset2:
            #     if j == self.size2:
            #         f_simple2 = open(self.model_config.train_dataset_simple2, encoding='utf-8')
            #         f_complex2 = open(self.model_config.train_dataset_complex2, encoding='utf-8')
            #         j = 0
            #     line_complex2 = f_complex2.readline()
            #     line_simple2 = f_simple2.readline()
            #     words_complex2, _ = self.process_line(line_complex2, self.vocab_complex)
            #     words_simple2, _ = self.process_line(line_simple2, self.vocab_simple)
            #
            #     supplement2 = {}
            #     if self.model_config.memory == 'rule':
            #         supplement2['mem'] = self.rules2[j]
            #
            #     yield j, words_simple2, words_complex2, cp.deepcopy([1.0] * len(words_simple2)), cp.deepcopy([1.0] * len(words_complex2)), supplement2
            #     j += 1

    def populate_rules(self, rule_path, vocab_rule):
        data_target, data_align = [], []
        for line in open(rule_path, encoding='utf-8'):
            cur_rules = line.split('\t')
            tmp, tmp_align = [], []
            for cur_rule in cur_rules:
                rule_id, rule_origins, rule_targets = vocab_rule.encode(
                    cur_rule)
                if rule_targets is not None and rule_origins is not None:
                    tmp.append((rule_id, [
                        self.vocab_simple.encode(rule_target)
                        for rule_target in rule_targets
                    ]))

                    if len(rule_origins) == 1 and len(rule_targets) == 1:
                        tmp_align.append(
                            (self.vocab_complex.encode(rule_origins[0]),
                             self.vocab_simple.encode(rule_targets[0])))
            data_target.append(tmp)
            data_align.append(tmp_align)

        return data_target, data_align

    def populate_data(self,
                      data_path_comp,
                      data_path_simp,
                      vocab_comp,
                      vocab_simp,
                      need_raw=False):
        # Populate data into memory
        data = []
        # max_len = -1
        # from collections import Counter
        # len_report = Counter()
        lines_comp = open(data_path_comp, encoding='utf-8').readlines()
        lines_simp = open(data_path_simp, encoding='utf-8').readlines()
        assert len(lines_comp) == len(lines_simp)
        for line_id in range(len(lines_comp)):
            obj = {}
            line_comp = lines_comp[line_id]
            line_simp = lines_simp[line_id]
            words_comp, words_raw_comp = self.process_line(
                line_comp, vocab_comp, self.model_config.max_complex_sentence,
                need_raw)
            words_simp, words_raw_simp = self.process_line(
                line_simp, vocab_simp, self.model_config.max_simple_sentence,
                need_raw)
            obj['words_comp'] = words_comp
            obj['words_simp'] = words_simp
            if need_raw:
                obj['words_raw_comp'] = words_raw_comp
                obj['words_raw_simp'] = words_raw_simp

            data.append(obj)
        return data

    def get_data_sample(self):
        i = rd.sample(range(self.size), 1)[0]
        supplement = {}
        if 'rule' in self.model_config.memory:
            supplement['rules_target'] = self.rules_target[i]
            supplement['rules_align'] = self.rules_align[i]

        return i, self.data[i], supplement
示例#8
0
class EvalData():
    def __init__(self, model_config):
        self.model_config = model_config
        if self.model_config.tied_embedding == 'enc|dec':
            self.voc_abstr = Vocab(self.model_config, vocab_path=self.model_config.path_abstrkword_voc)
            self.voc_kword = Vocab(self.model_config, vocab_path=self.model_config.path_abstrkword_voc)
        else:
            self.voc_abstr = Vocab(self.model_config, vocab_path=self.model_config.path_abstr_voc)
            self.voc_kword = Vocab(self.model_config, vocab_path=self.model_config.path_kword_voc)
        self.populate_data()
        self.size = len(self.data)
        if model_config.eval_mode == 'truncate2000':
            self.data = self.data[:2000]
            self.size = len(self.data)
            # assert self.size == 2000

    def populate_data(self):
        pad_id = self.voc_abstr.encode(constant.SYMBOL_PAD)
        self.data = []
        for line in open(self.model_config.path_val_json):
            obj = json.loads(line.strip())
            if self.model_config.subword_vocab_size > 0:
                abstr = self.voc_abstr.encode(' '.join(
                        [constant.SYMBOL_START] + obj['title'].split() + obj['abstr'].split() + [constant.SYMBOL_END]))
                if len(abstr) > self.model_config.max_abstr_len:
                    abstr = abstr[:self.model_config.max_abstr_len]
                else:
                    num_pad = self.model_config.max_abstr_len - len(abstr)
                    abstr.extend(num_pad * pad_id)
                kwords = [self.voc_kword.encode(' '.join(
                    [constant.SYMBOL_START] + kphrase.split() + [constant.SYMBOL_END]))
                    for kphrase in obj['kphrases'].split(';')]
                for kword_id, kword in enumerate(kwords):
                    if len(kword) >= self.model_config.max_kword_len:
                        kwords[kword_id] = kword[:self.model_config.max_kword_len]
                    else:
                        num_pad = self.model_config.max_kword_len - len(kword)
                        kwords[kword_id].extend(num_pad * pad_id)
            else:
                abstr = [self.voc_abstr.encode(w) for w
                         in [constant.SYMBOL_START] + obj['title'].split() + obj['abstr'].split() + [constant.SYMBOL_END]]
                if len(abstr) > self.model_config.max_abstr_len:
                    abstr = abstr[:self.model_config.max_abstr_len]
                else:
                    num_pad = self.model_config.max_abstr_len - len(abstr)
                    abstr.extend(num_pad * [pad_id])

                kwords = [[self.voc_kword.encode(w) for w
                           in [constant.SYMBOL_START] + kphrase.split() + [constant.SYMBOL_END]]
                          for kphrase in obj['kphrases'].split(';') ]
                for kword_id, kword in enumerate(kwords):
                    if len(kword) >= self.model_config.max_kword_len:
                        kwords[kword_id] = kword[:self.model_config.max_kword_len]
                    else:
                        num_pad = self.model_config.max_kword_len - len(kword)
                        kwords[kword_id].extend(num_pad * [pad_id])

            abstr_raw = [w for w
                         in [constant.SYMBOL_START] + obj['title'].split() + obj['abstr'].split() + [
                             constant.SYMBOL_END]]
            kwords_raw = [[w for w
                           in [constant.SYMBOL_START] + kphrase.split() + [constant.SYMBOL_END]]
                          for kphrase in obj['kphrases'].split(';')]

            obj = {
                'abstr': abstr,
                'abstr_raw': abstr_raw,
                'kwords': kwords,
                'kwords_raw': kwords_raw,
            }
            self.data.append(obj)

    def get_data_sample_it(self):
        i = 0
        while True:
            if i == self.size:
                yield None
            else:
                yield self.data[i]
                i += 1
示例#9
0
class ValData:
    def __init__(self, model_config):
        self.model_config = model_config

        vocab_simple_path = self.model_config.vocab_simple
        vocab_complex_path = self.model_config.vocab_complex
        vocab_all_path = self.model_config.vocab_all
        if self.model_config.subword_vocab_size > 0:
            vocab_simple_path = self.model_config.subword_vocab_simple
            vocab_complex_path = self.model_config.subword_vocab_complex
            vocab_all_path = self.model_config.subword_vocab_all

        if (self.model_config.tie_embedding == 'none' or
                    self.model_config.tie_embedding == 'dec_out'):
            self.vocab_simple = Vocab(model_config, vocab_simple_path)
            self.vocab_complex = Vocab(model_config, vocab_complex_path)
        elif (self.model_config.tie_embedding == 'all' or
                    self.model_config.tie_embedding == 'enc_dec'):
            self.vocab_simple = Vocab(model_config, vocab_all_path)
            self.vocab_complex = Vocab(model_config, vocab_all_path)

        # Populate basic complex simple pairs
        self.data = self.populate_data(self.vocab_complex, self.vocab_simple, True)
        self.data_complex_raw_lines = self.populate_data_rawfile(
            self.model_config.val_dataset_complex_rawlines_file,
            self.model_config.lower_case)
        # Populate simple references
        self.data_references_raw_lines = []
        for i in range(self.model_config.num_refs):
            ref_tmp_rawlines = self.populate_data_rawfile(
                self.model_config.val_dataset_simple_folder +
                self.model_config.val_dataset_simple_rawlines_file_references +
                str(i), self.model_config.lower_case)
            self.data_references_raw_lines.append(ref_tmp_rawlines)

        if self.model_config.replace_ner:
            self.mapper = load_mappers(self.model_config.val_mapper, self.model_config.lower_case)
            while len(self.mapper) < len(self.data):
                self.mapper.append({})

        assert len(self.data_complex_raw_lines) == len(self.data)
        assert len(self.mapper) == len(self.data)
        for i in range(self.model_config.num_refs):
            assert len(self.data_references_raw_lines[i]) == len(self.data)
        print('Use Val Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d'
              % (self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file,
                 self.model_config.val_dataset_complex, len(self.data)))

        if 'rule' in self.model_config.memory or 'direct' in self.model_config.memory:
            self.vocab_rule = Rule(model_config, self.model_config.vocab_rules)
            self.rules = self.populate_rules(
                self.model_config.val_dataset_complex_ppdb, self.vocab_rule)
            print('Populate Rule with size:%s' % self.vocab_rule.get_rule_size())

        if self.model_config.tune_style:
            self.comp_features = self.populate_comp_features(
                self.model_config.val_dataset_complex_features)

    def populate_comp_features(self, feature_path):
        data = []
        for line in open(feature_path, encoding='utf-8'):
            items = line.split('\t')
            data.append(
                (float(items[0]), float(items[1])))
        return data

    def populate_rules(self, rule_path, vocab_rule):
        data = []
        for line in open(rule_path, encoding='utf-8'):
            cur_rules = line.split('\t')
            tmp = []
            for cur_rule in cur_rules:
                rule_id, _, rule_targets = vocab_rule.encode(cur_rule)
                if rule_targets is not None:
                    if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode:
                        tmp.append((rule_id, self.vocab_simple.encode(rule_targets)))
                    else:
                        tmp.append((rule_id, [self.vocab_simple.encode(rule_target) for rule_target in rule_targets]))
            data.append(tmp)
        return data

    def populate_data_rawfile(self, data_path, lower_case=True):
        """Populate data raw lines into memory"""
        data = []
        for line in open(data_path, encoding='utf-8'):
            if lower_case:
                line = line.lower()
            data.append(line.strip())
        return data

    def populate_data(self, vocab_comp, vocab_simp, need_raw=False):
        # Populate data into memory
        data = []
        # max_len = -1
        # from collections import Counter
        # len_report = Counter()
        lines_comp = open(
            self.model_config.val_dataset_complex, encoding='utf-8').readlines()
        lines_simp = open(
            self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file,
            encoding='utf-8').readlines()
        assert len(lines_comp) == len(lines_simp)
        for line_id in range(len(lines_comp)):
            obj = {}
            line_comp = lines_comp[line_id]
            line_simp = lines_simp[line_id]
            words_comp, words_raw_comp, obj_comp = data_utils.process_line(
                line_comp, vocab_comp, self.model_config.max_complex_sentence, self.model_config, need_raw,
                self.model_config.lower_case)
            words_simp, words_raw_simp, obj_simp = data_utils.process_line(
                line_simp, vocab_simp, self.model_config.max_simple_sentence, self.model_config, need_raw,
                self.model_config.lower_case)

            obj['words_comp'] = words_comp
            obj['words_simp'] = words_simp
            if need_raw:
                obj['words_raw_comp'] = words_raw_comp
                obj['words_raw_simp'] = words_raw_simp
            if self.model_config.subword_vocab_size and self.model_config.seg_mode:
                obj['line_comp_segids'] = obj_comp['segment_idxs']
                obj['line_simp_segids'] = obj_simp['segment_idxs']

            data.append(obj)
            # len_report.update([len(words)])
            # if len(words) > max_len:
            #     max_len = len(words)
        # print('Max length for data %s is %s.' % (data_path, max_len))
        # print('counter:%s' % len_report)
        return data

    def get_data_iter(self):
        i = 0
        while True:
            if i >= len(self.data):
                yield None
            else:
                ref_rawlines_batch = [self.data_references_raw_lines[j][i]
                                      for j in range(self.model_config.num_refs)]
                supplement = {}
                if 'rule' in self.model_config.memory or 'direct' in self.model_config.memory:
                    supplement['mem'] = self.rules[i]

                if self.model_config.tune_style:
                    supplement['comp_features'] = self.comp_features[i]

                obj = {
                    'sentence_simple': self.data[i]['words_simp'],
                    'sentence_complex': self.data[i]['words_comp'],
                    'sentence_complex_raw': self.data[i]['words_raw_comp'],
                    'sentence_simple_raw': self.data[i]['words_raw_simp'],
                    'sentence_complex_raw_lines': self.data_complex_raw_lines[i],
                    'mapper': self.mapper[i],
                    'ref_raw_lines': ref_rawlines_batch,
                    'sup': supplement,
                }

                if self.model_config.subword_vocab_size and self.model_config.seg_mode:
                    obj['line_comp_segids'] = self.data[i]['line_comp_segids']
                    obj['line_simp_segids'] = self.data[i]['line_simp_segids']

                yield obj

                i += 1
示例#10
0
class Graph:
    def __init__(self, model_config, is_train):
        self.model_config = model_config
        self.is_train = is_train
        self.voc_abstr = Vocab(self.model_config,
                               vocab_path=self.model_config.path_abstr_voc)
        self.voc_kword = Vocab(self.model_config,
                               vocab_path=self.model_config.path_kword_voc)
        self.hparams = transformer.transformer_base()
        self.setup_hparams()

    def get_embedding(self):
        emb_init = tf.contrib.layers.xavier_initializer(
        )  # tf.random_uniform_initializer(-0.08, 0.08)
        emb_abstr = tf.get_variable(
            'embedding_abstr',
            [self.voc_abstr.vocab_size(), self.model_config.dimension],
            tf.float32,
            initializer=emb_init)
        emb_kword = tf.get_variable(
            'embedding_kword',
            [self.voc_kword.vocab_size(), self.model_config.dimension],
            tf.float32,
            initializer=emb_init)
        proj_w = tf.get_variable(
            'proj_w',
            [self.voc_kword.vocab_size(), self.model_config.dimension],
            tf.float32,
            initializer=emb_init)
        proj_b = tf.get_variable('proj_b',
                                 shape=[self.voc_kword.vocab_size()],
                                 initializer=emb_init)
        return emb_abstr, emb_kword, proj_w, proj_b

    def embedding_fn(self, inputs, embedding):
        if type(inputs) == list:
            if not inputs:
                return []
            else:
                return [
                    tf.nn.embedding_lookup(embedding, inp) for inp in inputs
                ]
        else:
            return tf.nn.embedding_lookup(embedding, inputs)

    def decode_step(self, kword_input, abstr_outputs, abstr_bias, attn_stick):
        batch_go = [
            tf.zeros(
                [self.model_config.batch_size, self.model_config.dimension])
        ]
        kword_length = len(kword_input) + 1
        kword_input = tf.stack(batch_go + kword_input, axis=1)
        kword_output, new_attn_stick = self.decode_inputs_to_outputs(
            kword_input, abstr_outputs, abstr_bias, attn_stick)
        kword_output_list = [
            tf.squeeze(d, 1)
            for d in tf.split(kword_output, kword_length, axis=1)
        ]
        return kword_output_list, new_attn_stick

    def decode_inputs_to_outputs(self, kword_input, abstr_outputs, abstr_bias,
                                 attn_stick):
        if self.hparams.pos == 'timing':
            kword_input = common_attention.add_timing_signal_1d(kword_input)
        kword_tribias = common_attention.attention_bias_lower_triangle(
            tf.shape(kword_input)[1])
        kword_input = tf.nn.dropout(
            kword_input, 1.0 - self.hparams.layer_prepostprocess_dropout)
        kword_output, new_attn_stick = transformer.transformer_decoder(
            kword_input,
            abstr_outputs,
            kword_tribias,
            abstr_bias,
            self.hparams,
            attn_stick=attn_stick)
        return kword_output, new_attn_stick

    def output_to_logit(self, prev_out, w, b):
        prev_logit = tf.add(tf.matmul(prev_out, tf.transpose(w)), b)
        return prev_logit

    def transformer_beam_search(self,
                                abstr_outputs,
                                abstr_bias,
                                emb_kword,
                                proj_w,
                                proj_b,
                                attn_stick=None):
        # Use Beam Search in evaluation stage
        # Update [a, b, c] to [a, a, a, b, b, b, c, c, c] if beam_search_size == 3
        encoder_beam_outputs = tf.concat([
            tf.tile(tf.expand_dims(abstr_outputs[o, :, :], axis=0),
                    [self.model_config.beam_search_size, 1, 1])
            for o in range(self.model_config.batch_size)
        ],
                                         axis=0)

        encoder_attn_beam_bias = tf.concat([
            tf.tile(tf.expand_dims(abstr_bias[o, :, :, :], axis=0),
                    [self.model_config.beam_search_size, 1, 1, 1])
            for o in range(self.model_config.batch_size)
        ],
                                           axis=0)

        if attn_stick is not None and 'tuzhaopeng' in self.model_config.cov_mode and (
                'wd_attn' in self.model_config.cov_mode
                and 'kp_attn' in self.model_config.cov_mode):
            attn_beam_stick = tf.stack([
                attn_stick for _ in range(self.model_config.beam_search_size)
            ],
                                       axis=1)
        elif attn_stick is not None and 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
            attn_beam_stick = tf.concat([
                tf.tile(tf.expand_dims(attn_stick[o, :, :, :], axis=0),
                        [self.model_config.beam_search_size, 1, 1, 1])
                for o in range(self.model_config.batch_size)
            ],
                                        axis=0)

        def symbol_to_logits_fn(ids, cur_attn_stick=None):
            embs = tf.nn.embedding_lookup(emb_kword, ids[:, 1:])
            embs = tf.pad(embs, [[0, 0], [1, 0], [0, 0]])
            if attn_stick is not None and 'tuzhaopeng' in self.model_config.cov_mode and (
                    'wd_attn' in self.model_config.cov_mode
                    and 'kp_attn' in self.model_config.cov_mode):
                final_outputs, new_attn_stick = self.decode_inputs_to_outputs(
                    embs,
                    encoder_beam_outputs,
                    encoder_attn_beam_bias,
                    attn_stick=cur_attn_stick)
            elif attn_stick is not None and 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
                final_outputs, new_attn_stick = self.decode_inputs_to_outputs(
                    embs,
                    encoder_beam_outputs,
                    encoder_attn_beam_bias,
                    attn_stick=attn_beam_stick)
            return self.output_to_logit(final_outputs[:, -1, :], proj_w,
                                        proj_b), new_attn_stick

        if attn_stick is not None and 'tuzhaopeng' in self.model_config.cov_mode and 'wd_attn' in self.model_config.cov_mode:
            beam_ids, beam_score, new_attn_stick = beam_search.beam_search(
                symbol_to_logits_fn,
                tf.zeros([self.model_config.batch_size], tf.int32),
                self.model_config.beam_search_size,
                self.model_config.max_kword_len,
                self.voc_kword.vocab_size(),
                0.6,
                attn_stick=attn_beam_stick,
                model_config=self.model_config)
            new_attn_stick = new_attn_stick[:, 0, :]
        else:
            beam_ids, beam_score = beam_search_t2t.beam_search(
                symbol_to_logits_fn,
                tf.zeros([self.model_config.batch_size],
                         tf.int32), self.model_config.beam_search_size,
                self.model_config.max_kword_len, self.voc_kword.vocab_size(),
                0.6)
            new_attn_stick = attn_stick

        top_beam_ids = beam_ids[:, 0, 1:]
        top_beam_ids = tf.pad(
            top_beam_ids,
            [[0, 0],
             [0, self.model_config.max_kword_len - tf.shape(top_beam_ids)[1]]])
        decoder_target_list = [
            tf.squeeze(d, 1) for d in tf.split(
                top_beam_ids, self.model_config.max_kword_len, axis=1)
        ]
        decoder_score = -beam_score[:, 0] / tf.to_float(
            tf.shape(top_beam_ids)[1])

        return decoder_score, top_beam_ids, new_attn_stick

    def greed_search(self,
                     id,
                     abstr_outputs,
                     abstr_bias,
                     emb_kword,
                     proj_w,
                     proj_b,
                     attn_stick=None):
        kword_target_tensor = tf.TensorArray(
            tf.int64,
            size=self.model_config.max_kword_len,
            clear_after_read=False,
            element_shape=[
                self.model_config.batch_size,
            ],
            name='kword_target_tensor_%s' % str(id))
        kword_logit_tensor = tf.TensorArray(
            tf.float32,
            size=self.model_config.max_kword_len,
            clear_after_read=False,
            element_shape=[
                self.model_config.batch_size,
                self.voc_kword.vocab_size()
            ],
            name='kword_logit_tensor_%s' % str(id))
        kword_embed_inputs_tensor = tf.TensorArray(
            tf.float32,
            size=1,
            dynamic_size=True,
            clear_after_read=False,
            element_shape=[
                self.model_config.batch_size, self.model_config.dimension
            ],
            name='kword_embed_inputs_tensor_%s' % str(id))
        kword_output_tensor = tf.TensorArray(
            tf.float32,
            size=self.model_config.max_kword_len,
            clear_after_read=False,
            element_shape=[
                self.model_config.batch_size, self.model_config.dimension
            ],
            name='kword_output_tensor_%s' % str(id))

        kword_embed_inputs_tensor = kword_embed_inputs_tensor.write(
            0,
            tf.zeros(
                [self.model_config.batch_size, self.model_config.dimension]))

        def _is_finished(step, kword_target_tensor, kword_logit_tensor,
                         kword_embed_inputs_tensor, kword_output_tensor,
                         attn_stick):
            return tf.less(step, self.model_config.max_kword_len)

        def _recursive(step, kword_target_tensor, kword_logit_tensor,
                       kword_embed_inputs_tensor, kword_output_tensor,
                       attn_stick):
            cur_kword_embed_inputs_tensor = kword_embed_inputs_tensor.stack()
            cur_kword_embed_inputs_tensor = tf.transpose(
                cur_kword_embed_inputs_tensor, perm=[1, 0, 2])

            kword_outputs = self.decode_inputs_to_outputs(
                cur_kword_embed_inputs_tensor,
                abstr_outputs,
                abstr_bias,
                attn_stick=attn_stick)
            kword_output = kword_outputs[:, -1, :]

            kword_logit = self.output_to_logit(kword_output, proj_w, proj_b)
            kword_target = tf.argmax(kword_logit,
                                     output_type=tf.int64,
                                     axis=-1)
            kword_output_tensor = kword_output_tensor.write(step, kword_output)
            kword_logit_tensor = kword_logit_tensor.write(step, kword_logit)
            kword_target_tensor = kword_target_tensor.write(step, kword_target)
            kword_embed_inputs_tensor = kword_embed_inputs_tensor.write(
                step + 1, tf.nn.embedding_lookup(emb_kword, kword_target))
            return step + 1, kword_target_tensor, kword_logit_tensor, kword_embed_inputs_tensor, kword_output_tensor, attn_stick

        step = tf.constant(0)
        (_, kword_target_tensor, kword_logit_tensor, kword_embed_inputs_tensor,
         kword_output_tensor, attn_stick) = tf.while_loop(
             _is_finished,
             _recursive, [
                 step, kword_target_tensor, kword_logit_tensor,
                 kword_embed_inputs_tensor, kword_output_tensor, attn_stick
             ],
             back_prop=False,
             parallel_iterations=1)

        kword_target_tensor = kword_target_tensor.stack()
        kword_target_tensor.set_shape(
            [self.model_config.max_kword_len, self.model_config.batch_size])
        kword_target_tensor = tf.transpose(kword_target_tensor, perm=[1, 0])
        return tf.constant(10.0), kword_target_tensor, attn_stick

    def create_model(self):
        with tf.variable_scope('variables'):
            abstr_ph = []
            for _ in range(self.model_config.max_abstr_len):
                abstr_ph.append(
                    tf.zeros(self.model_config.batch_size,
                             tf.int32,
                             name='abstract_input'))

            kwords_ph = []
            for _ in range(self.model_config.max_cnt_kword):
                kword = []
                for _ in range(self.model_config.max_kword_len):
                    kword.append(
                        tf.zeros(self.model_config.batch_size,
                                 tf.int32,
                                 name='kword_input'))
                kwords_ph.append(kword)

            emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding()
            abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1)
            kwords = []
            for kword_idx in range(self.model_config.max_cnt_kword):
                kwords.append(
                    self.embedding_fn(kwords_ph[kword_idx], emb_kword))

        with tf.variable_scope('model_encoder'):
            if self.hparams.pos == 'timing':
                abstr = common_attention.add_timing_signal_1d(abstr)
            encoder_embed_inputs = tf.nn.dropout(
                abstr, 1.0 - self.hparams.layer_prepostprocess_dropout)
            abstr_bias = common_attention.attention_bias_ignore_padding(
                tf.to_float(
                    tf.equal(tf.stack(abstr_ph, axis=1),
                             self.voc_kword.encode(constant.SYMBOL_PAD))))
            abstr_outputs = transformer.transformer_encoder(
                encoder_embed_inputs, abstr_bias, self.hparams)

            if 'tuzhaopeng' in self.model_config.cov_mode:
                attn_stick = tf.ones([
                    self.model_config.batch_size, self.model_config.num_heads,
                    1,
                    self.model_config.dimension / self.model_config.num_heads
                ], tf.float32, 'attn_memory')

        losses = []
        targets = []
        obj = {}
        with tf.variable_scope('model_decoder'):
            for kword_idx in range(self.model_config.max_cnt_kword):
                if self.is_train:
                    kword = kwords[kword_idx][:-1]
                    kword_ph = kwords_ph[kword_idx]
                    kword_output_list, new_attn_stick = self.decode_step(
                        kword, abstr_outputs, abstr_bias, attn_stick)
                    kword_logit_list = [
                        self.output_to_logit(o, proj_w, proj_b)
                        for o in kword_output_list
                    ]
                    kword_target_list = [
                        tf.argmax(o, output_type=tf.int32, axis=-1)
                        for o in kword_logit_list
                    ]
                    attn_stick = new_attn_stick

                    if self.model_config.number_samples > 0:
                        loss_fn = tf.nn.sampled_softmax_loss
                    else:
                        loss_fn = None
                    kword_lossbias = [
                        tf.to_float(
                            tf.not_equal(
                                d, self.voc_kword.encode(constant.SYMBOL_PAD)))
                        for d in kword_ph
                    ]
                    kword_lossbias = tf.stack(kword_lossbias, axis=1)
                    loss = sequence_loss(
                        logits=tf.stack(kword_logit_list, axis=1),
                        targets=tf.stack(kword_ph, axis=1),
                        weights=kword_lossbias,
                        softmax_loss_function=loss_fn,
                        w=proj_w,
                        b=proj_b,
                        decoder_outputs=tf.stack(kword_output_list, axis=1),
                        number_samples=self.model_config.number_samples)
                    targets.append(tf.stack(kword_target_list, axis=1))

                    if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
                        target_emb = tf.stack(self.embedding_fn(
                            kword_target_list, emb_kword),
                                              axis=1)
                        target_emb = common_attention.split_heads(
                            target_emb, self.model_config.num_heads)
                        target_emb = tf.reduce_mean(target_emb, axis=2)
                        target_emb_trans = tf.get_variable(
                            'dim_weight_trans',
                            shape=[
                                1,
                                target_emb.get_shape()[-1].value,
                                target_emb.get_shape()[-1].value
                            ],
                            dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer())
                        target_emb = tf.nn.conv1d(target_emb, target_emb_trans,
                                                  1, 'SAME')
                        target_emb = tf.expand_dims(target_emb, axis=2)
                        attn_stick += target_emb
                    losses.append(loss)
                else:
                    if self.model_config.beam_search_size > 0:
                        loss, target, new_attn_stick = self.transformer_beam_search(
                            abstr_outputs,
                            abstr_bias,
                            emb_kword,
                            proj_w,
                            proj_b,
                            attn_stick=attn_stick)
                    else:
                        loss, target, new_attn_stick = self.greed_search(
                            kword_idx,
                            abstr_outputs,
                            abstr_bias,
                            emb_kword,
                            proj_w,
                            proj_b,
                            attn_stick=attn_stick)
                    targets.append(target)
                    losses = loss
                    attn_stick = new_attn_stick
                    if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode:
                        target.set_shape([
                            self.model_config.batch_size,
                            self.model_config.max_kword_len
                        ])
                        target_list = tf.unstack(target, axis=1)
                        target_emb = tf.stack(self.embedding_fn(
                            target_list, emb_kword),
                                              axis=1)
                        target_emb = common_attention.split_heads(
                            target_emb, self.model_config.num_heads)
                        target_emb = tf.reduce_mean(target_emb, axis=2)
                        target_emb_trans = tf.get_variable(
                            'dim_weight_trans',
                            shape=[
                                1,
                                target_emb.get_shape()[-1].value,
                                target_emb.get_shape()[-1].value
                            ],
                            dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer())
                        target_emb = tf.nn.conv1d(target_emb, target_emb_trans,
                                                  1, 'SAME')
                        target_emb = tf.expand_dims(target_emb, axis=2)
                        attn_stick += target_emb
                tf.get_variable_scope().reuse_variables()
        if targets:
            obj['targets'] = tf.stack(targets, axis=1)
        obj['abstr_ph'] = abstr_ph
        obj['kwords_ph'] = kwords_ph
        obj['attn_stick'] = attn_stick
        if type(losses) is list:
            losses = tf.add_n(losses)
        return losses, obj

    def create_model_multigpu(self):
        losses = []
        grads = []
        optim = self.get_optim()
        self.objs = []

        with tf.variable_scope(tf.get_variable_scope()) as scope:
            for gpu_id in range(self.model_config.num_gpus):
                with tf.device('/gpu:%d' % gpu_id):
                    loss, obj = self.create_model()
                    grad = optim.compute_gradients(loss)
                    losses.append(loss)
                    grads.append(grad)
                    self.objs.append(obj)
                    tf.get_variable_scope().reuse_variables()

        self.global_step = tf.get_variable('global_step',
                                           initializer=tf.constant(
                                               0, dtype=tf.int64),
                                           trainable=False)
        with tf.variable_scope('optimization'):
            self.loss = tf.divide(tf.add_n(losses), self.model_config.num_gpus)
            self.perplexity = tf.exp(
                tf.reduce_mean(self.loss) / self.model_config.max_cnt_kword)

            if self.is_train:
                avg_grad = self.average_gradients(grads)
                grads = [g for (g, v) in avg_grad]
                clipped_grads, _ = tf.clip_by_global_norm(
                    grads, self.model_config.max_grad_norm)
                self.train_op = optim.apply_gradients(
                    zip(clipped_grads, tf.trainable_variables()),
                    global_step=self.global_step)
                self.increment_global_step = tf.assign_add(self.global_step, 1)

            self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)

    def get_optim(self):
        learning_rate = tf.constant(self.model_config.learning_rate)

        if self.model_config.optimizer == 'adagrad':
            opt = tf.train.AdagradOptimizer(learning_rate)
        # Adam need lower learning rate
        elif self.model_config.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate)
        else:
            raise Exception('Not Implemented Optimizer!')
        return opt

    # Got from https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L101
    def average_gradients(self, tower_grads):
        """Calculate the average gradient for each shared variable across all towers.
        Note that this function provides a synchronization point across all towers.
        Args:
          tower_grads: List of lists of (gradient, variable) tuples. The outer list
            is over individual gradients. The inner list is over the gradient
            calculation for each tower.
        Returns:
           List of pairs of (gradient, variable) where the gradient has been averaged
           across all towers.
        """
        average_grads = []
        for grad_and_vars in zip(*tower_grads):
            # Note that each grad_and_vars looks like the following:
            #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
            grads = []
            for g, _ in grad_and_vars:
                # Add 0 dimension to the gradients to represent the tower.
                expanded_g = tf.expand_dims(g, 0)

                # Append on a 'tower' dimension which we will average over below.
                grads.append(expanded_g)

            # Average over the 'tower' dimension.
            grad = tf.concat(axis=0, values=grads)
            grad = tf.reduce_mean(grad, 0)

            # Keep in mind that the Variables are redundant because they are shared
            # across towers. So .. we will just return the first tower's pointer to
            # the Variable.
            v = grad_and_vars[0][1]
            grad_and_var = (grad, v)
            average_grads.append(grad_and_var)
        return average_grads

    def setup_hparams(self):
        self.hparams.num_heads = self.model_config.num_heads
        self.hparams.num_hidden_layers = self.model_config.num_hidden_layers
        self.hparams.num_encoder_layers = self.model_config.num_encoder_layers
        self.hparams.num_decoder_layers = self.model_config.num_decoder_layers
        self.hparams.pos = self.model_config.hparams_pos
        self.hparams.hidden_size = self.model_config.dimension
        self.hparams.layer_prepostprocess_dropout = self.model_config.layer_prepostprocess_dropout
        self.hparams.cov_mode = self.model_config.cov_mode

        if self.is_train:
            self.hparams.add_hparam('mode', tf.estimator.ModeKeys.TRAIN)
        else:
            self.hparams.add_hparam('mode', tf.estimator.ModeKeys.EVAL)
            self.hparams.layer_prepostprocess_dropout = 0.0
            self.hparams.attention_dropout = 0.0
            self.hparams.dropout = 0.0
            self.hparams.relu_dropout = 0.0
示例#11
0
class Data:
    def __init__(self, model_config):
        self.model_config = model_config
        # For Abbr
        self.populate_abbr()
        # For Context
        self.voc = Vocab(model_config, model_config.voc_file)

    def populate_abbr(self):
        def update(item, item2id, id2item):
            if item not in item2id:
                item2id[item] = len(id2item)
                id2item.append(item)

        s_i = 0
        self.abbrs_pos = {}
        self.abbr2id, self.id2abbr = {}, []
        self.sense2id, self.id2sense = {}, []
        for line in open(self.model_config.abbr_common_file):
            items = line.strip().split('|')
            abbr = items[0]
            update(abbr, self.abbr2id, self.id2abbr)
            senses = items[1].split()

            abbr_id = self.abbr2id[abbr]
            if abbr_id not in self.abbrs_pos:
                self.abbrs_pos[abbr_id] = {}
                self.abbrs_pos[abbr_id]['s_i'] = s_i
                self.abbrs_pos[abbr_id]['e_i'] = s_i + len(senses)
                s_i = s_i + len(senses)
            for sense in senses:
                update(abbr + '|' + sense, self.sense2id, self.id2sense)
        self.sen_cnt = s_i

        self.abbrs_filterout = set()
        for line in open(self.model_config.abbr_rare_file):
            self.abbrs_filterout.add(line.strip())

    def process_line(self, line):
        checker = set()
        contexts = []
        targets = []
        words = line.split()
        contexts.extend(self.voc.encode(BOS))
        for id, word in enumerate(words):
            if word.startswith('abbr|'):
                pair = word.split('|')
                abbr = pair[1]
                if abbr in self.abbrs_filterout:
                    continue
                sense = pair[2]

                if 'add_abbr' in self.model_config.voc_process:
                    wid = self.voc.encode(abbr)
                else:
                    wid = self.voc.encode(NONTAR)
                if abbr not in self.abbrs_pos:
                    if abbr not in self.abbr2id:
                        continue
                    abbr_id = self.abbr2id[abbr]
                    if abbr_id not in checker and len(
                            targets) < self.model_config.max_abbrs:
                        if abbr + '|' + sense in self.sense2id:
                            sense_id = self.sense2id[abbr + '|' + sense]
                            targets.append([id, abbr_id, sense_id])
                            checker.add(abbr_id)
            else:
                wid = self.voc.encode(word)
            contexts.extend(wid)
        contexts.extend(self.voc.encode(EOS))

        if len(contexts) > self.model_config.max_context_len:
            contexts = contexts[:self.model_config.max_context_len]
        else:
            num_pad = self.model_config.max_context_len - len(contexts)
            contexts.extend(self.voc.encode(PAD) * num_pad)
        assert len(contexts) == self.model_config.max_context_len

        if len(targets) > self.model_config.max_abbrs:
            targets = targets[:self.model_config.max_abbrs]
        else:
            num_pad = self.model_config.max_abbrs - len(targets)
            targets.extend([[0, 0, 0]] * num_pad)
        assert len(targets) == self.model_config.max_abbrs

        obj = {'contexts': contexts, 'targets': targets}
        return obj

    def populate_data(self, path):
        self.datas = []
        for line in open(path):
            obj = self.process_line(line)
            self.datas.append(obj)
)
voc_kword = Vocab(
    DefaultConfig(),
    vocab_path=
    '/Users/sanqiangzhao/git/kp/keyphrase_data/kp20k_cleaned/tf_data/kword.subvoc'
)

max_abstr_len, max_kword_len, max_kword_cnt = 0, 0, 0
c_abstr_len, c_kwod_len, c_kword_cnt = Counter(), Counter(), Counter()

for line in open(PATH):
    obj = json.loads(line)
    kphrases = obj['kphrases'].split(';')
    abstr = obj['abstr']

    abstr_ids = voc_abstr.encode(abstr)
    if len(abstr_ids) > max_abstr_len:
        print(abstr)
    max_abstr_len = max(max_abstr_len, len(abstr_ids))
    c_abstr_len.update([len(abstr_ids)])

    max_kword_cnt = max(max_kword_cnt, len(kphrases))
    c_kword_cnt.update([len(kphrases)])

    for kphrase in kphrases:
        max_kword_len = max(max_kword_len, len(voc_kword.encode(kphrase)))
        c_kwod_len.update([len(voc_kword.encode(kphrase))])

print(max_abstr_len)
print(max_kword_len)
print(max_kword_cnt)
示例#13
0
文件: data.py 项目: Astroneko404/wsd
class Data:
    def __init__(self, model_config):
        self.model_config = model_config
        # For Abbr
        self.populate_abbr()
        # For Context
        self.voc = Vocab(model_config, model_config.voc_file)

        if 'stype' in model_config.extra_mode or 'def' in model_config.extra_mode:
            self.populate_cui()

    def populate_abbr(self):
        self.id2abbr = [
            abbr.strip()
            for abbr in open(self.model_config.abbr_file).readlines()
        ]
        self.abbr2id = dict(zip(self.id2abbr, range(len(self.id2abbr))))
        self.id2sense = [
            cui.strip()
            for cui in open(self.model_config.cui_file).readlines()
        ]
        self.sense2id = dict(zip(self.id2sense, range(len(self.id2sense))))
        self.sen_cnt = len(self.id2sense)

        if self.model_config.extra_mode:
            self.id2abbr.append(con)

    def populate_cui(self):
        self.stype2id, self.id2stype = {}, []
        self.id2stype = [
            stype.split('\t')[0].lower()
            for stype in open(self.model_config.stype_voc_file).readlines()
        ]
        self.id2stype.append('unk')
        self.stype2id = dict(zip(self.id2stype, range(len(self.id2stype))))

        self.cui2stype = {}
        with open(self.model_config.cui_extra_pkl, 'rb') as cui_file:
            cui_extra = pickle.load(cui_file)
            for cui in cui_extra:
                info = cui_extra[cui]
                self.cui2stype[cui] = self.stype2id[info[1].lower()]

        self.cui2def = {}
        with open(self.model_config.cui_extra_pkl, 'rb') as cui_file:
            cui_extra = pickle.load(cui_file)
            for cui in cui_extra:
                info = cui_extra[cui]
                if self.model_config.subword_vocab_size <= 0:
                    definition = [
                        self.voc.encode(w)
                        for w in info[0].lower().strip().split()
                    ]
                else:
                    definition = self.voc.encode(info[0].lower().strip())

                if len(definition) > self.model_config.max_def_len:
                    definition = definition[:self.model_config.max_def_len]
                else:
                    num_pad = self.model_config.max_def_len - len(definition)
                    if self.model_config.subword_vocab_size <= 0:
                        definition.extend([self.voc.encode(PAD)] * num_pad)
                    else:
                        definition.extend(self.voc.encode(PAD) * num_pad)
                assert len(definition) == self.model_config.max_def_len
                self.cui2def[cui] = definition

        np_mask = np.loadtxt(self.model_config.abbr_mask_file)
        self.cuiud2abbrid = defaultdict(list)
        for abbrid in range(len(self.id2abbr)):
            cuiids = list(np.where(np_mask[abbrid] == 1)[0])
            for cuiid in cuiids:
                self.cuiud2abbrid[cuiid].append(abbrid)

    def process_line(self, line, line_id, inst_id):
        '''
        Process each line and return tokens. Each line may contain multiple labels.
        :param line:
        :param line_id:
        :return:
        '''
        contexts = []
        targets = []
        words = line.split()

        # if self.model_config.subword_vocab_size <= 0:
        #     contexts.append(self.voc.encode(BOS))
        # else:
        #     contexts.extend(self.voc.encode(BOS))

        for pos_id, word in enumerate(words):
            if word.startswith('abbr|'):
                abbr, sense, long_form = dataset_helper.process_abbr_token(
                    word)

                if 'add_abbr' in self.model_config.voc_process:
                    wid = self.voc.encode(abbr)
                else:
                    wid = self.voc.encode(NONTAR)

                # Set abbr_id or sense_id to None if either one is not in our vocab/inventory
                if abbr not in self.abbr2id:
                    # print('abbr %s not found in abbr vocab (size=%d), ignore this data example'
                    #       % (abbr, len(self.abbr2id)))
                    abbr_id = 0
                else:
                    abbr_id = self.abbr2id[abbr]

                if sense in self.sense2id:
                    sense_id = self.sense2id[sense]
                else:
                    sense_id = 0
                    # print('sense %s is not in sense inventory (size=%d), ignore this data example'
                    #       % (sense, len(self.sense2id)))

                # return each target as a dict instead of a list
                targets.append({
                    'pos_id': pos_id,
                    'abbr_id': abbr_id,
                    'abbr': abbr,
                    'sense_id': sense_id,
                    'sense': sense,
                    'long_form': long_form,
                    'line_id': line_id,
                    'inst_id': inst_id
                })
                # targets.append([pos_id, abbr_id, sense_id, line_id, inst_id])
                # targets.append([id, abbr_id, sense_id, line_id, inst_id, longform_tokens])
                inst_id += 1  # global instance id increment
            else:
                wid = self.voc.encode(word)

            if self.model_config.subword_vocab_size <= 0:
                contexts.append(wid)
            else:
                contexts.extend(wid)

        # if self.model_config.subword_vocab_size <= 0:
        #     contexts.append(self.voc.encode(EOS))
        # else:
        #     contexts.extend(self.voc.encode(EOS))

        examples = []
        window_size = int(self.model_config.max_context_len / 2)
        for target in targets:
            pos_id = target['pos_id']
            extend_size = 0
            if pos_id < window_size:
                left_idx = 0
                extend_size = window_size - pos_id
            else:
                left_idx = pos_id - window_size

            if pos_id + window_size > len(contexts):
                right_idx = len(contexts)
            else:
                right_idx = min(pos_id + window_size + extend_size,
                                len(contexts))

            cur_contexts = contexts[left_idx:right_idx]

            if len(cur_contexts) > self.model_config.max_context_len:
                cur_contexts = cur_contexts[:self.model_config.max_context_len]
            else:
                num_pad = self.model_config.max_context_len - len(cur_contexts)
                if self.model_config.subword_vocab_size <= 0:
                    cur_contexts.extend([self.voc.encode(PAD)] * num_pad)
                else:
                    cur_contexts.extend(self.voc.encode(PAD) * num_pad)
            assert len(cur_contexts) == self.model_config.max_context_len

            example = {
                'contexts': cur_contexts,
                'target': target,
                'line': line,
            }

            examples.append(example)

        return examples, inst_id

    def populate_data(self, path):
        self.datas = []
        line_id = 0
        inst_id = 0
        for line in open(path):
            objs, inst_id = self.process_line(line, line_id, inst_id)
            self.datas.extend(objs)
            line_id += 1
            if line_id % 10000 == 0:
                print('Process %s lines.' % line_id)
        print('Finished processing with inst:%s' % inst_id)